sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,511 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from typing import Optional
|
|
20
|
+
from packaging.version import Version
|
|
21
|
+
import requests
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from sagemaker.core.serverless_inference_config import ServerlessInferenceConfig
|
|
25
|
+
from sagemaker.core.training_compiler_config import TrainingCompilerConfig
|
|
26
|
+
from sagemaker.core.fw_utils import (
|
|
27
|
+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
|
|
28
|
+
GRAVITON_ALLOWED_FRAMEWORKS,
|
|
29
|
+
)
|
|
30
|
+
from sagemaker.core.common_utils import _botocore_resolver, get_instance_type_family
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
|
|
35
|
+
HUGGING_FACE_FRAMEWORK = "huggingface"
|
|
36
|
+
HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm"
|
|
37
|
+
HUGGING_FACE_TEI_GPU_FRAMEWORK = "huggingface-tei"
|
|
38
|
+
HUGGING_FACE_TEI_CPU_FRAMEWORK = "huggingface-tei-cpu"
|
|
39
|
+
HUGGING_FACE_LLM_NEURONX_FRAMEWORK = "huggingface-llm-neuronx"
|
|
40
|
+
XGBOOST_FRAMEWORK = "xgboost"
|
|
41
|
+
SKLEARN_FRAMEWORK = "sklearn"
|
|
42
|
+
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
|
|
43
|
+
INFERENCE_GRAVITON = "inference_graviton"
|
|
44
|
+
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
|
|
45
|
+
STABILITYAI_FRAMEWORK = "stabilityai"
|
|
46
|
+
SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_image_tag(
|
|
50
|
+
container_version,
|
|
51
|
+
distributed,
|
|
52
|
+
final_image_scope,
|
|
53
|
+
framework,
|
|
54
|
+
inference_tool,
|
|
55
|
+
instance_type,
|
|
56
|
+
processor,
|
|
57
|
+
py_version,
|
|
58
|
+
tag_prefix,
|
|
59
|
+
version,
|
|
60
|
+
):
|
|
61
|
+
"""Return image tag based on framework, container, and compute configuration(s)."""
|
|
62
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
63
|
+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
|
|
64
|
+
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
|
|
65
|
+
_validate_arg(
|
|
66
|
+
instance_type_family,
|
|
67
|
+
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
|
|
68
|
+
"instance type",
|
|
69
|
+
)
|
|
70
|
+
if (
|
|
71
|
+
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
|
|
72
|
+
or final_image_scope == INFERENCE_GRAVITON
|
|
73
|
+
):
|
|
74
|
+
version_to_arm64_tag_mapping = {
|
|
75
|
+
"xgboost": {
|
|
76
|
+
"1.5-1": "1.5-1-arm64",
|
|
77
|
+
"1.3-1": "1.3-1-arm64",
|
|
78
|
+
},
|
|
79
|
+
"sklearn": {
|
|
80
|
+
"1.0-1": "1.0-1-arm64-cpu-py3",
|
|
81
|
+
},
|
|
82
|
+
}
|
|
83
|
+
tag = version_to_arm64_tag_mapping[framework][version]
|
|
84
|
+
else:
|
|
85
|
+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
|
|
86
|
+
else:
|
|
87
|
+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
|
|
88
|
+
|
|
89
|
+
if instance_type is not None and _should_auto_select_container_version(
|
|
90
|
+
instance_type, distributed
|
|
91
|
+
):
|
|
92
|
+
container_versions = {
|
|
93
|
+
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
|
|
94
|
+
"tensorflow-2.3.1-gpu-py37": "cu110-ubuntu18.04",
|
|
95
|
+
"tensorflow-2.3.2-gpu-py37": "cu110-ubuntu18.04",
|
|
96
|
+
"tensorflow-1.15-gpu-py37": "cu110-ubuntu18.04-v8",
|
|
97
|
+
"tensorflow-1.15.4-gpu-py37": "cu110-ubuntu18.04",
|
|
98
|
+
"tensorflow-1.15.5-gpu-py37": "cu110-ubuntu18.04",
|
|
99
|
+
"mxnet-1.8-gpu-py37": "cu110-ubuntu16.04-v1",
|
|
100
|
+
"mxnet-1.8.0-gpu-py37": "cu110-ubuntu16.04",
|
|
101
|
+
"pytorch-1.6-gpu-py36": "cu110-ubuntu18.04-v3",
|
|
102
|
+
"pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04",
|
|
103
|
+
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
|
|
104
|
+
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
|
|
105
|
+
}
|
|
106
|
+
key = "-".join([framework, tag])
|
|
107
|
+
if key in container_versions:
|
|
108
|
+
tag = "-".join([tag, container_versions[key]])
|
|
109
|
+
|
|
110
|
+
# Triton images don't have a trailing -gpu tag. Only -cpu images do.
|
|
111
|
+
if framework == SAGEMAKER_TRITONSERVER_FRAMEWORK:
|
|
112
|
+
if processor == "gpu":
|
|
113
|
+
tag = tag.rstrip("-gpu")
|
|
114
|
+
|
|
115
|
+
return tag
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _config_for_framework_and_scope(framework, image_scope, accelerator_type=None):
|
|
119
|
+
"""Loads the JSON config for the given framework and image scope."""
|
|
120
|
+
config = config_for_framework(framework)
|
|
121
|
+
|
|
122
|
+
if accelerator_type:
|
|
123
|
+
_validate_accelerator_type(accelerator_type)
|
|
124
|
+
|
|
125
|
+
if image_scope not in ("eia", "inference"):
|
|
126
|
+
logger.warning(
|
|
127
|
+
"Elastic inference is for inference only. Ignoring image scope: %s.",
|
|
128
|
+
image_scope,
|
|
129
|
+
)
|
|
130
|
+
image_scope = "eia"
|
|
131
|
+
|
|
132
|
+
available_scopes = config.get("scope", list(config.keys()))
|
|
133
|
+
|
|
134
|
+
if len(available_scopes) == 1:
|
|
135
|
+
if image_scope and image_scope != available_scopes[0]:
|
|
136
|
+
logger.warning(
|
|
137
|
+
"Defaulting to only supported image scope: %s. Ignoring image scope: %s.",
|
|
138
|
+
available_scopes[0],
|
|
139
|
+
image_scope,
|
|
140
|
+
)
|
|
141
|
+
image_scope = available_scopes[0]
|
|
142
|
+
|
|
143
|
+
if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
|
|
144
|
+
logger.info(
|
|
145
|
+
"Same images used for training and inference. Defaulting to image scope: %s.",
|
|
146
|
+
available_scopes[0],
|
|
147
|
+
)
|
|
148
|
+
image_scope = available_scopes[0]
|
|
149
|
+
|
|
150
|
+
_validate_arg(image_scope, available_scopes, "image scope")
|
|
151
|
+
return config if "scope" in config else config[image_scope]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _validate_instance_deprecation(framework, instance_type, version):
|
|
155
|
+
"""Check if instance type is deprecated for a certain framework with a certain version"""
|
|
156
|
+
if get_instance_type_family(instance_type) == "p2":
|
|
157
|
+
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
|
|
158
|
+
framework == "tensorflow" and Version(version) >= Version("2.12")
|
|
159
|
+
):
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
|
|
162
|
+
"For information about supported instance types please refer to "
|
|
163
|
+
"https://aws.amazon.com/sagemaker/pricing/"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _validate_for_suppported_frameworks_and_instance_type(framework, instance_type):
|
|
168
|
+
"""Validate if framework is supported for the instance_type"""
|
|
169
|
+
# Validate for Trainium allowed frameworks
|
|
170
|
+
if (
|
|
171
|
+
instance_type is not None
|
|
172
|
+
and "trn" in instance_type
|
|
173
|
+
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
|
|
174
|
+
):
|
|
175
|
+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")
|
|
176
|
+
|
|
177
|
+
# Validate for Graviton allowed frameowrks
|
|
178
|
+
if (
|
|
179
|
+
instance_type is not None
|
|
180
|
+
and get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
|
|
181
|
+
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
|
|
182
|
+
):
|
|
183
|
+
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def config_for_framework(framework):
|
|
187
|
+
"""Loads the JSON config for the given framework."""
|
|
188
|
+
response = requests.get(s3_url)
|
|
189
|
+
return response.json()
|
|
190
|
+
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
|
|
191
|
+
with open(fname) as f:
|
|
192
|
+
return json.load(f)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _get_final_image_scope(framework, instance_type, image_scope):
|
|
196
|
+
"""Return final image scope based on provided framework and instance type."""
|
|
197
|
+
if (
|
|
198
|
+
framework in GRAVITON_ALLOWED_FRAMEWORKS
|
|
199
|
+
and get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
|
|
200
|
+
):
|
|
201
|
+
return INFERENCE_GRAVITON
|
|
202
|
+
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
|
|
203
|
+
# Preserves backwards compatibility with XGB/SKLearn configs which no
|
|
204
|
+
# longer define top-level "scope" keys after introducing support for
|
|
205
|
+
# Graviton inference. Training and inference configs for XGB/SKLearn are
|
|
206
|
+
# identical, so default to training.
|
|
207
|
+
return "training"
|
|
208
|
+
return image_scope
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _get_inference_tool(inference_tool, instance_type):
|
|
212
|
+
"""Extract the inference tool name from instance type."""
|
|
213
|
+
if not inference_tool:
|
|
214
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
215
|
+
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
|
|
216
|
+
return "neuron"
|
|
217
|
+
return inference_tool
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _get_latest_versions(list_of_versions):
|
|
221
|
+
"""Extract the latest version from the input list of available versions."""
|
|
222
|
+
return sorted(list_of_versions, reverse=True)[0]
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _get_latest_version(framework, version, image_scope):
|
|
226
|
+
"""Get the latest version from the input framework"""
|
|
227
|
+
if version:
|
|
228
|
+
return version
|
|
229
|
+
try:
|
|
230
|
+
framework_config = config_for_framework(framework)
|
|
231
|
+
except FileNotFoundError:
|
|
232
|
+
raise ValueError("Invalid framework {}".format(framework))
|
|
233
|
+
|
|
234
|
+
if not framework_config:
|
|
235
|
+
raise ValueError("Invalid framework {}".format(framework))
|
|
236
|
+
|
|
237
|
+
if not version:
|
|
238
|
+
version = _fetch_latest_version_from_config(framework_config, image_scope)
|
|
239
|
+
return version
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _validate_accelerator_type(accelerator_type):
|
|
243
|
+
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
|
|
244
|
+
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"Invalid SageMaker Elastic Inference accelerator type: {}. "
|
|
247
|
+
"See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(accelerator_type)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _validate_version_and_set_if_needed(version, config, framework, image_scope):
|
|
252
|
+
"""Checks if the framework/algorithm version is one of the supported versions."""
|
|
253
|
+
if not config:
|
|
254
|
+
config = config_for_framework(framework)
|
|
255
|
+
available_versions = list(config["versions"].keys())
|
|
256
|
+
aliased_versions = list(config.get("version_aliases", {}).keys())
|
|
257
|
+
if len(available_versions) == 1 and version not in aliased_versions:
|
|
258
|
+
return available_versions[0]
|
|
259
|
+
if not version:
|
|
260
|
+
version = _get_latest_version(framework, version, image_scope)
|
|
261
|
+
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
|
|
262
|
+
return version
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _version_for_config(version, config):
|
|
266
|
+
"""Returns the version string for retrieving a framework version's specific config."""
|
|
267
|
+
if "version_aliases" in config:
|
|
268
|
+
if version in config["version_aliases"].keys():
|
|
269
|
+
return config["version_aliases"][version]
|
|
270
|
+
|
|
271
|
+
return version
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _registry_from_region(region, registry_dict):
|
|
275
|
+
"""Returns the ECR registry (AWS account number) for the given region."""
|
|
276
|
+
_validate_arg(region, registry_dict.keys(), "region")
|
|
277
|
+
return registry_dict[region]
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _processor(instance_type, available_processors, serverless_inference_config=None):
|
|
281
|
+
"""Returns the processor type for the given instance type."""
|
|
282
|
+
if not available_processors:
|
|
283
|
+
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
if len(available_processors) == 1 and not instance_type:
|
|
287
|
+
logger.info("Defaulting to only supported image scope: %s.", available_processors[0])
|
|
288
|
+
return available_processors[0]
|
|
289
|
+
|
|
290
|
+
if serverless_inference_config is not None:
|
|
291
|
+
logger.info("Defaulting to CPU type when using serverless inference")
|
|
292
|
+
return "cpu"
|
|
293
|
+
|
|
294
|
+
if not instance_type:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
"Empty SageMaker instance type. For options, see: "
|
|
297
|
+
"https://aws.amazon.com/sagemaker/pricing/instance-types"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if instance_type.startswith("local"):
|
|
301
|
+
processor = "cpu" if instance_type == "local" else "gpu"
|
|
302
|
+
elif instance_type.startswith("neuron"):
|
|
303
|
+
processor = "neuron"
|
|
304
|
+
else:
|
|
305
|
+
# looks for either "ml.<family>.<size>" or "ml_<family>"
|
|
306
|
+
family = get_instance_type_family(instance_type)
|
|
307
|
+
if family:
|
|
308
|
+
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
|
|
309
|
+
# In those cases, we use the family name in the image tag. In other cases, we use
|
|
310
|
+
# 'cpu' or 'gpu'.
|
|
311
|
+
if family in available_processors:
|
|
312
|
+
processor = family
|
|
313
|
+
elif family.startswith("inf"):
|
|
314
|
+
processor = "inf"
|
|
315
|
+
elif family.startswith("trn"):
|
|
316
|
+
processor = "trn"
|
|
317
|
+
elif family[0] in ("g", "p"):
|
|
318
|
+
processor = "gpu"
|
|
319
|
+
else:
|
|
320
|
+
processor = "cpu"
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError(
|
|
323
|
+
"Invalid SageMaker instance type: {}. For options, see: "
|
|
324
|
+
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
_validate_arg(processor, available_processors, "processor")
|
|
328
|
+
return processor
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _should_auto_select_container_version(instance_type, distributed):
|
|
332
|
+
"""Returns a boolean that indicates whether to use an auto-selected container version."""
|
|
333
|
+
p4d = False
|
|
334
|
+
if instance_type:
|
|
335
|
+
# looks for either "ml.<family>.<size>" or "ml_<family>"
|
|
336
|
+
family = get_instance_type_family(instance_type)
|
|
337
|
+
if family:
|
|
338
|
+
p4d = family == "p4d"
|
|
339
|
+
|
|
340
|
+
return p4d or distributed
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _validate_py_version_and_set_if_needed(py_version, version_config, framework):
|
|
344
|
+
"""Checks if the Python version is one of the supported versions."""
|
|
345
|
+
if "repository" in version_config:
|
|
346
|
+
available_versions = version_config.get("py_versions")
|
|
347
|
+
else:
|
|
348
|
+
available_versions = list(version_config.keys())
|
|
349
|
+
|
|
350
|
+
if not available_versions:
|
|
351
|
+
if py_version:
|
|
352
|
+
logger.info("Ignoring unnecessary Python version: %s.", py_version)
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
if py_version is None and defaults.SPARK_NAME == framework:
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
if py_version is None and len(available_versions) == 1:
|
|
359
|
+
logger.info("Defaulting to only available Python version: %s", available_versions[0])
|
|
360
|
+
return available_versions[0]
|
|
361
|
+
|
|
362
|
+
_validate_arg(py_version, available_versions, "Python version")
|
|
363
|
+
return py_version
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _validate_arg(arg, available_options, arg_name):
|
|
367
|
+
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
|
|
368
|
+
if arg not in available_options:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
|
|
371
|
+
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
|
|
372
|
+
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options))
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
|
|
377
|
+
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
|
|
378
|
+
if framework not in allowed_frameworks:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Unsupported {arg_name}: {framework}. "
|
|
381
|
+
f"Supported {arg_name}(s) for {hardware_name} instances: {allowed_frameworks}."
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
|
|
386
|
+
"""Creates a tag for the image URI."""
|
|
387
|
+
if inference_tool:
|
|
388
|
+
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
|
|
389
|
+
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _fetch_latest_version_from_config( # pylint: disable=R0911
|
|
393
|
+
framework_config: dict, image_scope: Optional[str] = None
|
|
394
|
+
) -> Optional[str]:
|
|
395
|
+
"""Helper function to fetch the latest version as a string from a framework's config
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
framework_config (dict): A framework config dict.
|
|
399
|
+
image_scope (str): Scope of the image, eg: training, inference
|
|
400
|
+
Returns:
|
|
401
|
+
Version string if latest version found else None
|
|
402
|
+
"""
|
|
403
|
+
if image_scope in framework_config:
|
|
404
|
+
if image_scope_config := framework_config[image_scope]:
|
|
405
|
+
if "version_aliases" in image_scope_config:
|
|
406
|
+
if "latest" in image_scope_config["version_aliases"]:
|
|
407
|
+
return image_scope_config["version_aliases"]["latest"]
|
|
408
|
+
top_version = None
|
|
409
|
+
bottom_version = None
|
|
410
|
+
|
|
411
|
+
if "versions" in framework_config:
|
|
412
|
+
versions = list(framework_config["versions"].keys())
|
|
413
|
+
if len(versions) == 1:
|
|
414
|
+
return versions[0]
|
|
415
|
+
top_version = versions[0]
|
|
416
|
+
bottom_version = versions[-1]
|
|
417
|
+
if top_version == "latest" or bottom_version == "latest":
|
|
418
|
+
return None
|
|
419
|
+
elif (
|
|
420
|
+
image_scope is not None
|
|
421
|
+
and image_scope in framework_config
|
|
422
|
+
and "versions" in framework_config[image_scope]
|
|
423
|
+
):
|
|
424
|
+
versions = list(framework_config[image_scope]["versions"].keys())
|
|
425
|
+
top_version = versions[0]
|
|
426
|
+
bottom_version = versions[-1]
|
|
427
|
+
elif "processing" in framework_config and "versions" in framework_config["processing"]:
|
|
428
|
+
versions = list(framework_config["processing"]["versions"].keys())
|
|
429
|
+
top_version = versions[0]
|
|
430
|
+
bottom_version = versions[-1]
|
|
431
|
+
if top_version and bottom_version:
|
|
432
|
+
if top_version.endswith(".x") or bottom_version.endswith(".x"):
|
|
433
|
+
top_number = int(top_version[:-2])
|
|
434
|
+
bottom_number = int(bottom_version[:-2])
|
|
435
|
+
max_version = max(top_number, bottom_number)
|
|
436
|
+
return f"{max_version}.x"
|
|
437
|
+
if Version(top_version) >= Version(bottom_version):
|
|
438
|
+
return top_version
|
|
439
|
+
return bottom_version
|
|
440
|
+
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _retrieve_pytorch_uri_inputs_are_all_default(
|
|
445
|
+
version: Optional[str] = None,
|
|
446
|
+
py_version: Optional[str] = None,
|
|
447
|
+
instance_type: Optional[str] = None,
|
|
448
|
+
accelerator_type: Optional[str] = None,
|
|
449
|
+
image_scope: Optional[str] = None,
|
|
450
|
+
container_version: str = None,
|
|
451
|
+
distributed: bool = False,
|
|
452
|
+
smp: bool = False,
|
|
453
|
+
training_compiler_config: TrainingCompilerConfig = None,
|
|
454
|
+
sdk_version: Optional[str] = None,
|
|
455
|
+
inference_tool: Optional[str] = None,
|
|
456
|
+
serverless_inference_config: ServerlessInferenceConfig = None,
|
|
457
|
+
) -> bool:
|
|
458
|
+
"""
|
|
459
|
+
Determine if the inputs for _retrieve_pytorch_uri() are all default values.
|
|
460
|
+
"""
|
|
461
|
+
return (
|
|
462
|
+
not version
|
|
463
|
+
and not py_version
|
|
464
|
+
and not instance_type
|
|
465
|
+
and not accelerator_type
|
|
466
|
+
and not image_scope
|
|
467
|
+
and not container_version
|
|
468
|
+
and not distributed
|
|
469
|
+
and not smp
|
|
470
|
+
and not training_compiler_config
|
|
471
|
+
and not sdk_version
|
|
472
|
+
and not inference_tool
|
|
473
|
+
and not serverless_inference_config
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _retrieve_latest_pytorch_training_uri(region: str):
|
|
478
|
+
"""
|
|
479
|
+
Retrive the URI for the latest PyTorch training image for CPU
|
|
480
|
+
"""
|
|
481
|
+
config = config_for_framework("pytorch")
|
|
482
|
+
image_scope = "training"
|
|
483
|
+
|
|
484
|
+
latest_version = _fetch_latest_version_from_config(config, image_scope)
|
|
485
|
+
version_config = config[image_scope]["versions"][latest_version]
|
|
486
|
+
py_version = _validate_py_version_and_set_if_needed(None, version_config, None)
|
|
487
|
+
|
|
488
|
+
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
|
|
489
|
+
if region == "il-central-1" and not endpoint_data:
|
|
490
|
+
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
|
|
491
|
+
|
|
492
|
+
registry = _registry_from_region(region, version_config["registries"])
|
|
493
|
+
hostname = endpoint_data["hostname"]
|
|
494
|
+
repo = version_config["repository"]
|
|
495
|
+
|
|
496
|
+
tag = _get_image_tag(
|
|
497
|
+
container_version="ec2",
|
|
498
|
+
distributed=False,
|
|
499
|
+
final_image_scope=image_scope,
|
|
500
|
+
framework="pytorch",
|
|
501
|
+
inference_tool="",
|
|
502
|
+
instance_type="",
|
|
503
|
+
processor="cpu",
|
|
504
|
+
py_version=py_version,
|
|
505
|
+
tag_prefix=latest_version,
|
|
506
|
+
version=latest_version,
|
|
507
|
+
)
|
|
508
|
+
if tag:
|
|
509
|
+
repo += ":{}".format(tag)
|
|
510
|
+
|
|
511
|
+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Image URI configuration data."""
|