sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""LocalContainer class module."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import base64
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
import shutil
|
|
20
|
+
import subprocess
|
|
21
|
+
from tempfile import TemporaryDirectory
|
|
22
|
+
from typing import Any, Dict, List, Optional
|
|
23
|
+
from pydantic import BaseModel, ConfigDict
|
|
24
|
+
|
|
25
|
+
from sagemaker.core.local.image import (
|
|
26
|
+
_Volume,
|
|
27
|
+
_aws_credentials,
|
|
28
|
+
_check_output,
|
|
29
|
+
_pull_image,
|
|
30
|
+
_stream_output,
|
|
31
|
+
_write_json_file,
|
|
32
|
+
)
|
|
33
|
+
from sagemaker.core.local.utils import check_for_studio, recursive_copy
|
|
34
|
+
from sagemaker.core.constants import DIR_PARAM_NAME
|
|
35
|
+
from sagemaker.core.modules import logger, Session
|
|
36
|
+
from sagemaker.core.modules.configs import Channel
|
|
37
|
+
from sagemaker.core.common_utils import (
|
|
38
|
+
ECR_URI_PATTERN,
|
|
39
|
+
create_tar_file,
|
|
40
|
+
_module_import_error,
|
|
41
|
+
download_folder,
|
|
42
|
+
)
|
|
43
|
+
from sagemaker.core.utils.utils import Unassigned
|
|
44
|
+
from sagemaker.core.shapes import DataSource
|
|
45
|
+
|
|
46
|
+
from six.moves.urllib.parse import urlparse
|
|
47
|
+
|
|
48
|
+
STUDIO_HOST_NAME = "sagemaker-local"
|
|
49
|
+
DOCKER_COMPOSE_FILENAME = "docker-compose.yaml"
|
|
50
|
+
DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = "COMPOSE_HTTP_TIMEOUT"
|
|
51
|
+
DOCKER_COMPOSE_HTTP_TIMEOUT = "120"
|
|
52
|
+
|
|
53
|
+
REGION_ENV_NAME = "AWS_REGION"
|
|
54
|
+
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
|
|
55
|
+
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"
|
|
56
|
+
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"
|
|
57
|
+
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class _LocalContainer(BaseModel):
|
|
61
|
+
"""A local training job class for local mode model trainer.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
training_job_name (str):
|
|
65
|
+
The name of the training job.
|
|
66
|
+
instance_type (str):
|
|
67
|
+
The instance type.
|
|
68
|
+
instance_count (int):
|
|
69
|
+
The number of instances.
|
|
70
|
+
image (str):
|
|
71
|
+
The image name for training.
|
|
72
|
+
container_root (str):
|
|
73
|
+
The directory path for the local container root.
|
|
74
|
+
input_from_s3 (bool):
|
|
75
|
+
If the input is from s3.
|
|
76
|
+
is_studio (bool):
|
|
77
|
+
If the container is running on SageMaker studio instance.
|
|
78
|
+
hosts (Optional[List[str]]):
|
|
79
|
+
The list of host names.
|
|
80
|
+
input_data_config: Optional[List[Channel]]
|
|
81
|
+
The input data channels for the training job.
|
|
82
|
+
Takes a list of Channel objects or a dictionary of channel names to DataSourceType.
|
|
83
|
+
DataSourceType can be an S3 URI string, local file path string,
|
|
84
|
+
S3DataSource object, or FileSystemDataSource object.
|
|
85
|
+
environment (Optional[Dict[str, str]]):
|
|
86
|
+
The environment variables for the training job.
|
|
87
|
+
hyper_parameters (Optional[Dict[str, Any]]):
|
|
88
|
+
The hyperparameters for the training job.
|
|
89
|
+
sagemaker_session (Optional[Session]):
|
|
90
|
+
The SageMaker session.
|
|
91
|
+
For local mode training, SageMaker session will only be used when input is from S3 or
|
|
92
|
+
image needs to be pulled from ECR.
|
|
93
|
+
container_entrypoint (Optional[List[str]]):
|
|
94
|
+
The command to be executed in the container.
|
|
95
|
+
container_arguments (Optional[List[str]]):
|
|
96
|
+
The arguments of the container commands.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
|
|
100
|
+
|
|
101
|
+
training_job_name: str
|
|
102
|
+
instance_type: str
|
|
103
|
+
instance_count: int
|
|
104
|
+
image: str
|
|
105
|
+
container_root: str
|
|
106
|
+
input_from_s3: Optional[bool] = False
|
|
107
|
+
is_studio: Optional[bool] = False
|
|
108
|
+
hosts: Optional[List[str]] = []
|
|
109
|
+
input_data_config: Optional[List[Channel]]
|
|
110
|
+
environment: Optional[Dict[str, str]]
|
|
111
|
+
hyper_parameters: Optional[Dict[str, str]]
|
|
112
|
+
sagemaker_session: Optional[Session] = None
|
|
113
|
+
container_entrypoint: Optional[List[str]]
|
|
114
|
+
container_arguments: Optional[List[str]]
|
|
115
|
+
|
|
116
|
+
_temporary_folders: List[str] = []
|
|
117
|
+
|
|
118
|
+
def model_post_init(self, __context: Any):
|
|
119
|
+
"""Post init method to perform custom validation and set default values."""
|
|
120
|
+
self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)]
|
|
121
|
+
if self.environment is None:
|
|
122
|
+
self.environment = {}
|
|
123
|
+
if self.hyper_parameters is None:
|
|
124
|
+
self.hyper_parameters = {}
|
|
125
|
+
|
|
126
|
+
for channel in self.input_data_config:
|
|
127
|
+
if channel.data_source and channel.data_source.s3_data_source != Unassigned():
|
|
128
|
+
self.input_from_s3 = True
|
|
129
|
+
data_distribution = channel.data_source.s3_data_source.s3_data_distribution_type
|
|
130
|
+
if self.sagemaker_session is None:
|
|
131
|
+
# In local mode only initiate session when neccessary
|
|
132
|
+
self.sagemaker_session = Session()
|
|
133
|
+
elif (
|
|
134
|
+
channel.data_source and channel.data_source.file_system_data_source != Unassigned()
|
|
135
|
+
):
|
|
136
|
+
self.input_from_s3 = False
|
|
137
|
+
data_distribution = channel.data_source.file_system_data_source.file_system_type
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"Need channel.data_source to have s3_data_source or file_system_data_source"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
supported_distributions = ["FullyReplicated", "EFS"]
|
|
144
|
+
if data_distribution and data_distribution not in supported_distributions:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
"Invalid Data Distribution: '{}'. Local mode currently supports FullyReplicated "
|
|
147
|
+
"Distribution for S3 data source and EFS Distribution for local data source.".format(
|
|
148
|
+
data_distribution,
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
self.is_studio = check_for_studio()
|
|
152
|
+
|
|
153
|
+
def train(
|
|
154
|
+
self,
|
|
155
|
+
wait: bool,
|
|
156
|
+
) -> str:
|
|
157
|
+
"""Run a training job locally using docker-compose.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
wait (bool):
|
|
161
|
+
Whether to wait the training output before exiting.
|
|
162
|
+
"""
|
|
163
|
+
# create output/data folder since sagemaker-containers 2.0 expects it
|
|
164
|
+
os.makedirs(os.path.join(self.container_root, "output", "data"), exist_ok=True)
|
|
165
|
+
# A shared directory for all the containers. It is only mounted if the training script is
|
|
166
|
+
# Local.
|
|
167
|
+
os.makedirs(os.path.join(self.container_root, "shared"), exist_ok=True)
|
|
168
|
+
|
|
169
|
+
data_dir = os.path.join(self.container_root, "input", "data")
|
|
170
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
171
|
+
volumes = self._prepare_training_volumes(
|
|
172
|
+
data_dir, self.input_data_config, self.hyper_parameters
|
|
173
|
+
)
|
|
174
|
+
# If local, source directory needs to be updated to mounted /opt/ml/code path
|
|
175
|
+
if DIR_PARAM_NAME in self.hyper_parameters:
|
|
176
|
+
src_dir = self.hyper_parameters[DIR_PARAM_NAME]
|
|
177
|
+
parsed_uri = urlparse(src_dir)
|
|
178
|
+
if parsed_uri.scheme == "file":
|
|
179
|
+
self.hyper_parameters[DIR_PARAM_NAME] = "/opt/ml/code"
|
|
180
|
+
|
|
181
|
+
for host in self.hosts:
|
|
182
|
+
# Create the configuration files
|
|
183
|
+
self._create_config_file_directories(host)
|
|
184
|
+
self._write_config_files(host, self.input_data_config, self.hyper_parameters)
|
|
185
|
+
|
|
186
|
+
self.environment[TRAINING_JOB_NAME_ENV_NAME] = self.training_job_name
|
|
187
|
+
if self.input_from_s3:
|
|
188
|
+
self.environment[S3_ENDPOINT_URL_ENV_NAME] = (
|
|
189
|
+
self.sagemaker_session.s3_resource.meta.client._endpoint.host
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if self._ecr_login_if_needed():
|
|
193
|
+
_pull_image(self.image)
|
|
194
|
+
|
|
195
|
+
if self.sagemaker_session:
|
|
196
|
+
self.environment[REGION_ENV_NAME] = self.sagemaker_session.boto_region_name
|
|
197
|
+
|
|
198
|
+
compose_data = self._generate_compose_file(self.environment, volumes)
|
|
199
|
+
compose_command = self._generate_compose_command(wait)
|
|
200
|
+
process = subprocess.Popen(
|
|
201
|
+
compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
_stream_output(process)
|
|
206
|
+
finally:
|
|
207
|
+
artifacts = self.retrieve_artifacts(compose_data)
|
|
208
|
+
|
|
209
|
+
# Print our Job Complete line
|
|
210
|
+
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
|
|
211
|
+
|
|
212
|
+
shutil.rmtree(os.path.join(self.container_root, "input"))
|
|
213
|
+
shutil.rmtree(os.path.join(self.container_root, "shared"))
|
|
214
|
+
for host in self.hosts:
|
|
215
|
+
shutil.rmtree(os.path.join(self.container_root, host))
|
|
216
|
+
for folder in self._temporary_folders:
|
|
217
|
+
shutil.rmtree(os.path.join(self.container_root, folder))
|
|
218
|
+
return artifacts
|
|
219
|
+
|
|
220
|
+
def retrieve_artifacts(
|
|
221
|
+
self,
|
|
222
|
+
compose_data: dict,
|
|
223
|
+
):
|
|
224
|
+
"""Get the model artifacts from all the container nodes.
|
|
225
|
+
|
|
226
|
+
Used after training completes to gather the data from all the
|
|
227
|
+
individual containers. As the official SageMaker Training Service, it
|
|
228
|
+
will override duplicate files if multiple containers have the same file
|
|
229
|
+
names.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
compose_data (dict): Docker-Compose configuration in dictionary
|
|
233
|
+
format.
|
|
234
|
+
|
|
235
|
+
Returns: Local path to the collected model artifacts.
|
|
236
|
+
"""
|
|
237
|
+
# We need a directory to store the artfiacts from all the nodes
|
|
238
|
+
# and another one to contained the compressed final artifacts
|
|
239
|
+
artifacts = os.path.join(self.container_root, "artifacts")
|
|
240
|
+
compressed_artifacts = os.path.join(self.container_root, "compressed_artifacts")
|
|
241
|
+
os.makedirs(artifacts, exist_ok=True)
|
|
242
|
+
|
|
243
|
+
model_artifacts = os.path.join(artifacts, "model")
|
|
244
|
+
output_artifacts = os.path.join(artifacts, "output")
|
|
245
|
+
|
|
246
|
+
artifact_dirs = [model_artifacts, output_artifacts, compressed_artifacts]
|
|
247
|
+
for d in artifact_dirs:
|
|
248
|
+
os.makedirs(d, exist_ok=True)
|
|
249
|
+
|
|
250
|
+
# Gather the artifacts from all nodes into artifacts/model and artifacts/output
|
|
251
|
+
for host in self.hosts:
|
|
252
|
+
volumes = compose_data["services"][str(host)]["volumes"]
|
|
253
|
+
volumes = [v[:-2] if v.endswith(":z") else v for v in volumes]
|
|
254
|
+
for volume in volumes:
|
|
255
|
+
if re.search(r"^[A-Za-z]:", volume):
|
|
256
|
+
unit, host_dir, container_dir = volume.split(":")
|
|
257
|
+
host_dir = unit + ":" + host_dir
|
|
258
|
+
else:
|
|
259
|
+
host_dir, container_dir = volume.split(":")
|
|
260
|
+
if container_dir == "/opt/ml/model":
|
|
261
|
+
recursive_copy(host_dir, model_artifacts)
|
|
262
|
+
elif container_dir == "/opt/ml/output":
|
|
263
|
+
recursive_copy(host_dir, output_artifacts)
|
|
264
|
+
|
|
265
|
+
# Tar Artifacts -> model.tar.gz and output.tar.gz
|
|
266
|
+
model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)]
|
|
267
|
+
output_files = [
|
|
268
|
+
os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)
|
|
269
|
+
]
|
|
270
|
+
create_tar_file(model_files, os.path.join(compressed_artifacts, "model.tar.gz"))
|
|
271
|
+
create_tar_file(output_files, os.path.join(compressed_artifacts, "output.tar.gz"))
|
|
272
|
+
|
|
273
|
+
output_data = "file://%s" % compressed_artifacts
|
|
274
|
+
|
|
275
|
+
return os.path.join(output_data, "model.tar.gz")
|
|
276
|
+
|
|
277
|
+
def _create_config_file_directories(self, host: str):
|
|
278
|
+
"""Creates the directories for the config files.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
host (str): The name of the current host.
|
|
282
|
+
"""
|
|
283
|
+
for d in ["input", "input/config", "output", "model"]:
|
|
284
|
+
os.makedirs(os.path.join(self.container_root, host, d), exist_ok=True)
|
|
285
|
+
|
|
286
|
+
def _write_config_files(
|
|
287
|
+
self,
|
|
288
|
+
host: str,
|
|
289
|
+
input_data_config: Optional[List[Channel]],
|
|
290
|
+
hyper_parameters: Optional[Dict[str, str]],
|
|
291
|
+
):
|
|
292
|
+
"""Write the config files for the training containers.
|
|
293
|
+
|
|
294
|
+
This method writes the hyper_parameters, resources and input data
|
|
295
|
+
configuration files.
|
|
296
|
+
|
|
297
|
+
Returns: None
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
host (str): The name of the current host.
|
|
301
|
+
input_data_config (List[Channel]): Training input channels to be used for
|
|
302
|
+
training.
|
|
303
|
+
hyper_parameters (Dict[str, str]): Hyperparameters for training.
|
|
304
|
+
"""
|
|
305
|
+
config_path = os.path.join(self.container_root, host, "input", "config")
|
|
306
|
+
# Only support single container now
|
|
307
|
+
resource_config = {
|
|
308
|
+
"current_host": host,
|
|
309
|
+
"hosts": self.hosts,
|
|
310
|
+
"network_interface_name": "ethwe",
|
|
311
|
+
"current_instance_type": self.instance_type,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
json_input_data_config = {}
|
|
315
|
+
for channel in input_data_config:
|
|
316
|
+
channel_name = channel.channel_name
|
|
317
|
+
json_input_data_config[channel_name] = {"TrainingInputMode": "File"}
|
|
318
|
+
if channel.content_type != Unassigned():
|
|
319
|
+
json_input_data_config[channel_name]["ContentType"] = channel.content_type
|
|
320
|
+
|
|
321
|
+
_write_json_file(os.path.join(config_path, "hyperparameters.json"), hyper_parameters)
|
|
322
|
+
_write_json_file(os.path.join(config_path, "resourceconfig.json"), resource_config)
|
|
323
|
+
_write_json_file(os.path.join(config_path, "inputdataconfig.json"), json_input_data_config)
|
|
324
|
+
|
|
325
|
+
def _generate_compose_file(self, environment: Dict[str, str], volumes: List[str]) -> dict:
|
|
326
|
+
"""Writes a config file describing a training/hosting environment.
|
|
327
|
+
|
|
328
|
+
This method generates a docker compose configuration file, it has an
|
|
329
|
+
entry for each container that will be created (based on self.hosts). it
|
|
330
|
+
calls
|
|
331
|
+
:meth:~sagemaker.local_session.SageMakerContainer._create_docker_host to
|
|
332
|
+
generate the config for each individual container.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
environment (Dict[str, str]): a dictionary with environment variables to be
|
|
336
|
+
passed on to the containers.
|
|
337
|
+
volumes (List[str]): a list of volumes that will be mapped to
|
|
338
|
+
the containers
|
|
339
|
+
|
|
340
|
+
Returns: (dict) A dictionary representation of the configuration that was written.
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
if os.environ.get(DOCKER_COMPOSE_HTTP_TIMEOUT_ENV) is None:
|
|
344
|
+
os.environ[DOCKER_COMPOSE_HTTP_TIMEOUT_ENV] = DOCKER_COMPOSE_HTTP_TIMEOUT
|
|
345
|
+
|
|
346
|
+
services = {
|
|
347
|
+
host: self._create_docker_host(host, environment, volumes) for host in self.hosts
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
if self.is_studio:
|
|
351
|
+
content = {
|
|
352
|
+
"services": services,
|
|
353
|
+
}
|
|
354
|
+
else:
|
|
355
|
+
content = {
|
|
356
|
+
"services": services,
|
|
357
|
+
"networks": {"sagemaker-local": {"name": "sagemaker-local"}},
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME)
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
import yaml
|
|
364
|
+
except ImportError as e:
|
|
365
|
+
logger.error(_module_import_error("yaml", "Local mode", "local"))
|
|
366
|
+
raise e
|
|
367
|
+
|
|
368
|
+
yaml_content = yaml.dump(content, default_flow_style=False)
|
|
369
|
+
with open(docker_compose_path, "w") as f:
|
|
370
|
+
f.write(yaml_content)
|
|
371
|
+
|
|
372
|
+
return content
|
|
373
|
+
|
|
374
|
+
def _create_docker_host(
|
|
375
|
+
self,
|
|
376
|
+
host: str,
|
|
377
|
+
environment: Dict[str, str],
|
|
378
|
+
volumes: List[str],
|
|
379
|
+
) -> Dict:
|
|
380
|
+
"""Creates the docker host configuration.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
host (str): The host address
|
|
384
|
+
environment (Dict[str, str]): a dictionary with environment variables to be
|
|
385
|
+
passed on to the containers.
|
|
386
|
+
volumes (List[str]): List of volumes that will be mapped to the containers
|
|
387
|
+
"""
|
|
388
|
+
environment = ["{}={}".format(k, v) for k, v in environment.items()]
|
|
389
|
+
aws_creds = None
|
|
390
|
+
if self.sagemaker_session:
|
|
391
|
+
# In local mode only get aws credentials when neccessary
|
|
392
|
+
aws_creds = _aws_credentials(self.sagemaker_session.boto_session)
|
|
393
|
+
if aws_creds is not None:
|
|
394
|
+
environment.extend(aws_creds)
|
|
395
|
+
|
|
396
|
+
if self.is_studio:
|
|
397
|
+
environment.extend([f"{SM_STUDIO_LOCAL_MODE}=True"])
|
|
398
|
+
|
|
399
|
+
# Add volumes for the input and output of each host
|
|
400
|
+
host_volumes = volumes.copy()
|
|
401
|
+
subdirs = ["output", "output/data", "input"]
|
|
402
|
+
for subdir in subdirs:
|
|
403
|
+
host_dir = os.path.join(self.container_root, host, subdir)
|
|
404
|
+
container_dir = "/opt/ml/{}".format(subdir)
|
|
405
|
+
volume = _Volume(host_dir, container_dir)
|
|
406
|
+
host_volumes.append(volume.map)
|
|
407
|
+
|
|
408
|
+
host_config = {
|
|
409
|
+
"image": self.image,
|
|
410
|
+
"volumes": host_volumes,
|
|
411
|
+
"environment": environment,
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
if self.container_entrypoint:
|
|
415
|
+
host_config["entrypoint"] = self.container_entrypoint
|
|
416
|
+
if self.container_arguments:
|
|
417
|
+
host_config["entrypoint"] = host_config["entrypoint"] + self.container_arguments
|
|
418
|
+
|
|
419
|
+
if self.is_studio:
|
|
420
|
+
host_config["network_mode"] = "sagemaker"
|
|
421
|
+
else:
|
|
422
|
+
host_config["networks"] = {"sagemaker-local": {"aliases": [host]}}
|
|
423
|
+
|
|
424
|
+
# for GPU support pass in nvidia as the runtime, this is equivalent
|
|
425
|
+
# to setting --runtime=nvidia in the docker commandline.
|
|
426
|
+
if self.instance_type == "local_gpu":
|
|
427
|
+
host_config["deploy"] = {
|
|
428
|
+
"resources": {
|
|
429
|
+
"reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]}
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
return host_config
|
|
434
|
+
|
|
435
|
+
def _generate_compose_command(self, wait: bool):
|
|
436
|
+
"""Invokes the docker compose command.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
wait (bool): Whether to wait for the docker command result.
|
|
440
|
+
"""
|
|
441
|
+
_compose_cmd_prefix = self._get_compose_cmd_prefix()
|
|
442
|
+
|
|
443
|
+
command = _compose_cmd_prefix + [
|
|
444
|
+
"-f",
|
|
445
|
+
os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME),
|
|
446
|
+
"up",
|
|
447
|
+
"--build",
|
|
448
|
+
"--abort-on-container-exit" if wait else "--detach",
|
|
449
|
+
]
|
|
450
|
+
|
|
451
|
+
logger.info("docker command: %s", " ".join(command))
|
|
452
|
+
return command
|
|
453
|
+
|
|
454
|
+
def _ecr_login_if_needed(self):
|
|
455
|
+
"""Log into ECR, if needed.
|
|
456
|
+
|
|
457
|
+
Only ECR images that not have been pulled locally need login.
|
|
458
|
+
"""
|
|
459
|
+
sagemaker_pattern = re.compile(ECR_URI_PATTERN)
|
|
460
|
+
sagemaker_match = sagemaker_pattern.match(self.image)
|
|
461
|
+
if not sagemaker_match:
|
|
462
|
+
return False
|
|
463
|
+
|
|
464
|
+
# Do we already have the image locally?
|
|
465
|
+
if _check_output("docker images -q %s" % self.image).strip():
|
|
466
|
+
return False
|
|
467
|
+
|
|
468
|
+
if not self.sagemaker_session:
|
|
469
|
+
# In local mode only initiate session when neccessary
|
|
470
|
+
self.sagemaker_session = Session()
|
|
471
|
+
|
|
472
|
+
ecr = self.sagemaker_session.boto_session.client("ecr")
|
|
473
|
+
auth = ecr.get_authorization_token(registryIds=[self.image.split(".")[0]])
|
|
474
|
+
authorization_data = auth["authorizationData"][0]
|
|
475
|
+
|
|
476
|
+
raw_token = base64.b64decode(authorization_data["authorizationToken"])
|
|
477
|
+
token = raw_token.decode("utf-8").strip("AWS:")
|
|
478
|
+
ecr_url = auth["authorizationData"][0]["proxyEndpoint"]
|
|
479
|
+
|
|
480
|
+
# Log in to ecr, but use communicate to not print creds to the console
|
|
481
|
+
cmd = f"docker login {ecr_url} -u AWS --password-stdin".split()
|
|
482
|
+
proc = subprocess.Popen(
|
|
483
|
+
cmd,
|
|
484
|
+
stdin=subprocess.PIPE,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
proc.communicate(input=token.encode())
|
|
488
|
+
|
|
489
|
+
return True
|
|
490
|
+
|
|
491
|
+
def _prepare_training_volumes(
|
|
492
|
+
self,
|
|
493
|
+
data_dir: str,
|
|
494
|
+
input_data_config: Optional[List[Channel]],
|
|
495
|
+
hyper_parameters: Optional[Dict[str, str]],
|
|
496
|
+
) -> List[str]:
|
|
497
|
+
"""Prepares the training volumes based on input and output data configs.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
data_dir (str): The directory of input data.
|
|
501
|
+
input_data_config (Optional[List[Channel]]): Training input channels to be used for
|
|
502
|
+
training.
|
|
503
|
+
hyper_parameters (Optional[Dict[str, str]]): Hyperparameters for training.
|
|
504
|
+
"""
|
|
505
|
+
volumes = []
|
|
506
|
+
model_dir = os.path.join(self.container_root, "model")
|
|
507
|
+
volumes.append(_Volume(model_dir, "/opt/ml/model").map)
|
|
508
|
+
|
|
509
|
+
# Mount the metadata directory if present.
|
|
510
|
+
# Only expected to be present on SM notebook instances.
|
|
511
|
+
# This is used by some DeepEngine libraries
|
|
512
|
+
metadata_dir = "/opt/ml/metadata"
|
|
513
|
+
if os.path.isdir(metadata_dir):
|
|
514
|
+
volumes.append(_Volume(metadata_dir, metadata_dir).map)
|
|
515
|
+
|
|
516
|
+
# Set up the channels for the containers. For local data we will
|
|
517
|
+
# mount the local directory to the container. For S3 Data we will download the S3 data
|
|
518
|
+
# first.
|
|
519
|
+
for channel in input_data_config:
|
|
520
|
+
channel_name = channel.channel_name
|
|
521
|
+
channel_dir = os.path.join(data_dir, channel_name)
|
|
522
|
+
os.makedirs(channel_dir, exist_ok=True)
|
|
523
|
+
|
|
524
|
+
data_source_local_path = self._get_data_source_local_path(channel.data_source)
|
|
525
|
+
volumes.append(_Volume(data_source_local_path, channel=channel_name).map)
|
|
526
|
+
|
|
527
|
+
# If there is a training script directory and it is a local directory,
|
|
528
|
+
# mount it to the container.
|
|
529
|
+
if DIR_PARAM_NAME in hyper_parameters:
|
|
530
|
+
training_dir = hyper_parameters[DIR_PARAM_NAME]
|
|
531
|
+
parsed_uri = urlparse(training_dir)
|
|
532
|
+
if parsed_uri.scheme == "file":
|
|
533
|
+
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
|
|
534
|
+
volumes.append(_Volume(host_dir, "/opt/ml/code").map)
|
|
535
|
+
shared_dir = os.path.join(self.container_root, "shared")
|
|
536
|
+
volumes.append(_Volume(shared_dir, "/opt/ml/shared").map)
|
|
537
|
+
|
|
538
|
+
return volumes
|
|
539
|
+
|
|
540
|
+
def _get_data_source_local_path(self, data_source: DataSource):
|
|
541
|
+
"""Return a local data path of :class:`sagemaker.local.data.DataSource`.
|
|
542
|
+
|
|
543
|
+
If the data source is from S3, the data will be downloaded to a temporary
|
|
544
|
+
local path.
|
|
545
|
+
If the data source is local file, the absolute path will be returned.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
data_source (DataSource): a data source of local file or s3
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
str: The local path of the data.
|
|
552
|
+
"""
|
|
553
|
+
if data_source.s3_data_source != Unassigned():
|
|
554
|
+
uri = data_source.s3_data_source.s3_uri
|
|
555
|
+
parsed_uri = urlparse(uri)
|
|
556
|
+
local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name
|
|
557
|
+
self._temporary_folders.append(local_dir)
|
|
558
|
+
download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session)
|
|
559
|
+
return local_dir
|
|
560
|
+
else:
|
|
561
|
+
return os.path.abspath(data_source.file_system_data_source.directory_path)
|
|
562
|
+
|
|
563
|
+
def _get_compose_cmd_prefix(self) -> List[str]:
|
|
564
|
+
"""Gets the Docker Compose command.
|
|
565
|
+
|
|
566
|
+
The method initially looks for 'docker compose' v2
|
|
567
|
+
executable, if not found looks for 'docker-compose' executable.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
List[str]: Docker Compose executable split into list.
|
|
571
|
+
|
|
572
|
+
Raises:
|
|
573
|
+
ImportError: If Docker Compose executable was not found.
|
|
574
|
+
"""
|
|
575
|
+
compose_cmd_prefix = []
|
|
576
|
+
|
|
577
|
+
output = None
|
|
578
|
+
try:
|
|
579
|
+
output = subprocess.check_output(
|
|
580
|
+
["docker", "compose", "version"],
|
|
581
|
+
stderr=subprocess.DEVNULL,
|
|
582
|
+
encoding="UTF-8",
|
|
583
|
+
)
|
|
584
|
+
except subprocess.CalledProcessError:
|
|
585
|
+
logger.info(
|
|
586
|
+
"'Docker Compose' is not installed. "
|
|
587
|
+
"Proceeding to check for 'docker-compose' CLI."
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
if output and "v2" in output.strip():
|
|
591
|
+
logger.info("'Docker Compose' found using Docker CLI.")
|
|
592
|
+
compose_cmd_prefix.extend(["docker", "compose"])
|
|
593
|
+
return compose_cmd_prefix
|
|
594
|
+
|
|
595
|
+
if shutil.which("docker-compose") is not None:
|
|
596
|
+
logger.info("'Docker Compose' found using Docker Compose CLI.")
|
|
597
|
+
compose_cmd_prefix.extend(["docker-compose"])
|
|
598
|
+
return compose_cmd_prefix
|
|
599
|
+
|
|
600
|
+
raise ImportError(
|
|
601
|
+
"Docker Compose is not installed. "
|
|
602
|
+
"Local Mode features will not work without docker compose. "
|
|
603
|
+
"For more information on how to install 'docker compose', please, see "
|
|
604
|
+
"https://docs.docker.com/compose/install/"
|
|
605
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Templates module."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
EXECUTE_BASE_COMMANDS = """
|
|
17
|
+
CMD="{base_command}"
|
|
18
|
+
echo "Executing command: $CMD"
|
|
19
|
+
eval $CMD
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
EXECUTE_BASIC_SCRIPT_DRIVER = """
|
|
23
|
+
echo "Running Basic Script driver"
|
|
24
|
+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
EXEUCTE_DISTRIBUTED_DRIVER = """
|
|
28
|
+
echo "Running {driver_name} Driver"
|
|
29
|
+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script}
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
TRAIN_SCRIPT_TEMPLATE = """
|
|
33
|
+
#!/bin/bash
|
|
34
|
+
set -e
|
|
35
|
+
echo "Starting training script"
|
|
36
|
+
|
|
37
|
+
handle_error() {{
|
|
38
|
+
EXIT_STATUS=$?
|
|
39
|
+
echo "An error occurred with exit code $EXIT_STATUS"
|
|
40
|
+
if [ ! -s /opt/ml/output/failure ]; then
|
|
41
|
+
echo "Training Execution failed. For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
|
|
42
|
+
TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure
|
|
43
|
+
fi
|
|
44
|
+
exit $EXIT_STATUS
|
|
45
|
+
}}
|
|
46
|
+
|
|
47
|
+
check_python() {{
|
|
48
|
+
SM_PYTHON_CMD=$(command -v python3 || command -v python)
|
|
49
|
+
SM_PIP_CMD=$(command -v pip3 || command -v pip)
|
|
50
|
+
|
|
51
|
+
# Check if Python is found
|
|
52
|
+
if [[ -z "$SM_PYTHON_CMD" || -z "$SM_PIP_CMD" ]]; then
|
|
53
|
+
echo "Error: The Python executable was not found in the system path."
|
|
54
|
+
return 1
|
|
55
|
+
fi
|
|
56
|
+
|
|
57
|
+
return 0
|
|
58
|
+
}}
|
|
59
|
+
|
|
60
|
+
trap 'handle_error' ERR
|
|
61
|
+
|
|
62
|
+
check_python
|
|
63
|
+
|
|
64
|
+
$SM_PYTHON_CMD --version
|
|
65
|
+
|
|
66
|
+
echo "/opt/ml/input/config/resourceconfig.json:"
|
|
67
|
+
cat /opt/ml/input/config/resourceconfig.json
|
|
68
|
+
echo
|
|
69
|
+
|
|
70
|
+
echo "/opt/ml/input/config/inputdataconfig.json:"
|
|
71
|
+
cat /opt/ml/input/config/inputdataconfig.json
|
|
72
|
+
echo
|
|
73
|
+
|
|
74
|
+
echo "Setting up environment variables"
|
|
75
|
+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py
|
|
76
|
+
source /opt/ml/input/sm_training.env
|
|
77
|
+
|
|
78
|
+
{working_dir}
|
|
79
|
+
{install_requirements}
|
|
80
|
+
{execute_driver}
|
|
81
|
+
|
|
82
|
+
echo "Training Container Execution Completed"
|
|
83
|
+
"""
|