sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module is used to define the environment variables for the training job container."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import Dict, Any
|
|
17
|
+
import multiprocessing
|
|
18
|
+
import subprocess
|
|
19
|
+
import json
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
import logging
|
|
24
|
+
|
|
25
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
26
|
+
|
|
27
|
+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
|
|
28
|
+
safe_serialize,
|
|
29
|
+
safe_deserialize,
|
|
30
|
+
read_distributed_json,
|
|
31
|
+
read_source_code_json,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Initialize logger
|
|
35
|
+
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
38
|
+
logger.addHandler(console_handler)
|
|
39
|
+
logger.setLevel(int(SM_LOG_LEVEL))
|
|
40
|
+
|
|
41
|
+
SM_MODEL_DIR = "/opt/ml/model"
|
|
42
|
+
|
|
43
|
+
SM_INPUT_DIR = "/opt/ml/input"
|
|
44
|
+
SM_INPUT_DATA_DIR = "/opt/ml/input/data"
|
|
45
|
+
SM_INPUT_CONFIG_DIR = "/opt/ml/input/config"
|
|
46
|
+
|
|
47
|
+
SM_OUTPUT_DIR = "/opt/ml/output"
|
|
48
|
+
SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
|
|
49
|
+
SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
|
|
50
|
+
SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code"
|
|
51
|
+
SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers"
|
|
52
|
+
|
|
53
|
+
SM_MASTER_ADDR = "algo-1"
|
|
54
|
+
SM_MASTER_PORT = 7777
|
|
55
|
+
|
|
56
|
+
RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json"
|
|
57
|
+
INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json"
|
|
58
|
+
HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json"
|
|
59
|
+
|
|
60
|
+
ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env"
|
|
61
|
+
|
|
62
|
+
SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"]
|
|
63
|
+
HIDDEN_VALUE = "******"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def num_cpus() -> int:
|
|
67
|
+
"""Return the number of CPUs available in the current container.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
int: Number of CPUs available in the current container.
|
|
71
|
+
"""
|
|
72
|
+
return multiprocessing.cpu_count()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def num_gpus() -> int:
|
|
76
|
+
"""Return the number of GPUs available in the current container.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
int: Number of GPUs available in the current container.
|
|
80
|
+
"""
|
|
81
|
+
try:
|
|
82
|
+
cmd = ["nvidia-smi", "--list-gpus"]
|
|
83
|
+
output = subprocess.check_output(cmd).decode("utf-8")
|
|
84
|
+
return sum(1 for line in output.splitlines() if line.startswith("GPU "))
|
|
85
|
+
except (OSError, subprocess.CalledProcessError):
|
|
86
|
+
logger.info("No GPUs detected (normal if no gpus installed)")
|
|
87
|
+
return 0
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def num_neurons() -> int:
|
|
91
|
+
"""Return the number of neuron cores available in the current container.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
int: Number of Neuron Cores available in the current container.
|
|
95
|
+
"""
|
|
96
|
+
try:
|
|
97
|
+
cmd = ["neuron-ls", "-j"]
|
|
98
|
+
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8")
|
|
99
|
+
j = json.loads(output)
|
|
100
|
+
neuron_cores = 0
|
|
101
|
+
for item in j:
|
|
102
|
+
neuron_cores += item.get("nc_count", 0)
|
|
103
|
+
logger.info("Found %s neurons on this instance", neuron_cores)
|
|
104
|
+
return neuron_cores
|
|
105
|
+
except OSError:
|
|
106
|
+
logger.info("No Neurons detected (normal if no neurons installed)")
|
|
107
|
+
return 0
|
|
108
|
+
except subprocess.CalledProcessError as e:
|
|
109
|
+
if e.output is not None:
|
|
110
|
+
try:
|
|
111
|
+
msg = e.output.decode("utf-8").partition("error=")[2]
|
|
112
|
+
logger.info(
|
|
113
|
+
"No Neurons detected (normal if no neurons installed). \
|
|
114
|
+
If neuron installed then %s",
|
|
115
|
+
msg,
|
|
116
|
+
)
|
|
117
|
+
except AttributeError:
|
|
118
|
+
logger.info("No Neurons detected (normal if no neurons installed)")
|
|
119
|
+
else:
|
|
120
|
+
logger.info("No Neurons detected (normal if no neurons installed)")
|
|
121
|
+
|
|
122
|
+
return 0
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]:
|
|
126
|
+
"""Deserialize hyperparameters from string to their original types.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
hyperparameters (Dict[str, str]): Hyperparameters as strings.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Dict[str, Any]: Hyperparameters as their original types.
|
|
133
|
+
"""
|
|
134
|
+
deserialized_hyperparameters = {}
|
|
135
|
+
for key, value in hyperparameters.items():
|
|
136
|
+
deserialized_hyperparameters[key] = safe_deserialize(value)
|
|
137
|
+
return deserialized_hyperparameters
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def set_env(
|
|
141
|
+
resource_config: Dict[str, Any],
|
|
142
|
+
input_data_config: Dict[str, Any],
|
|
143
|
+
hyperparameters_config: Dict[str, Any],
|
|
144
|
+
output_file: str = ENV_OUTPUT_FILE,
|
|
145
|
+
):
|
|
146
|
+
"""Set environment variables for the training job container.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
resource_config (Dict[str, Any]): Resource configuration for the training job.
|
|
150
|
+
input_data_config (Dict[str, Any]): Input data configuration for the training job.
|
|
151
|
+
hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job.
|
|
152
|
+
output_file (str): Output file to write the environment variables.
|
|
153
|
+
"""
|
|
154
|
+
# Constants
|
|
155
|
+
env_vars = {
|
|
156
|
+
"SM_MODEL_DIR": SM_MODEL_DIR,
|
|
157
|
+
"SM_INPUT_DIR": SM_INPUT_DIR,
|
|
158
|
+
"SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR,
|
|
159
|
+
"SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR,
|
|
160
|
+
"SM_OUTPUT_DIR": SM_OUTPUT_DIR,
|
|
161
|
+
"SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE,
|
|
162
|
+
"SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR,
|
|
163
|
+
"SM_LOG_LEVEL": SM_LOG_LEVEL,
|
|
164
|
+
"SM_MASTER_ADDR": SM_MASTER_ADDR,
|
|
165
|
+
"SM_MASTER_PORT": SM_MASTER_PORT,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# SourceCode and DistributedConfig Environment Variables
|
|
169
|
+
source_code = read_source_code_json()
|
|
170
|
+
if source_code:
|
|
171
|
+
env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH
|
|
172
|
+
env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "")
|
|
173
|
+
|
|
174
|
+
distributed = read_distributed_json()
|
|
175
|
+
if distributed:
|
|
176
|
+
env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH
|
|
177
|
+
env_vars["SM_DISTRIBUTED_CONFIG"] = distributed
|
|
178
|
+
|
|
179
|
+
# Data Channels
|
|
180
|
+
channels = list(input_data_config.keys())
|
|
181
|
+
for channel in channels:
|
|
182
|
+
env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}"
|
|
183
|
+
env_vars["SM_CHANNELS"] = channels
|
|
184
|
+
|
|
185
|
+
# Hyperparameters
|
|
186
|
+
hps = deserialize_hyperparameters(hyperparameters_config)
|
|
187
|
+
for key, value in hps.items():
|
|
188
|
+
key_upper = key.replace("-", "_").upper()
|
|
189
|
+
env_vars[f"SM_HP_{key_upper}"] = value
|
|
190
|
+
env_vars["SM_HPS"] = hps
|
|
191
|
+
|
|
192
|
+
# Host Variables
|
|
193
|
+
current_host = resource_config["current_host"]
|
|
194
|
+
current_instance_type = resource_config["current_instance_type"]
|
|
195
|
+
hosts = resource_config["hosts"]
|
|
196
|
+
sorted_hosts = sorted(hosts)
|
|
197
|
+
|
|
198
|
+
env_vars["SM_CURRENT_HOST"] = current_host
|
|
199
|
+
env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type
|
|
200
|
+
env_vars["SM_HOSTS"] = sorted_hosts
|
|
201
|
+
env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"]
|
|
202
|
+
env_vars["SM_HOST_COUNT"] = len(sorted_hosts)
|
|
203
|
+
env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host)
|
|
204
|
+
|
|
205
|
+
env_vars["SM_NUM_CPUS"] = num_cpus()
|
|
206
|
+
env_vars["SM_NUM_GPUS"] = num_gpus()
|
|
207
|
+
env_vars["SM_NUM_NEURONS"] = num_neurons()
|
|
208
|
+
|
|
209
|
+
# Misc.
|
|
210
|
+
env_vars["SM_RESOURCE_CONFIG"] = resource_config
|
|
211
|
+
env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config
|
|
212
|
+
|
|
213
|
+
# All Training Environment Variables
|
|
214
|
+
env_vars["SM_TRAINING_ENV"] = {
|
|
215
|
+
"channel_input_dirs": {
|
|
216
|
+
channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels
|
|
217
|
+
},
|
|
218
|
+
"current_host": env_vars["SM_CURRENT_HOST"],
|
|
219
|
+
"current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"],
|
|
220
|
+
"hosts": env_vars["SM_HOSTS"],
|
|
221
|
+
"master_addr": env_vars["SM_MASTER_ADDR"],
|
|
222
|
+
"master_port": env_vars["SM_MASTER_PORT"],
|
|
223
|
+
"hyperparameters": env_vars["SM_HPS"],
|
|
224
|
+
"input_data_config": input_data_config,
|
|
225
|
+
"input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"],
|
|
226
|
+
"input_data_dir": env_vars["SM_INPUT_DATA_DIR"],
|
|
227
|
+
"input_dir": env_vars["SM_INPUT_DIR"],
|
|
228
|
+
"job_name": os.environ["TRAINING_JOB_NAME"],
|
|
229
|
+
"log_level": env_vars["SM_LOG_LEVEL"],
|
|
230
|
+
"model_dir": env_vars["SM_MODEL_DIR"],
|
|
231
|
+
"network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"],
|
|
232
|
+
"num_cpus": env_vars["SM_NUM_CPUS"],
|
|
233
|
+
"num_gpus": env_vars["SM_NUM_GPUS"],
|
|
234
|
+
"num_neurons": env_vars["SM_NUM_NEURONS"],
|
|
235
|
+
"output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"],
|
|
236
|
+
"resource_config": env_vars["SM_RESOURCE_CONFIG"],
|
|
237
|
+
}
|
|
238
|
+
with open(output_file, "w") as f:
|
|
239
|
+
for key, value in env_vars.items():
|
|
240
|
+
f.write(f"export {key}='{safe_serialize(value)}'\n")
|
|
241
|
+
|
|
242
|
+
logger.info("Environment Variables:")
|
|
243
|
+
log_env_variables(env_vars_dict=env_vars)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def mask_sensitive_info(data):
|
|
247
|
+
"""Recursively mask sensitive information in a dictionary."""
|
|
248
|
+
if isinstance(data, dict):
|
|
249
|
+
for k, v in data.items():
|
|
250
|
+
if isinstance(v, dict):
|
|
251
|
+
data[k] = mask_sensitive_info(v)
|
|
252
|
+
elif isinstance(v, str) and any(
|
|
253
|
+
keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS
|
|
254
|
+
):
|
|
255
|
+
data[k] = HIDDEN_VALUE
|
|
256
|
+
return data
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def log_key_value(key: str, value: str):
|
|
260
|
+
"""Log a key-value pair, masking sensitive values if necessary."""
|
|
261
|
+
if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS):
|
|
262
|
+
logger.info("%s=%s", key, HIDDEN_VALUE)
|
|
263
|
+
elif isinstance(value, dict):
|
|
264
|
+
masked_value = mask_sensitive_info(value)
|
|
265
|
+
logger.info("%s=%s", key, json.dumps(masked_value))
|
|
266
|
+
else:
|
|
267
|
+
try:
|
|
268
|
+
decoded_value = json.loads(value)
|
|
269
|
+
if isinstance(decoded_value, dict):
|
|
270
|
+
masked_value = mask_sensitive_info(decoded_value)
|
|
271
|
+
logger.info("%s=%s", key, json.dumps(masked_value))
|
|
272
|
+
else:
|
|
273
|
+
logger.info("%s=%s", key, decoded_value)
|
|
274
|
+
except (json.JSONDecodeError, TypeError):
|
|
275
|
+
logger.info("%s=%s", key, value)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def log_env_variables(env_vars_dict: Dict[str, Any]):
|
|
279
|
+
"""Log Environment Variables from the environment and an env_vars_dict."""
|
|
280
|
+
for key, value in os.environ.items():
|
|
281
|
+
log_key_value(key, value)
|
|
282
|
+
|
|
283
|
+
for key, value in env_vars_dict.items():
|
|
284
|
+
log_key_value(key, value)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def main():
|
|
288
|
+
"""Main function to set the environment variables for the training job container."""
|
|
289
|
+
with open(RESOURCE_CONFIG, "r") as f:
|
|
290
|
+
resource_config = json.load(f)
|
|
291
|
+
with open(INPUT_DATA_CONFIG, "r") as f:
|
|
292
|
+
input_data_config = json.load(f)
|
|
293
|
+
with open(HYPERPARAMETERS_CONFIG, "r") as f:
|
|
294
|
+
hyperparameters_config = json.load(f)
|
|
295
|
+
|
|
296
|
+
set_env(
|
|
297
|
+
resource_config=resource_config,
|
|
298
|
+
input_data_config=input_data_config,
|
|
299
|
+
hyperparameters_config=hyperparameters_config,
|
|
300
|
+
output_file=ENV_OUTPUT_FILE,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
if __name__ == "__main__":
|
|
305
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Utility functions for SageMaker training recipes."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
import os
|
|
18
|
+
import json
|
|
19
|
+
import shutil
|
|
20
|
+
import tempfile
|
|
21
|
+
from urllib.request import urlretrieve
|
|
22
|
+
from typing import Dict, Any, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import omegaconf
|
|
25
|
+
from omegaconf import OmegaConf, dictconfig
|
|
26
|
+
|
|
27
|
+
from sagemaker.core.image_uris import retrieve
|
|
28
|
+
|
|
29
|
+
from sagemaker.core.modules import logger
|
|
30
|
+
from sagemaker.core.modules.utils import _run_clone_command_silent
|
|
31
|
+
from sagemaker.core.modules.configs import Compute, SourceCode
|
|
32
|
+
from sagemaker.core.modules.distributed import Torchrun, SMP
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _try_resolve_recipe(recipe, key=None):
|
|
36
|
+
"""Try to resolve recipe and return resolved recipe."""
|
|
37
|
+
if key is not None:
|
|
38
|
+
recipe = dictconfig.DictConfig({key: recipe})
|
|
39
|
+
try:
|
|
40
|
+
OmegaConf.resolve(recipe)
|
|
41
|
+
except omegaconf.errors.OmegaConfBaseException:
|
|
42
|
+
return None
|
|
43
|
+
if key is None:
|
|
44
|
+
return recipe
|
|
45
|
+
return recipe[key]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _determine_device_type(instance_type: str) -> str:
|
|
49
|
+
"""Determine device type (gpu, cpu, trainium) based on instance type."""
|
|
50
|
+
instance_family = instance_type.split(".")[1]
|
|
51
|
+
if instance_family.startswith(("p", "g")):
|
|
52
|
+
return "gpu"
|
|
53
|
+
if instance_family.startswith("trn"):
|
|
54
|
+
return "trainium"
|
|
55
|
+
return "cpu"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _load_recipes_cfg() -> str:
|
|
59
|
+
"""Load training recipes configuration json."""
|
|
60
|
+
training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json")
|
|
61
|
+
with open(training_recipes_cfg_filename) as training_recipes_cfg_file:
|
|
62
|
+
training_recipes_cfg = json.load(training_recipes_cfg_file)
|
|
63
|
+
return training_recipes_cfg
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _load_base_recipe(
|
|
67
|
+
training_recipe: str,
|
|
68
|
+
recipe_overrides: Optional[Dict[str, Any]] = None,
|
|
69
|
+
training_recipes_cfg: Optional[Dict[str, Any]] = None,
|
|
70
|
+
) -> Dict[str, Any]:
|
|
71
|
+
"""Load recipe and apply overrides."""
|
|
72
|
+
if recipe_overrides is None:
|
|
73
|
+
recipe_overrides = dict()
|
|
74
|
+
|
|
75
|
+
temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name
|
|
76
|
+
|
|
77
|
+
if training_recipe.endswith(".yaml"):
|
|
78
|
+
if os.path.isfile(training_recipe):
|
|
79
|
+
shutil.copy(training_recipe, temp_local_recipe)
|
|
80
|
+
else:
|
|
81
|
+
try:
|
|
82
|
+
urlretrieve(training_recipe, temp_local_recipe)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}"
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
|
|
89
|
+
|
|
90
|
+
launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get(
|
|
91
|
+
"launcher_repo"
|
|
92
|
+
)
|
|
93
|
+
_run_clone_command_silent(launcher_repo, recipe_launcher_dir.name)
|
|
94
|
+
|
|
95
|
+
recipe = os.path.join(
|
|
96
|
+
recipe_launcher_dir.name,
|
|
97
|
+
"recipes_collection",
|
|
98
|
+
"recipes",
|
|
99
|
+
training_recipe + ".yaml",
|
|
100
|
+
)
|
|
101
|
+
if os.path.isfile(recipe):
|
|
102
|
+
shutil.copy(recipe, temp_local_recipe)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Recipe {training_recipe} not found.")
|
|
105
|
+
|
|
106
|
+
recipe = OmegaConf.load(temp_local_recipe)
|
|
107
|
+
os.unlink(temp_local_recipe)
|
|
108
|
+
recipe = OmegaConf.merge(recipe, recipe_overrides)
|
|
109
|
+
return recipe
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _register_custom_resolvers():
|
|
113
|
+
"""Register custom resolvers for OmegaConf."""
|
|
114
|
+
if not OmegaConf.has_resolver("multiply"):
|
|
115
|
+
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
|
|
116
|
+
if not OmegaConf.has_resolver("divide_ceil"):
|
|
117
|
+
OmegaConf.register_new_resolver(
|
|
118
|
+
"divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True
|
|
119
|
+
)
|
|
120
|
+
if not OmegaConf.has_resolver("divide_floor"):
|
|
121
|
+
OmegaConf.register_new_resolver(
|
|
122
|
+
"divide_floor", lambda x, y: int(math.floor(x / y)), replace=True
|
|
123
|
+
)
|
|
124
|
+
if not OmegaConf.has_resolver("add"):
|
|
125
|
+
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
|
|
129
|
+
"""Get the model base name and script for the training recipe."""
|
|
130
|
+
|
|
131
|
+
model_type_to_script = {
|
|
132
|
+
"llama": ("llama", "llama_pretrain.py"),
|
|
133
|
+
"mistral": ("mistral", "mistral_pretrain.py"),
|
|
134
|
+
"mixtral": ("mixtral", "mixtral_pretrain.py"),
|
|
135
|
+
"deepseek": ("deepseek", "deepseek_pretrain.py"),
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
for key in model_type_to_script:
|
|
139
|
+
if model_type.startswith(key):
|
|
140
|
+
model_type = key
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
if model_type not in model_type_to_script:
|
|
144
|
+
raise ValueError(f"Model type {model_type} not supported")
|
|
145
|
+
|
|
146
|
+
return model_type_to_script[model_type][0], model_type_to_script[model_type][1]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _configure_gpu_args(
|
|
150
|
+
training_recipes_cfg: Dict[str, Any],
|
|
151
|
+
region_name: str,
|
|
152
|
+
recipe: OmegaConf,
|
|
153
|
+
recipe_train_dir: tempfile.TemporaryDirectory,
|
|
154
|
+
) -> Dict[str, Any]:
|
|
155
|
+
"""Configure arguments specific to GPU."""
|
|
156
|
+
source_code = SourceCode()
|
|
157
|
+
args = dict()
|
|
158
|
+
|
|
159
|
+
adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get(
|
|
160
|
+
"adapter_repo"
|
|
161
|
+
)
|
|
162
|
+
_run_clone_command_silent(adapter_repo, recipe_train_dir.name)
|
|
163
|
+
|
|
164
|
+
if "model" not in recipe:
|
|
165
|
+
raise ValueError("Supplied recipe does not contain required field model.")
|
|
166
|
+
if "model_type" not in recipe["model"]:
|
|
167
|
+
raise ValueError("Supplied recipe does not contain required field model_type.")
|
|
168
|
+
model_type = recipe["model"]["model_type"]
|
|
169
|
+
|
|
170
|
+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type)
|
|
171
|
+
|
|
172
|
+
source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name)
|
|
173
|
+
source_code.entry_script = script
|
|
174
|
+
|
|
175
|
+
gpu_image_cfg = training_recipes_cfg.get("gpu_image")
|
|
176
|
+
if isinstance(gpu_image_cfg, str):
|
|
177
|
+
training_image = gpu_image_cfg
|
|
178
|
+
else:
|
|
179
|
+
training_image = retrieve(
|
|
180
|
+
gpu_image_cfg.get("framework"),
|
|
181
|
+
region=region_name,
|
|
182
|
+
version=gpu_image_cfg.get("version"),
|
|
183
|
+
image_scope="training",
|
|
184
|
+
**gpu_image_cfg.get("additional_args"),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Setting dummy parameters for now
|
|
188
|
+
torch_distributed = Torchrun(smp=SMP(random_seed="123456"))
|
|
189
|
+
args.update(
|
|
190
|
+
{
|
|
191
|
+
"source_code": source_code,
|
|
192
|
+
"training_image": training_image,
|
|
193
|
+
"distributed": torch_distributed,
|
|
194
|
+
}
|
|
195
|
+
)
|
|
196
|
+
return args
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _configure_trainium_args(
|
|
200
|
+
training_recipes_cfg: Dict[str, Any],
|
|
201
|
+
region_name: str,
|
|
202
|
+
recipe_train_dir: tempfile.TemporaryDirectory,
|
|
203
|
+
) -> Dict[str, Any]:
|
|
204
|
+
"""Configure arguments specific to Trainium."""
|
|
205
|
+
source_code = SourceCode()
|
|
206
|
+
args = dict()
|
|
207
|
+
|
|
208
|
+
_run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name)
|
|
209
|
+
|
|
210
|
+
source_code.source_dir = os.path.join(recipe_train_dir.name, "examples")
|
|
211
|
+
source_code.entry_script = "training_orchestrator.py"
|
|
212
|
+
neuron_image_cfg = training_recipes_cfg.get("neuron_image")
|
|
213
|
+
if isinstance(neuron_image_cfg, str):
|
|
214
|
+
training_image = neuron_image_cfg
|
|
215
|
+
else:
|
|
216
|
+
training_image = retrieve(
|
|
217
|
+
neuron_image_cfg.get("framework"),
|
|
218
|
+
region=region_name,
|
|
219
|
+
version=neuron_image_cfg.get("version"),
|
|
220
|
+
image_scope="training",
|
|
221
|
+
**neuron_image_cfg.get("additional_args"),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
args.update(
|
|
225
|
+
{
|
|
226
|
+
"source_code": source_code,
|
|
227
|
+
"training_image": training_image,
|
|
228
|
+
"distributed": Torchrun(),
|
|
229
|
+
}
|
|
230
|
+
)
|
|
231
|
+
return args
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _get_args_from_recipe(
|
|
235
|
+
training_recipe: str,
|
|
236
|
+
compute: Compute,
|
|
237
|
+
region_name: str,
|
|
238
|
+
recipe_overrides: Optional[Dict[str, Any]],
|
|
239
|
+
requirements: Optional[str],
|
|
240
|
+
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
|
|
241
|
+
"""Get arguments for ModelTrainer from a training recipe.
|
|
242
|
+
|
|
243
|
+
Returns a dictionary of arguments to be used with ModelTrainer like:
|
|
244
|
+
```python
|
|
245
|
+
{
|
|
246
|
+
"source_code": SourceCode,
|
|
247
|
+
"training_image": str,
|
|
248
|
+
"distributed": DistributedConfig,
|
|
249
|
+
"compute": Compute,
|
|
250
|
+
"hyperparameters": Dict[str, Any],
|
|
251
|
+
}
|
|
252
|
+
```
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
training_recipe (str):
|
|
256
|
+
Name of the training recipe or path to the recipe file.
|
|
257
|
+
compute (Compute):
|
|
258
|
+
Compute configuration for training.
|
|
259
|
+
region_name (str):
|
|
260
|
+
Name of the AWS region.
|
|
261
|
+
recipe_overrides (Optional[Dict[str, Any]]):
|
|
262
|
+
Overrides for the training recipe.
|
|
263
|
+
requirements (Optional[str]):
|
|
264
|
+
Path to the requirements file.
|
|
265
|
+
"""
|
|
266
|
+
if compute.instance_type is None:
|
|
267
|
+
raise ValueError("Must set `instance_type` in compute when using training recipes.")
|
|
268
|
+
|
|
269
|
+
training_recipes_cfg = _load_recipes_cfg()
|
|
270
|
+
recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
|
|
271
|
+
|
|
272
|
+
if "trainer" not in recipe:
|
|
273
|
+
raise ValueError("Supplied recipe does not contain required field trainer.")
|
|
274
|
+
|
|
275
|
+
# Set instance_count
|
|
276
|
+
if compute.instance_count and "num_nodes" in recipe["trainer"]:
|
|
277
|
+
logger.warning(
|
|
278
|
+
f"Using Compute to set instance_count:\n{compute}."
|
|
279
|
+
"\nIgnoring trainer -> num_nodes in recipe."
|
|
280
|
+
)
|
|
281
|
+
if compute.instance_count is None:
|
|
282
|
+
if "num_nodes" not in recipe["trainer"]:
|
|
283
|
+
raise ValueError(
|
|
284
|
+
"Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe."
|
|
285
|
+
)
|
|
286
|
+
compute.instance_count = recipe["trainer"]["num_nodes"]
|
|
287
|
+
|
|
288
|
+
if requirements and not os.path.isfile(requirements):
|
|
289
|
+
raise ValueError(f"Recipe requirements file {requirements} not found.")
|
|
290
|
+
|
|
291
|
+
# Get Training Image, SourceCode, and distributed args
|
|
292
|
+
device_type = _determine_device_type(compute.instance_type)
|
|
293
|
+
recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
|
|
294
|
+
if device_type == "gpu":
|
|
295
|
+
args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir)
|
|
296
|
+
elif device_type == "trainium":
|
|
297
|
+
args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir)
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Devices of type {device_type} are not supported with training recipes.")
|
|
300
|
+
|
|
301
|
+
_register_custom_resolvers()
|
|
302
|
+
|
|
303
|
+
# Resolve Final Recipe
|
|
304
|
+
final_recipe = _try_resolve_recipe(recipe)
|
|
305
|
+
if final_recipe is None:
|
|
306
|
+
final_recipe = _try_resolve_recipe(recipe, "recipes")
|
|
307
|
+
if final_recipe is None:
|
|
308
|
+
final_recipe = _try_resolve_recipe(recipe, "training")
|
|
309
|
+
if final_recipe is None:
|
|
310
|
+
raise RuntimeError("Could not resolve provided recipe.")
|
|
311
|
+
|
|
312
|
+
# Save Final Recipe to source_dir
|
|
313
|
+
OmegaConf.save(
|
|
314
|
+
config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml")
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# If recipe_requirements is provided, copy it to source_dir
|
|
318
|
+
if requirements:
|
|
319
|
+
shutil.copy(requirements, args["source_code"].source_dir)
|
|
320
|
+
args["source_code"].requirements = os.path.basename(requirements)
|
|
321
|
+
|
|
322
|
+
# Update args with compute and hyperparameters
|
|
323
|
+
args.update(
|
|
324
|
+
{
|
|
325
|
+
"compute": compute,
|
|
326
|
+
"hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"},
|
|
327
|
+
}
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return args, recipe_train_dir
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Types module."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import Union
|
|
17
|
+
from sagemaker.core.modules.configs import S3DataSource, FileSystemDataSource
|
|
18
|
+
|
|
19
|
+
DataSourceType = Union[str, S3DataSource, FileSystemDataSource]
|