sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""HuggingFace framework support for SageMaker."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from sagemaker.core.huggingface.estimator import HuggingFace # noqa: F401
|
|
17
|
+
from sagemaker.core.huggingface.llm_utils import get_huggingface_llm_image_uri # noqa: F401
|
|
18
|
+
from sagemaker.core.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
|
|
19
|
+
from sagemaker.core.huggingface.processing import HuggingFaceProcessor # noqa: F401
|
|
20
|
+
from sagemaker.core.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"HuggingFace",
|
|
24
|
+
"HuggingFaceModel",
|
|
25
|
+
"HuggingFacePredictor",
|
|
26
|
+
"HuggingFaceProcessor",
|
|
27
|
+
"TrainingCompilerConfig",
|
|
28
|
+
"get_huggingface_llm_image_uri",
|
|
29
|
+
]
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Optional
|
|
18
|
+
import importlib.util
|
|
19
|
+
|
|
20
|
+
import urllib.request
|
|
21
|
+
from urllib.error import HTTPError, URLError
|
|
22
|
+
import json
|
|
23
|
+
from json import JSONDecodeError
|
|
24
|
+
import logging
|
|
25
|
+
from sagemaker.core import image_uris
|
|
26
|
+
from sagemaker.core.helper.session_helper import Session
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_huggingface_llm_image_uri(
|
|
32
|
+
backend: str,
|
|
33
|
+
session: Optional[Session] = None,
|
|
34
|
+
region: Optional[str] = None,
|
|
35
|
+
version: Optional[str] = None,
|
|
36
|
+
) -> str:
|
|
37
|
+
"""Retrieves the image URI for inference.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
backend (str): The backend to use. Valid values include "huggingface" and "lmi".
|
|
41
|
+
session (Session): The SageMaker Session to use. (Default: None).
|
|
42
|
+
region (str): The AWS region to use for image URI. (default: None).
|
|
43
|
+
version (str): The framework version for which to retrieve an
|
|
44
|
+
image URI. If no version is set, defaults to latest version. (default: None).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
str: The image URI string.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
if region is None:
|
|
51
|
+
if session is None:
|
|
52
|
+
region = Session().boto_session.region_name
|
|
53
|
+
else:
|
|
54
|
+
region = session.boto_session.region_name
|
|
55
|
+
if backend == "huggingface":
|
|
56
|
+
return image_uris.retrieve(
|
|
57
|
+
"huggingface-llm",
|
|
58
|
+
region=region,
|
|
59
|
+
version=version,
|
|
60
|
+
image_scope="inference",
|
|
61
|
+
)
|
|
62
|
+
if backend == "huggingface-neuronx":
|
|
63
|
+
return image_uris.retrieve(
|
|
64
|
+
"huggingface-llm-neuronx",
|
|
65
|
+
region=region,
|
|
66
|
+
version=version,
|
|
67
|
+
image_scope="inference",
|
|
68
|
+
inference_tool="neuronx",
|
|
69
|
+
)
|
|
70
|
+
if backend == "huggingface-tei":
|
|
71
|
+
return image_uris.retrieve(
|
|
72
|
+
"huggingface-tei",
|
|
73
|
+
region=region,
|
|
74
|
+
version=version,
|
|
75
|
+
image_scope="inference",
|
|
76
|
+
)
|
|
77
|
+
if backend == "huggingface-tei-cpu":
|
|
78
|
+
return image_uris.retrieve(
|
|
79
|
+
"huggingface-tei-cpu",
|
|
80
|
+
region=region,
|
|
81
|
+
version=version,
|
|
82
|
+
image_scope="inference",
|
|
83
|
+
)
|
|
84
|
+
if backend == "lmi":
|
|
85
|
+
version = version or "0.24.0"
|
|
86
|
+
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
|
|
87
|
+
raise ValueError("Unsupported backend: %s" % backend)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = None) -> dict:
|
|
91
|
+
"""Retrieves the json metadata of the HuggingFace Model via HuggingFace API.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
model_id (str): The HuggingFace Model ID
|
|
95
|
+
hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
dict: The model metadata retrieved with the HuggingFace API
|
|
99
|
+
"""
|
|
100
|
+
if not model_id:
|
|
101
|
+
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
|
|
102
|
+
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
|
|
103
|
+
hf_model_metadata_json = None
|
|
104
|
+
try:
|
|
105
|
+
if hf_hub_token:
|
|
106
|
+
hf_model_metadata_url = urllib.request.Request(
|
|
107
|
+
hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
|
|
108
|
+
)
|
|
109
|
+
with urllib.request.urlopen(hf_model_metadata_url) as response:
|
|
110
|
+
hf_model_metadata_json = json.load(response)
|
|
111
|
+
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
|
|
112
|
+
if "HTTP Error 401: Unauthorized" in str(e):
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"Trying to access a gated/private HuggingFace model without valid credentials. "
|
|
115
|
+
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
|
|
116
|
+
)
|
|
117
|
+
logger.warning(
|
|
118
|
+
"Exception encountered while trying to retrieve HuggingFace model metadata %s. "
|
|
119
|
+
"Details: %s",
|
|
120
|
+
hf_model_metadata_url,
|
|
121
|
+
e,
|
|
122
|
+
)
|
|
123
|
+
if not hf_model_metadata_json:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
|
|
126
|
+
)
|
|
127
|
+
return hf_model_metadata_json
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def download_huggingface_model_metadata(
|
|
131
|
+
model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Downloads the HuggingFace Model snapshot via HuggingFace API.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
model_id (str): The HuggingFace Model ID
|
|
137
|
+
model_local_path (str): The local path to save the HuggingFace Model snapshot.
|
|
138
|
+
hf_hub_token (str): The HuggingFace Hub Token
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
ImportError: If huggingface_hub is not installed.
|
|
142
|
+
"""
|
|
143
|
+
if not importlib.util.find_spec("huggingface_hub"):
|
|
144
|
+
raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed")
|
|
145
|
+
|
|
146
|
+
from huggingface_hub import snapshot_download
|
|
147
|
+
|
|
148
|
+
os.makedirs(model_local_path, exist_ok=True)
|
|
149
|
+
logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path)
|
|
150
|
+
snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains code related to HuggingFace Processors which are used for Processing jobs.
|
|
14
|
+
|
|
15
|
+
These jobs let customers perform data pre-processing, post-processing, feature engineering,
|
|
16
|
+
data validation, and model evaluation and interpretation on SageMaker.
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
|
|
20
|
+
from typing import Union, Optional, List, Dict
|
|
21
|
+
|
|
22
|
+
from sagemaker.core.helper.session_helper import Session
|
|
23
|
+
from sagemaker.core.network import NetworkConfig
|
|
24
|
+
from sagemaker.core.processing import FrameworkProcessor
|
|
25
|
+
from sagemaker.core.huggingface.estimator import HuggingFace
|
|
26
|
+
|
|
27
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
28
|
+
from sagemaker.core.common_utils import format_tags, Tags
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class HuggingFaceProcessor(FrameworkProcessor):
|
|
32
|
+
"""Handles Amazon SageMaker processing tasks for jobs using HuggingFace containers."""
|
|
33
|
+
|
|
34
|
+
estimator_cls = HuggingFace
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
role: Optional[Union[str, PipelineVariable]] = None,
|
|
39
|
+
instance_count: Union[int, PipelineVariable] = None,
|
|
40
|
+
instance_type: Union[str, PipelineVariable] = None,
|
|
41
|
+
transformers_version: Optional[str] = None,
|
|
42
|
+
tensorflow_version: Optional[str] = None,
|
|
43
|
+
pytorch_version: Optional[str] = None,
|
|
44
|
+
py_version: str = "py36",
|
|
45
|
+
image_uri: Optional[Union[str, PipelineVariable]] = None,
|
|
46
|
+
command: Optional[List[str]] = None,
|
|
47
|
+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
|
|
48
|
+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
49
|
+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
50
|
+
code_location: Optional[str] = None,
|
|
51
|
+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
|
|
52
|
+
base_job_name: Optional[str] = None,
|
|
53
|
+
sagemaker_session: Optional[Session] = None,
|
|
54
|
+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
55
|
+
tags: Optional[Tags] = None,
|
|
56
|
+
network_config: Optional[NetworkConfig] = None,
|
|
57
|
+
):
|
|
58
|
+
"""This processor executes a Python script in a HuggingFace execution environment.
|
|
59
|
+
|
|
60
|
+
Unless ``image_uri`` is specified, the environment is an Amazon-built Docker container
|
|
61
|
+
that executes functions defined in the supplied ``code`` Python script.
|
|
62
|
+
|
|
63
|
+
The arguments have the same meaning as in ``FrameworkProcessor``, with the following
|
|
64
|
+
exceptions.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
transformers_version (str): Transformers version you want to use for
|
|
68
|
+
executing your model training code. Defaults to ``None``. Required unless
|
|
69
|
+
``image_uri`` is provided. The current supported version is ``4.4.2``.
|
|
70
|
+
tensorflow_version (str): TensorFlow version you want to use for
|
|
71
|
+
executing your model training code. Defaults to ``None``. Required unless
|
|
72
|
+
``pytorch_version`` is provided. The current supported version is ``2.4.1``.
|
|
73
|
+
pytorch_version (str): PyTorch version you want to use for
|
|
74
|
+
executing your model training code. Defaults to ``None``. Required unless
|
|
75
|
+
``tensorflow_version`` is provided. The current supported version is ``1.6.0``.
|
|
76
|
+
py_version (str): Python version you want to use for executing your model training
|
|
77
|
+
code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
|
|
78
|
+
using PyTorch, the current supported version is ``py36``. If using TensorFlow,
|
|
79
|
+
the current supported version is ``py37``.
|
|
80
|
+
|
|
81
|
+
.. tip::
|
|
82
|
+
|
|
83
|
+
You can find additional parameters for initializing this class at
|
|
84
|
+
:class:`~sagemaker.processing.FrameworkProcessor`.
|
|
85
|
+
"""
|
|
86
|
+
self.pytorch_version = pytorch_version
|
|
87
|
+
self.tensorflow_version = tensorflow_version
|
|
88
|
+
super().__init__(
|
|
89
|
+
self.estimator_cls,
|
|
90
|
+
transformers_version,
|
|
91
|
+
role,
|
|
92
|
+
instance_count,
|
|
93
|
+
instance_type,
|
|
94
|
+
py_version,
|
|
95
|
+
image_uri,
|
|
96
|
+
command,
|
|
97
|
+
volume_size_in_gb,
|
|
98
|
+
volume_kms_key,
|
|
99
|
+
output_kms_key,
|
|
100
|
+
code_location,
|
|
101
|
+
max_runtime_in_seconds,
|
|
102
|
+
base_job_name,
|
|
103
|
+
sagemaker_session,
|
|
104
|
+
env,
|
|
105
|
+
format_tags(tags),
|
|
106
|
+
network_config,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def _create_estimator(
|
|
110
|
+
self,
|
|
111
|
+
entry_point="",
|
|
112
|
+
source_dir=None,
|
|
113
|
+
dependencies=None,
|
|
114
|
+
git_config=None,
|
|
115
|
+
):
|
|
116
|
+
"""Override default estimator factory function for HuggingFace's different parameters
|
|
117
|
+
|
|
118
|
+
HuggingFace estimators have 3 framework version parameters instead of one: The version for
|
|
119
|
+
Transformers, PyTorch, and TensorFlow.
|
|
120
|
+
"""
|
|
121
|
+
return self.estimator_cls(
|
|
122
|
+
transformers_version=self.framework_version,
|
|
123
|
+
tensorflow_version=self.tensorflow_version,
|
|
124
|
+
pytorch_version=self.pytorch_version,
|
|
125
|
+
py_version=self.py_version,
|
|
126
|
+
entry_point=entry_point,
|
|
127
|
+
source_dir=source_dir,
|
|
128
|
+
dependencies=dependencies,
|
|
129
|
+
git_config=git_config,
|
|
130
|
+
code_location=self.code_location,
|
|
131
|
+
enable_network_isolation=False,
|
|
132
|
+
image_uri=self.image_uri,
|
|
133
|
+
role=self.role,
|
|
134
|
+
instance_count=self.instance_count,
|
|
135
|
+
instance_type=self.instance_type,
|
|
136
|
+
sagemaker_session=self.sagemaker_session,
|
|
137
|
+
debugger_hook_config=False,
|
|
138
|
+
disable_profiler=True,
|
|
139
|
+
)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Configuration for the SageMaker Training Compiler."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Union
|
|
17
|
+
from packaging.specifiers import SpecifierSet
|
|
18
|
+
from packaging.version import Version
|
|
19
|
+
|
|
20
|
+
from sagemaker.core.training_compiler.config import TrainingCompilerConfig as BaseConfig
|
|
21
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TrainingCompilerConfig(BaseConfig):
|
|
27
|
+
"""The SageMaker Training Compiler configuration class."""
|
|
28
|
+
|
|
29
|
+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
|
|
30
|
+
SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
|
|
31
|
+
"ml.g4dn.8xlarge",
|
|
32
|
+
"ml.g4dn.12xlarge",
|
|
33
|
+
"ml.g5.48xlarge",
|
|
34
|
+
"ml.p3dn.24xlarge",
|
|
35
|
+
"ml.p4d.24xlarge",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
enabled: Union[bool, PipelineVariable] = True,
|
|
41
|
+
debug: Union[bool, PipelineVariable] = False,
|
|
42
|
+
):
|
|
43
|
+
"""This class initializes a ``TrainingCompilerConfig`` instance.
|
|
44
|
+
|
|
45
|
+
`Amazon SageMaker Training Compiler
|
|
46
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
|
|
47
|
+
is a feature of SageMaker Training
|
|
48
|
+
and speeds up training jobs by optimizing model execution graphs.
|
|
49
|
+
|
|
50
|
+
You can compile Hugging Face models
|
|
51
|
+
by passing the object of this configuration class to the ``compiler_config``
|
|
52
|
+
parameter of the :class:`~sagemaker.huggingface.HuggingFace`
|
|
53
|
+
estimator.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker
|
|
57
|
+
Training Compiler. The default is ``True``.
|
|
58
|
+
debug (bool or PipelineVariable): Optional. Whether to dump detailed logs
|
|
59
|
+
for debugging. This comes with a potential performance slowdown.
|
|
60
|
+
The default is ``False``.
|
|
61
|
+
|
|
62
|
+
**Example**: The following code shows the basic usage of the
|
|
63
|
+
:class:`sagemaker.huggingface.TrainingCompilerConfig()` class
|
|
64
|
+
to run a HuggingFace training job with the compiler.
|
|
65
|
+
|
|
66
|
+
.. code-block:: python
|
|
67
|
+
|
|
68
|
+
from sagemaker.core.huggingface import HuggingFace, TrainingCompilerConfig
|
|
69
|
+
|
|
70
|
+
huggingface_estimator=HuggingFace(
|
|
71
|
+
...
|
|
72
|
+
compiler_config=TrainingCompilerConfig()
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
.. seealso::
|
|
76
|
+
|
|
77
|
+
For more information about how to enable SageMaker Training Compiler
|
|
78
|
+
for various training settings such as using TensorFlow-based models,
|
|
79
|
+
PyTorch-based models, and distributed training,
|
|
80
|
+
see `Enable SageMaker Training Compiler
|
|
81
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
|
|
82
|
+
in the `Amazon SageMaker Training Compiler developer guide
|
|
83
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def validate(cls, estimator):
|
|
91
|
+
"""Checks if SageMaker Training Compiler is configured correctly.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
estimator (:class:`sagemaker.huggingface.HuggingFace`): An estimator object.
|
|
95
|
+
If SageMaker Training Compiler is enabled, it will validate whether
|
|
96
|
+
the estimator is configured to be compatible with Training Compiler.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: Raised if the requested configuration is not compatible
|
|
100
|
+
with SageMaker Training Compiler.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
super(TrainingCompilerConfig, cls).validate(estimator)
|
|
104
|
+
|
|
105
|
+
if estimator.pytorch_version:
|
|
106
|
+
if (Version(estimator.pytorch_version) in SpecifierSet("< 1.9")) or (
|
|
107
|
+
Version(estimator.pytorch_version) in SpecifierSet("> 1.11")
|
|
108
|
+
):
|
|
109
|
+
error_helper_string = (
|
|
110
|
+
"SageMaker Training Compiler is only supported "
|
|
111
|
+
"with HuggingFace PyTorch 1.9-1.11. "
|
|
112
|
+
"Received pytorch_version={} which is unsupported."
|
|
113
|
+
)
|
|
114
|
+
raise ValueError(error_helper_string.format(estimator.pytorch_version))
|
|
115
|
+
|
|
116
|
+
if estimator.image_uri:
|
|
117
|
+
error_helper_string = (
|
|
118
|
+
"Overriding the image URI is currently not supported "
|
|
119
|
+
"for SageMaker Training Compiler."
|
|
120
|
+
"Specify the following parameters to run the Hugging Face training job "
|
|
121
|
+
"with SageMaker Training Compiler enabled: "
|
|
122
|
+
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
|
|
123
|
+
)
|
|
124
|
+
raise ValueError(error_helper_string)
|
|
125
|
+
|
|
126
|
+
if estimator.distribution:
|
|
127
|
+
pt_xla_present = "pytorchxla" in estimator.distribution
|
|
128
|
+
pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
|
|
129
|
+
if pt_xla_enabled:
|
|
130
|
+
if estimator.tensorflow_version:
|
|
131
|
+
error_helper_string = (
|
|
132
|
+
"Distribution mechanism 'pytorchxla' is currently only supported for "
|
|
133
|
+
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled. Received "
|
|
134
|
+
"tensorflow_version={} which is unsupported."
|
|
135
|
+
)
|
|
136
|
+
raise ValueError(error_helper_string.format(estimator.tensorflow_version))
|
|
137
|
+
if estimator.pytorch_version:
|
|
138
|
+
if Version(estimator.pytorch_version) in SpecifierSet("< 1.11"):
|
|
139
|
+
error_helper_string = (
|
|
140
|
+
"Distribution mechanism 'pytorchxla' is currently only supported for "
|
|
141
|
+
"PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
|
|
142
|
+
" Received pytorch_version={} which is unsupported."
|
|
143
|
+
)
|
|
144
|
+
raise ValueError(error_helper_string.format(estimator.pytorch_version))
|
|
145
|
+
if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
|
|
146
|
+
logger.warning(
|
|
147
|
+
"Consider using instances with EFA support when "
|
|
148
|
+
"training with PyTorch >= 1.11 and SageMaker Training Compiler "
|
|
149
|
+
"enabled. SageMaker Training Compiler leverages EFA to provide better "
|
|
150
|
+
"performance for distributed training."
|
|
151
|
+
)
|
|
152
|
+
if not pt_xla_present:
|
|
153
|
+
if estimator.pytorch_version:
|
|
154
|
+
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
|
|
155
|
+
error_helper_string = (
|
|
156
|
+
"'pytorchxla' is the only distribution mechanism currently supported "
|
|
157
|
+
"for PyTorch >= 1.11 when SageMaker Training Compiler is enabled."
|
|
158
|
+
" Received distribution={} which is unsupported."
|
|
159
|
+
)
|
|
160
|
+
raise ValueError(error_helper_string.format(estimator.distribution))
|
|
161
|
+
elif estimator.instance_count and estimator.instance_count > 1:
|
|
162
|
+
if estimator.pytorch_version:
|
|
163
|
+
if Version(estimator.pytorch_version) in SpecifierSet(">= 1.11"):
|
|
164
|
+
logger.warning(
|
|
165
|
+
"Consider setting 'distribution' to 'pytorchxla' for distributed "
|
|
166
|
+
"training with PyTorch >= 1.11 and SageMaker Training Compiler enabled."
|
|
167
|
+
)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Accessors to retrieve hyperparameters for training jobs."""
|
|
14
|
+
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Dict, Optional
|
|
19
|
+
|
|
20
|
+
from sagemaker.core.jumpstart import utils as jumpstart_utils
|
|
21
|
+
from sagemaker.core.jumpstart import artifacts
|
|
22
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
23
|
+
from sagemaker.core.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType
|
|
24
|
+
from sagemaker.core.jumpstart.validators import validate_hyperparameters
|
|
25
|
+
from sagemaker.core.helper.session_helper import Session
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def retrieve_default(
|
|
31
|
+
region: Optional[str] = None,
|
|
32
|
+
model_id: Optional[str] = None,
|
|
33
|
+
model_version: Optional[str] = None,
|
|
34
|
+
hub_arn: Optional[str] = None,
|
|
35
|
+
instance_type: Optional[str] = None,
|
|
36
|
+
include_container_hyperparameters: bool = False,
|
|
37
|
+
tolerate_vulnerable_model: bool = False,
|
|
38
|
+
tolerate_deprecated_model: bool = False,
|
|
39
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
40
|
+
config_name: Optional[str] = None,
|
|
41
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
42
|
+
) -> Dict[str, str]:
|
|
43
|
+
"""Retrieves the default training hyperparameters for the model matching the given arguments.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
region (str): The AWS Region for which to retrieve the default hyperparameters.
|
|
47
|
+
Defaults to ``None``.
|
|
48
|
+
model_id (str): The model ID of the model for which to
|
|
49
|
+
retrieve the default hyperparameters. (Default: None).
|
|
50
|
+
model_version (str): The version of the model for which to retrieve the
|
|
51
|
+
default hyperparameters. (Default: None).
|
|
52
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
53
|
+
model details from. (default: None).
|
|
54
|
+
instance_type (str): An instance type to optionally supply in order to get hyperparameters
|
|
55
|
+
specific for the instance type.
|
|
56
|
+
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
|
|
57
|
+
should be returned. Container hyperparameters are not used to tune
|
|
58
|
+
the specific algorithm. They are used by SageMaker Training jobs to set up
|
|
59
|
+
the training container environment. For example, there is a container hyperparameter
|
|
60
|
+
that indicates the entrypoint script to use. These hyperparameters may be required
|
|
61
|
+
when creating a training job with boto3, however the ``Estimator`` classes
|
|
62
|
+
add required container hyperparameters to the job. (Default: False).
|
|
63
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
64
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
65
|
+
exception if the script used by this version of the model has dependencies with known
|
|
66
|
+
security vulnerabilities. (Default: False).
|
|
67
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
68
|
+
(exception not raised). False if these models should raise an exception.
|
|
69
|
+
(Default: False).
|
|
70
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
71
|
+
object, used for SageMaker interactions. If not
|
|
72
|
+
specified, one is created using the default AWS configuration
|
|
73
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
74
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
75
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
76
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
77
|
+
Returns:
|
|
78
|
+
dict: The hyperparameters to use for the model.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
82
|
+
"""
|
|
83
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving hyperparameters."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return artifacts._retrieve_default_hyperparameters(
|
|
89
|
+
model_id=model_id,
|
|
90
|
+
model_version=model_version,
|
|
91
|
+
hub_arn=hub_arn,
|
|
92
|
+
instance_type=instance_type,
|
|
93
|
+
region=region,
|
|
94
|
+
include_container_hyperparameters=include_container_hyperparameters,
|
|
95
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
96
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
97
|
+
sagemaker_session=sagemaker_session,
|
|
98
|
+
config_name=config_name,
|
|
99
|
+
model_type=model_type,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def validate(
|
|
104
|
+
region: Optional[str] = None,
|
|
105
|
+
model_id: Optional[str] = None,
|
|
106
|
+
hub_arn: Optional[str] = None,
|
|
107
|
+
model_version: Optional[str] = None,
|
|
108
|
+
hyperparameters: Optional[dict] = None,
|
|
109
|
+
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
|
|
110
|
+
tolerate_vulnerable_model: bool = False,
|
|
111
|
+
tolerate_deprecated_model: bool = False,
|
|
112
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Validates hyperparameters for models.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
region (str): The AWS Region for which to validate hyperparameters. (Default: None).
|
|
118
|
+
model_id (str): The model ID of the model for which to validate hyperparameters.
|
|
119
|
+
(Default: None).
|
|
120
|
+
model_version (str): The version of the model for which to validate hyperparameters.
|
|
121
|
+
(Default: None).
|
|
122
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
123
|
+
model details from. (default: None).
|
|
124
|
+
hyperparameters (dict): Hyperparameters to validate.
|
|
125
|
+
(Default: None).
|
|
126
|
+
validation_mode (HyperparameterValidationMode): Method of validation to use with
|
|
127
|
+
hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided
|
|
128
|
+
to this function will be validated, the missing hyperparameters will be ignored.
|
|
129
|
+
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
|
|
130
|
+
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
|
|
131
|
+
(Default: None).
|
|
132
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
133
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
134
|
+
exception if the script used by this version of the model has dependencies with known
|
|
135
|
+
security vulnerabilities. (Default: False).
|
|
136
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
137
|
+
(exception not raised). False if these models should raise an exception.
|
|
138
|
+
(Default: False).
|
|
139
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
140
|
+
object, used for SageMaker interactions. If not
|
|
141
|
+
specified, one is created using the default AWS configuration
|
|
142
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
|
|
146
|
+
according to its specs in the model metadata.
|
|
147
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
148
|
+
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"Must specify JumpStart `model_id` and `model_version` when validating hyperparameters."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if model_id is None or model_version is None:
|
|
157
|
+
raise RuntimeError("Model ID and version must both be non-None")
|
|
158
|
+
|
|
159
|
+
if hyperparameters is None:
|
|
160
|
+
raise ValueError("Must specify hyperparameters.")
|
|
161
|
+
|
|
162
|
+
return validate_hyperparameters(
|
|
163
|
+
model_id=model_id,
|
|
164
|
+
model_version=model_version,
|
|
165
|
+
hub_arn=hub_arn,
|
|
166
|
+
hyperparameters=hyperparameters,
|
|
167
|
+
validation_mode=validation_mode,
|
|
168
|
+
region=region,
|
|
169
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
170
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
171
|
+
sagemaker_session=sagemaker_session,
|
|
172
|
+
)
|