sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
import subprocess
|
|
20
|
+
import sys
|
|
21
|
+
import time
|
|
22
|
+
from typing import List
|
|
23
|
+
|
|
24
|
+
import paramiko
|
|
25
|
+
|
|
26
|
+
if __package__ is None or __package__ == "":
|
|
27
|
+
from runtime_environment_manager import (
|
|
28
|
+
get_logger,
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import (
|
|
32
|
+
get_logger,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
SUCCESS_EXIT_CODE = 0
|
|
36
|
+
DEFAULT_FAILURE_CODE = 1
|
|
37
|
+
|
|
38
|
+
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
|
|
39
|
+
READY_FILE = "/tmp/ready.%s"
|
|
40
|
+
DEFAULT_SSH_PORT = 22
|
|
41
|
+
|
|
42
|
+
FAILURE_REASON_PATH = "/opt/ml/output/failure"
|
|
43
|
+
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
|
|
44
|
+
|
|
45
|
+
logger = get_logger()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
|
|
49
|
+
"""Class to handle host key policy for SageMaker distributed training SSH connections.
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
>>> client = paramiko.SSHClient()
|
|
53
|
+
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
|
|
54
|
+
>>> # Will succeed for SageMaker algorithm containers
|
|
55
|
+
>>> client.connect('algo-1234.internal')
|
|
56
|
+
>>> # Will raise SSHException for other unknown hosts
|
|
57
|
+
>>> client.connect('unknown-host') # raises SSHException
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def missing_host_key(self, client, hostname, key):
|
|
61
|
+
"""Accept host keys for algo-* hostnames, reject others.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
client: The SSHClient instance
|
|
65
|
+
hostname: The hostname attempting to connect
|
|
66
|
+
key: The host key
|
|
67
|
+
Raises:
|
|
68
|
+
paramiko.SSHException: If hostname doesn't match algo-* pattern
|
|
69
|
+
"""
|
|
70
|
+
if hostname.startswith("algo-"):
|
|
71
|
+
client.get_host_keys().add(hostname, key.get_name(), key)
|
|
72
|
+
return
|
|
73
|
+
raise paramiko.SSHException(f"Unknown host key for {hostname}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _parse_args(sys_args):
|
|
77
|
+
"""Parses CLI arguments."""
|
|
78
|
+
parser = argparse.ArgumentParser()
|
|
79
|
+
parser.add_argument("--job_ended", type=str, default="0")
|
|
80
|
+
args, _ = parser.parse_known_args(sys_args)
|
|
81
|
+
return args
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
|
|
85
|
+
"""Check if the connection to the provided host and port is possible."""
|
|
86
|
+
try:
|
|
87
|
+
with paramiko.SSHClient() as client:
|
|
88
|
+
client.load_system_host_keys()
|
|
89
|
+
client.set_missing_host_key_policy(CustomHostKeyPolicy())
|
|
90
|
+
client.connect(host, port=port)
|
|
91
|
+
logger.info("Can connect to host %s", host)
|
|
92
|
+
return True
|
|
93
|
+
except Exception as e: # pylint: disable=W0703
|
|
94
|
+
logger.info("Cannot connect to host %s", host)
|
|
95
|
+
logger.debug("Connection failed with exception: %s", e)
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _write_file_to_host(host: str, status_file: str) -> bool:
|
|
100
|
+
"""Write the a file to the provided host."""
|
|
101
|
+
try:
|
|
102
|
+
logger.info("Writing %s to %s", status_file, host)
|
|
103
|
+
subprocess.run(
|
|
104
|
+
["ssh", host, "touch", f"{status_file}"],
|
|
105
|
+
capture_output=True,
|
|
106
|
+
text=True,
|
|
107
|
+
check=True,
|
|
108
|
+
)
|
|
109
|
+
logger.info("Finished writing status file")
|
|
110
|
+
return True
|
|
111
|
+
except subprocess.CalledProcessError:
|
|
112
|
+
logger.info("Cannot connect to %s", host)
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _write_failure_reason_file(failure_msg):
|
|
117
|
+
"""Create a file 'failure' with failure reason written if bootstrap runtime env failed.
|
|
118
|
+
|
|
119
|
+
See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
|
|
120
|
+
Args:
|
|
121
|
+
failure_msg: The content of file to be written.
|
|
122
|
+
"""
|
|
123
|
+
if not os.path.exists(FAILURE_REASON_PATH):
|
|
124
|
+
with open(FAILURE_REASON_PATH, "w") as f:
|
|
125
|
+
f.write("RuntimeEnvironmentError: " + failure_msg)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300):
|
|
129
|
+
"""Worker nodes wait until they can connect to the master node."""
|
|
130
|
+
start_time = time.time()
|
|
131
|
+
while True:
|
|
132
|
+
logger.info("Worker is attempting to connect to the master node %s...", master_host)
|
|
133
|
+
if _can_connect(master_host, port):
|
|
134
|
+
logger.info("Worker can connect to master node %s.", master_host)
|
|
135
|
+
break
|
|
136
|
+
if time.time() - start_time > timeout:
|
|
137
|
+
raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host)
|
|
138
|
+
|
|
139
|
+
time.sleep(5) # Wait for 5 seconds before trying again
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _wait_for_status_file(status_file: str):
|
|
143
|
+
"""Wait for the status file to be created."""
|
|
144
|
+
logger.info("Waiting for status file %s", status_file)
|
|
145
|
+
while not os.path.exists(status_file):
|
|
146
|
+
time.sleep(30)
|
|
147
|
+
logger.info("Found status file %s", status_file)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300):
|
|
151
|
+
"""Master node waits until it can connect to all worker nodes."""
|
|
152
|
+
start_time = time.time()
|
|
153
|
+
if not worker_hosts:
|
|
154
|
+
logger.info("No worker nodes to connect to.")
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
while True:
|
|
158
|
+
logger.info("Master is attempting to connect to all workers...")
|
|
159
|
+
all_workers_connected = all(
|
|
160
|
+
_can_connect(worker, port) and os.path.exists(READY_FILE % worker)
|
|
161
|
+
for worker in worker_hosts
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if all_workers_connected:
|
|
165
|
+
logger.info("Master can connect to all worker nodes.")
|
|
166
|
+
break
|
|
167
|
+
if time.time() - start_time > timeout:
|
|
168
|
+
raise TimeoutError("Timed out waiting for workers to be reachable.")
|
|
169
|
+
|
|
170
|
+
time.sleep(5) # Wait for 5 seconds before trying again
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def bootstrap_master_node(worker_hosts: List[str]):
|
|
174
|
+
"""Bootstrap the master node."""
|
|
175
|
+
logger.info("Bootstrapping master node...")
|
|
176
|
+
_wait_for_workers(worker_hosts)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def bootstrap_worker_node(
|
|
180
|
+
master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE
|
|
181
|
+
):
|
|
182
|
+
"""Bootstrap the worker nodes."""
|
|
183
|
+
logger.info("Bootstrapping worker node...")
|
|
184
|
+
_wait_for_master(master_host)
|
|
185
|
+
_write_file_to_host(master_host, READY_FILE % current_host)
|
|
186
|
+
_wait_for_status_file(status_file)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def start_sshd_daemon():
|
|
190
|
+
"""Start the SSH daemon on the current node."""
|
|
191
|
+
sshd_executable = "/usr/sbin/sshd"
|
|
192
|
+
|
|
193
|
+
if not os.path.exists(sshd_executable):
|
|
194
|
+
raise RuntimeError("SSH daemon not found.")
|
|
195
|
+
|
|
196
|
+
# Start the sshd in daemon mode (-D)
|
|
197
|
+
subprocess.Popen([sshd_executable, "-D"])
|
|
198
|
+
logger.info("Started SSH daemon.")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE):
|
|
202
|
+
"""Write the status file to all worker nodes."""
|
|
203
|
+
for worker in worker_hosts:
|
|
204
|
+
retry = 0
|
|
205
|
+
while not _write_file_to_host(worker, status_file):
|
|
206
|
+
time.sleep(5)
|
|
207
|
+
retry += 1
|
|
208
|
+
if retry > 5:
|
|
209
|
+
raise TimeoutError("Timed out waiting for %s to be reachable." % worker)
|
|
210
|
+
logger.info("Retrying to write status file to %s", worker)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def main(sys_args=None):
|
|
214
|
+
"""Entry point for bootstrap script"""
|
|
215
|
+
try:
|
|
216
|
+
args = _parse_args(sys_args)
|
|
217
|
+
|
|
218
|
+
job_ended = args.job_ended
|
|
219
|
+
|
|
220
|
+
main_host = os.environ["SM_MASTER_ADDR"]
|
|
221
|
+
current_host = os.environ["SM_CURRENT_HOST"]
|
|
222
|
+
|
|
223
|
+
if job_ended == "0":
|
|
224
|
+
logger.info("Job is running, bootstrapping nodes")
|
|
225
|
+
|
|
226
|
+
start_sshd_daemon()
|
|
227
|
+
|
|
228
|
+
if current_host != main_host:
|
|
229
|
+
bootstrap_worker_node(main_host, current_host)
|
|
230
|
+
else:
|
|
231
|
+
sorted_hosts = json.loads(os.environ["SM_HOSTS"])
|
|
232
|
+
worker_hosts = [host for host in sorted_hosts if host != main_host]
|
|
233
|
+
|
|
234
|
+
bootstrap_master_node(worker_hosts)
|
|
235
|
+
else:
|
|
236
|
+
logger.info("Job ended, writing status file to workers")
|
|
237
|
+
|
|
238
|
+
if current_host == main_host:
|
|
239
|
+
sorted_hosts = json.loads(os.environ["SM_HOSTS"])
|
|
240
|
+
worker_hosts = [host for host in sorted_hosts if host != main_host]
|
|
241
|
+
|
|
242
|
+
write_status_file_to_workers(worker_hosts)
|
|
243
|
+
except Exception as e: # pylint: disable=broad-except
|
|
244
|
+
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
|
|
245
|
+
|
|
246
|
+
_write_failure_reason_file(str(e))
|
|
247
|
+
|
|
248
|
+
sys.exit(DEFAULT_FAILURE_CODE)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
if __name__ == "__main__":
|
|
252
|
+
main(sys.argv[1:])
|