sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,2281 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Placeholder docstring"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import sys
|
|
17
|
+
import contextlib
|
|
18
|
+
import copy
|
|
19
|
+
import errno
|
|
20
|
+
import inspect
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import random
|
|
24
|
+
import re
|
|
25
|
+
import shutil
|
|
26
|
+
import tarfile
|
|
27
|
+
import tempfile
|
|
28
|
+
import time
|
|
29
|
+
from functools import lru_cache
|
|
30
|
+
from typing import Union, Any, List, Optional, Dict
|
|
31
|
+
import json
|
|
32
|
+
import abc
|
|
33
|
+
import uuid
|
|
34
|
+
from datetime import datetime
|
|
35
|
+
from os.path import abspath, realpath, dirname, normpath, join as joinpath
|
|
36
|
+
|
|
37
|
+
from importlib import import_module
|
|
38
|
+
|
|
39
|
+
import boto3
|
|
40
|
+
import botocore
|
|
41
|
+
from botocore.utils import merge_dicts
|
|
42
|
+
from botocore import exceptions
|
|
43
|
+
from botocore.exceptions import ClientError
|
|
44
|
+
from six.moves.urllib import parse
|
|
45
|
+
from six import viewitems
|
|
46
|
+
|
|
47
|
+
import sagemaker
|
|
48
|
+
|
|
49
|
+
from sagemaker.core.enums import RoutingStrategy
|
|
50
|
+
from sagemaker.core.session_settings import SessionSettings
|
|
51
|
+
from sagemaker.core.workflow import is_pipeline_variable, is_pipeline_parameter_string
|
|
52
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
53
|
+
from enum import Enum
|
|
54
|
+
|
|
55
|
+
ALTERNATE_DOMAINS = {
|
|
56
|
+
"cn-north-1": "amazonaws.com.cn",
|
|
57
|
+
"cn-northwest-1": "amazonaws.com.cn",
|
|
58
|
+
"us-iso-east-1": "c2s.ic.gov",
|
|
59
|
+
"us-isob-east-1": "sc2s.sgov.gov",
|
|
60
|
+
"us-isof-south-1": "csp.hci.ic.gov",
|
|
61
|
+
"us-isof-east-1": "csp.hci.ic.gov",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
|
|
65
|
+
MODEL_PACKAGE_ARN_PATTERN = (
|
|
66
|
+
r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)"
|
|
67
|
+
)
|
|
68
|
+
MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)"
|
|
69
|
+
MAX_BUCKET_PATHS_COUNT = 5
|
|
70
|
+
S3_PREFIX = "s3://"
|
|
71
|
+
HTTP_PREFIX = "http://"
|
|
72
|
+
HTTPS_PREFIX = "https://"
|
|
73
|
+
DEFAULT_SLEEP_TIME_SECONDS = 10
|
|
74
|
+
WAITING_DOT_NUMBER = 10
|
|
75
|
+
MAX_ITEMS = 100
|
|
76
|
+
PAGE_SIZE = 10
|
|
77
|
+
|
|
78
|
+
logger = logging.getLogger(__name__)
|
|
79
|
+
|
|
80
|
+
TagsDict = Dict[str, Union[str, PipelineVariable]]
|
|
81
|
+
Tags = Union[List[TagsDict], TagsDict]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ModelApprovalStatusEnum(str, Enum):
|
|
85
|
+
"""Model package approval status enumerator"""
|
|
86
|
+
|
|
87
|
+
APPROVED = "Approved"
|
|
88
|
+
REJECTED = "Rejected"
|
|
89
|
+
PENDING_MANUAL_APPROVAL = "PendingManualApproval"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Use the base name of the image as the job name if the user doesn't give us one
|
|
93
|
+
def name_from_image(image, max_length=63):
|
|
94
|
+
"""Create a training job name based on the image name and a timestamp.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
image (str): Image name.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
str: Training job name using the algorithm from the image name and a
|
|
101
|
+
timestamp.
|
|
102
|
+
max_length (int): Maximum length for the resulting string (default: 63).
|
|
103
|
+
"""
|
|
104
|
+
return name_from_base(base_name_from_image(image), max_length=max_length)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def name_from_base(base, max_length=63, short=False):
|
|
108
|
+
"""Append a timestamp to the provided string.
|
|
109
|
+
|
|
110
|
+
This function assures that the total length of the resulting string is
|
|
111
|
+
not longer than the specified max length, trimming the input parameter if
|
|
112
|
+
necessary.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
base (str): String used as prefix to generate the unique name.
|
|
116
|
+
max_length (int): Maximum length for the resulting string (default: 63).
|
|
117
|
+
short (bool): Whether or not to use a truncated timestamp (default: False).
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
str: Input parameter with appended timestamp.
|
|
121
|
+
"""
|
|
122
|
+
timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp()
|
|
123
|
+
trimmed_base = base[: max_length - len(timestamp) - 1]
|
|
124
|
+
return "{}-{}".format(trimmed_base, timestamp)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def unique_name_from_base_uuid4(base, max_length=63):
|
|
128
|
+
"""Append a UUID to the provided string.
|
|
129
|
+
|
|
130
|
+
This function is used to generate a name using UUID instead of timestamps
|
|
131
|
+
for uniqueness.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
base (str): String used as prefix to generate the unique name.
|
|
135
|
+
max_length (int): Maximum length for the resulting string (default: 63).
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
str: Input parameter with appended timestamp.
|
|
139
|
+
"""
|
|
140
|
+
random.seed(int(uuid.uuid4())) # using uuid to randomize
|
|
141
|
+
unique = str(uuid.uuid4())
|
|
142
|
+
trimmed_base = base[: max_length - len(unique) - 1]
|
|
143
|
+
return "{}-{}".format(trimmed_base, unique)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def unique_name_from_base(base, max_length=63):
|
|
147
|
+
"""Placeholder Docstring"""
|
|
148
|
+
random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used.
|
|
149
|
+
unique = "%04x" % random.randrange(16**4) # 4-digit hex
|
|
150
|
+
ts = str(int(time.time()))
|
|
151
|
+
available_length = max_length - 2 - len(ts) - len(unique)
|
|
152
|
+
trimmed = base[:available_length]
|
|
153
|
+
return "{}-{}-{}".format(trimmed, ts, unique)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def base_name_from_image(image, default_base_name=None):
|
|
157
|
+
"""Extract the base name of the image to use as the 'algorithm name' for the job.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
image (str): Image name.
|
|
161
|
+
default_base_name (str): The default base name
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: Algorithm name, as extracted from the image name.
|
|
165
|
+
"""
|
|
166
|
+
if is_pipeline_variable(image):
|
|
167
|
+
if is_pipeline_parameter_string(image) and image.default_value:
|
|
168
|
+
image_str = image.default_value
|
|
169
|
+
else:
|
|
170
|
+
return default_base_name if default_base_name else "base_name"
|
|
171
|
+
else:
|
|
172
|
+
image_str = image
|
|
173
|
+
|
|
174
|
+
m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
|
|
175
|
+
base_name = m.group(2) if m else image_str
|
|
176
|
+
return base_name
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def base_from_name(name):
|
|
180
|
+
"""Extract the base name of the resource name (for use with future resource name generation).
|
|
181
|
+
|
|
182
|
+
This function looks for timestamps that match the ones produced by
|
|
183
|
+
:func:`~sagemaker.utils.name_from_base`.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
name (str): The resource name.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
str: The base name, as extracted from the resource name.
|
|
190
|
+
"""
|
|
191
|
+
m = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", name)
|
|
192
|
+
return m.group(1) if m else name
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def sagemaker_timestamp():
|
|
196
|
+
"""Return a timestamp with millisecond precision."""
|
|
197
|
+
moment = time.time()
|
|
198
|
+
moment_ms = repr(moment).split(".")[1][:3]
|
|
199
|
+
return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_ms), time.gmtime(moment))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def sagemaker_short_timestamp():
|
|
203
|
+
"""Return a timestamp that is relatively short in length"""
|
|
204
|
+
return time.strftime("%y%m%d-%H%M")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def build_dict(key, value):
|
|
208
|
+
"""Return a dict of key and value pair if value is not None, otherwise return an empty dict.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
key (str): input key
|
|
212
|
+
value (str): input value
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
dict: dict of key and value or an empty dict.
|
|
216
|
+
"""
|
|
217
|
+
if value:
|
|
218
|
+
return {key: value}
|
|
219
|
+
return {}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def get_config_value(key_path, config):
|
|
223
|
+
"""Placeholder Docstring"""
|
|
224
|
+
if config is None:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
current_section = config
|
|
228
|
+
for key in key_path.split("."):
|
|
229
|
+
if key in current_section:
|
|
230
|
+
current_section = current_section[key]
|
|
231
|
+
else:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
return current_section
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def get_nested_value(dictionary: dict, nested_keys: List[str]):
|
|
238
|
+
"""Returns a nested value from the given dictionary, and None if none present.
|
|
239
|
+
|
|
240
|
+
Raises
|
|
241
|
+
ValueError if the dictionary structure does not match the nested_keys
|
|
242
|
+
"""
|
|
243
|
+
if (
|
|
244
|
+
dictionary is not None
|
|
245
|
+
and isinstance(dictionary, dict)
|
|
246
|
+
and nested_keys is not None
|
|
247
|
+
and len(nested_keys) > 0
|
|
248
|
+
):
|
|
249
|
+
|
|
250
|
+
current_section = dictionary
|
|
251
|
+
|
|
252
|
+
for key in nested_keys[:-1]:
|
|
253
|
+
current_section = current_section.get(key, None)
|
|
254
|
+
if current_section is None:
|
|
255
|
+
# means the full path of nested_keys doesnt exist in the dictionary
|
|
256
|
+
# or the value was set to None
|
|
257
|
+
return None
|
|
258
|
+
if not isinstance(current_section, dict):
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"Unexpected structure of dictionary.",
|
|
261
|
+
"Expected value of type dict at key '{}' but got '{}' for dict '{}'".format(
|
|
262
|
+
key, current_section, dictionary
|
|
263
|
+
),
|
|
264
|
+
)
|
|
265
|
+
return current_section.get(nested_keys[-1], None)
|
|
266
|
+
|
|
267
|
+
return None
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object):
|
|
271
|
+
"""Sets a nested value in a dictionary.
|
|
272
|
+
|
|
273
|
+
This sets a nested value inside the given dictionary and returns the new dictionary. Note: if
|
|
274
|
+
provided an unintended list of nested keys, this can overwrite an unexpected part of the dict.
|
|
275
|
+
Recommended to use after a check with get_nested_value first
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
if dictionary is None:
|
|
279
|
+
dictionary = {}
|
|
280
|
+
|
|
281
|
+
if (
|
|
282
|
+
dictionary is not None
|
|
283
|
+
and isinstance(dictionary, dict)
|
|
284
|
+
and nested_keys is not None
|
|
285
|
+
and len(nested_keys) > 0
|
|
286
|
+
):
|
|
287
|
+
current_section = dictionary
|
|
288
|
+
for key in nested_keys[:-1]:
|
|
289
|
+
if (
|
|
290
|
+
key not in current_section
|
|
291
|
+
or current_section[key] is None
|
|
292
|
+
or not isinstance(current_section[key], dict)
|
|
293
|
+
):
|
|
294
|
+
current_section[key] = {}
|
|
295
|
+
current_section = current_section[key]
|
|
296
|
+
|
|
297
|
+
current_section[nested_keys[-1]] = value_to_set
|
|
298
|
+
return dictionary
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def get_short_version(framework_version):
|
|
302
|
+
"""Return short version in the format of x.x
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
framework_version: The version string to be shortened.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
str: The short version string
|
|
309
|
+
"""
|
|
310
|
+
return ".".join(framework_version.split(".")[:2])
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def secondary_training_status_changed(current_job_description, prev_job_description):
|
|
314
|
+
"""Returns true if training job's secondary status message has changed.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
current_job_description: Current job description, returned from DescribeTrainingJob call.
|
|
318
|
+
prev_job_description: Previous job description, returned from DescribeTrainingJob call.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
boolean: Whether the secondary status message of a training job changed
|
|
322
|
+
or not.
|
|
323
|
+
"""
|
|
324
|
+
current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions")
|
|
325
|
+
if (
|
|
326
|
+
current_secondary_status_transitions is None
|
|
327
|
+
or len(current_secondary_status_transitions) == 0
|
|
328
|
+
):
|
|
329
|
+
return False
|
|
330
|
+
|
|
331
|
+
prev_job_secondary_status_transitions = (
|
|
332
|
+
prev_job_description.get("SecondaryStatusTransitions")
|
|
333
|
+
if prev_job_description is not None
|
|
334
|
+
else None
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
last_message = (
|
|
338
|
+
prev_job_secondary_status_transitions[-1]["StatusMessage"]
|
|
339
|
+
if prev_job_secondary_status_transitions is not None
|
|
340
|
+
and len(prev_job_secondary_status_transitions) > 0
|
|
341
|
+
else ""
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"]
|
|
345
|
+
|
|
346
|
+
return message != last_message
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def secondary_training_status_message(job_description, prev_description):
|
|
350
|
+
"""Returns a string contains last modified time and the secondary training job status message.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
job_description: Returned response from DescribeTrainingJob call
|
|
354
|
+
prev_description: Previous job description from DescribeTrainingJob call
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
str: Job status string to be printed.
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
if (
|
|
361
|
+
job_description is None
|
|
362
|
+
or job_description.get("SecondaryStatusTransitions") is None
|
|
363
|
+
or len(job_description.get("SecondaryStatusTransitions")) == 0
|
|
364
|
+
):
|
|
365
|
+
return ""
|
|
366
|
+
|
|
367
|
+
prev_description_secondary_transitions = (
|
|
368
|
+
prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None
|
|
369
|
+
)
|
|
370
|
+
prev_transitions_num = (
|
|
371
|
+
len(prev_description["SecondaryStatusTransitions"])
|
|
372
|
+
if prev_description_secondary_transitions is not None
|
|
373
|
+
else 0
|
|
374
|
+
)
|
|
375
|
+
current_transitions = job_description["SecondaryStatusTransitions"]
|
|
376
|
+
|
|
377
|
+
if len(current_transitions) == prev_transitions_num:
|
|
378
|
+
# Secondary status is not changed but the message changed.
|
|
379
|
+
transitions_to_print = current_transitions[-1:]
|
|
380
|
+
else:
|
|
381
|
+
# Secondary status is changed we need to print all the entries.
|
|
382
|
+
transitions_to_print = current_transitions[
|
|
383
|
+
prev_transitions_num - len(current_transitions) :
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
status_strs = []
|
|
387
|
+
for transition in transitions_to_print:
|
|
388
|
+
message = transition["StatusMessage"]
|
|
389
|
+
time_str = datetime.utcfromtimestamp(
|
|
390
|
+
time.mktime(job_description["LastModifiedTime"].timetuple())
|
|
391
|
+
).strftime("%Y-%m-%d %H:%M:%S")
|
|
392
|
+
status_strs.append("{} {} - {}".format(time_str, transition["Status"], message))
|
|
393
|
+
|
|
394
|
+
return "\n".join(status_strs)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def download_folder(bucket_name, prefix, target, sagemaker_session):
|
|
398
|
+
"""Download a folder from S3 to a local path
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
bucket_name (str): S3 bucket name
|
|
402
|
+
prefix (str): S3 prefix within the bucket that will be downloaded. Can
|
|
403
|
+
be a single file.
|
|
404
|
+
target (str): destination path where the downloaded items will be placed
|
|
405
|
+
sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
|
|
406
|
+
interact with S3.
|
|
407
|
+
"""
|
|
408
|
+
s3 = sagemaker_session.s3_resource
|
|
409
|
+
|
|
410
|
+
prefix = prefix.lstrip("/")
|
|
411
|
+
|
|
412
|
+
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
|
|
413
|
+
# Do this first, in case the object has broader permissions than the bucket.
|
|
414
|
+
if not prefix.endswith("/"):
|
|
415
|
+
try:
|
|
416
|
+
file_destination = os.path.join(target, os.path.basename(prefix))
|
|
417
|
+
s3.Object(bucket_name, prefix).download_file(file_destination)
|
|
418
|
+
return
|
|
419
|
+
except botocore.exceptions.ClientError as e:
|
|
420
|
+
err_info = e.response["Error"]
|
|
421
|
+
if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
|
|
422
|
+
# S3 also throws this error if the object is a folder,
|
|
423
|
+
# so assume that is the case here, and then raise for an actual 404 later.
|
|
424
|
+
pass
|
|
425
|
+
else:
|
|
426
|
+
raise
|
|
427
|
+
|
|
428
|
+
_download_files_under_prefix(bucket_name, prefix, target, s3)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _download_files_under_prefix(bucket_name, prefix, target, s3):
|
|
432
|
+
"""Download all S3 files which match the given prefix
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
bucket_name (str): S3 bucket name
|
|
436
|
+
prefix (str): S3 prefix within the bucket that will be downloaded
|
|
437
|
+
target (str): destination path where the downloaded items will be placed
|
|
438
|
+
s3 (boto3.resources.base.ServiceResource): S3 resource
|
|
439
|
+
"""
|
|
440
|
+
bucket = s3.Bucket(bucket_name)
|
|
441
|
+
for obj_sum in bucket.objects.filter(Prefix=prefix):
|
|
442
|
+
# if obj_sum is a folder object skip it.
|
|
443
|
+
if obj_sum.key.endswith("/"):
|
|
444
|
+
continue
|
|
445
|
+
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
|
|
446
|
+
s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
|
|
447
|
+
file_path = os.path.join(target, s3_relative_path)
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
os.makedirs(os.path.dirname(file_path))
|
|
451
|
+
except OSError as exc:
|
|
452
|
+
# EEXIST means the folder already exists, this is safe to skip
|
|
453
|
+
# anything else will be raised.
|
|
454
|
+
if exc.errno != errno.EEXIST:
|
|
455
|
+
raise
|
|
456
|
+
obj.download_file(file_path)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def create_tar_file(source_files, target=None):
|
|
460
|
+
"""Create a tar file containing all the source_files
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
source_files: (List[str]): List of file paths that will be contained in the tar file
|
|
464
|
+
target:
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
(str): path to created tar file
|
|
468
|
+
"""
|
|
469
|
+
if target:
|
|
470
|
+
filename = target
|
|
471
|
+
else:
|
|
472
|
+
_, filename = tempfile.mkstemp()
|
|
473
|
+
|
|
474
|
+
with tarfile.open(filename, mode="w:gz", dereference=True) as t:
|
|
475
|
+
for sf in source_files:
|
|
476
|
+
# Add all files from the directory into the root of the directory structure of the tar
|
|
477
|
+
t.add(sf, arcname=os.path.basename(sf))
|
|
478
|
+
return filename
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
@contextlib.contextmanager
|
|
482
|
+
def _tmpdir(suffix="", prefix="tmp", directory=None):
|
|
483
|
+
"""Create a temporary directory with a context manager.
|
|
484
|
+
|
|
485
|
+
The file is deleted when the context exits, even when there's an exception.
|
|
486
|
+
The prefix, suffix, and dir arguments are the same as for mkstemp().
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
suffix (str): If suffix is specified, the file name will end with that
|
|
490
|
+
suffix, otherwise there will be no suffix.
|
|
491
|
+
prefix (str): If prefix is specified, the file name will begin with that
|
|
492
|
+
prefix; otherwise, a default prefix is used.
|
|
493
|
+
directory (str): If a directory is specified, the file will be downloaded
|
|
494
|
+
in this directory; otherwise, a default directory is used.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
str: path to the directory
|
|
498
|
+
"""
|
|
499
|
+
if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)):
|
|
500
|
+
raise ValueError(
|
|
501
|
+
"Inputted directory for storing newly generated temporary "
|
|
502
|
+
f"directory does not exist: '{directory}'"
|
|
503
|
+
)
|
|
504
|
+
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
|
|
505
|
+
try:
|
|
506
|
+
yield tmp
|
|
507
|
+
finally:
|
|
508
|
+
shutil.rmtree(tmp)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def repack_model(
|
|
512
|
+
inference_script,
|
|
513
|
+
source_directory,
|
|
514
|
+
dependencies,
|
|
515
|
+
model_uri,
|
|
516
|
+
repacked_model_uri,
|
|
517
|
+
sagemaker_session,
|
|
518
|
+
kms_key=None,
|
|
519
|
+
):
|
|
520
|
+
"""Unpack model tarball and creates a new model tarball with the provided code script.
|
|
521
|
+
|
|
522
|
+
This function does the following: - uncompresses model tarball from S3 or
|
|
523
|
+
local system into a temp folder - replaces the inference code from the model
|
|
524
|
+
with the new code provided - compresses the new model tarball and saves it
|
|
525
|
+
in S3 or local file system
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
inference_script (str): path or basename of the inference script that
|
|
529
|
+
will be packed into the model
|
|
530
|
+
source_directory (str): path including all the files that will be packed
|
|
531
|
+
into the model
|
|
532
|
+
dependencies (list[str]): A list of paths to directories (absolute or
|
|
533
|
+
relative) with any additional libraries that will be exported to the
|
|
534
|
+
container (default: []). The library folders will be copied to
|
|
535
|
+
SageMaker in the same folder where the entrypoint is copied.
|
|
536
|
+
Example
|
|
537
|
+
|
|
538
|
+
The following call >>> Estimator(entry_point='train.py',
|
|
539
|
+
dependencies=['my/libs/common', 'virtual-env']) results in the
|
|
540
|
+
following inside the container:
|
|
541
|
+
|
|
542
|
+
>>> $ ls
|
|
543
|
+
|
|
544
|
+
>>> opt/ml/code
|
|
545
|
+
>>> |------ train.py
|
|
546
|
+
>>> |------ common
|
|
547
|
+
>>> |------ virtual-env
|
|
548
|
+
model_uri (str): S3 or file system location of the original model tar
|
|
549
|
+
repacked_model_uri (str): path or file system location where the new
|
|
550
|
+
model will be saved
|
|
551
|
+
sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
|
|
552
|
+
interact with S3.
|
|
553
|
+
kms_key (str): KMS key ARN for encrypting the repacked model file
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
str: path to the new packed model
|
|
557
|
+
"""
|
|
558
|
+
dependencies = dependencies or []
|
|
559
|
+
|
|
560
|
+
local_download_dir = (
|
|
561
|
+
None
|
|
562
|
+
if sagemaker_session.settings is None
|
|
563
|
+
or sagemaker_session.settings.local_download_dir is None
|
|
564
|
+
else sagemaker_session.settings.local_download_dir
|
|
565
|
+
)
|
|
566
|
+
with _tmpdir(directory=local_download_dir) as tmp:
|
|
567
|
+
model_dir = _extract_model(model_uri, sagemaker_session, tmp)
|
|
568
|
+
|
|
569
|
+
_create_or_update_code_dir(
|
|
570
|
+
model_dir,
|
|
571
|
+
inference_script,
|
|
572
|
+
source_directory,
|
|
573
|
+
dependencies,
|
|
574
|
+
sagemaker_session,
|
|
575
|
+
tmp,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
tmp_model_path = os.path.join(tmp, "temp-model.tar.gz")
|
|
579
|
+
with tarfile.open(tmp_model_path, mode="w:gz") as t:
|
|
580
|
+
t.add(model_dir, arcname=os.path.sep)
|
|
581
|
+
|
|
582
|
+
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
|
|
586
|
+
"""Placeholder docstring"""
|
|
587
|
+
if repacked_model_uri.lower().startswith("s3://"):
|
|
588
|
+
url = parse.urlparse(repacked_model_uri)
|
|
589
|
+
bucket, key = url.netloc, url.path.lstrip("/")
|
|
590
|
+
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
|
|
591
|
+
|
|
592
|
+
settings = (
|
|
593
|
+
sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
|
|
594
|
+
)
|
|
595
|
+
encrypt_artifact = settings.encrypt_repacked_artifacts
|
|
596
|
+
|
|
597
|
+
if kms_key:
|
|
598
|
+
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
|
|
599
|
+
elif encrypt_artifact:
|
|
600
|
+
extra_args = {"ServerSideEncryption": "aws:kms"}
|
|
601
|
+
else:
|
|
602
|
+
extra_args = None
|
|
603
|
+
sagemaker_session.boto_session.resource(
|
|
604
|
+
"s3", region_name=sagemaker_session.boto_region_name
|
|
605
|
+
).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args)
|
|
606
|
+
else:
|
|
607
|
+
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _create_or_update_code_dir(
|
|
611
|
+
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
|
|
612
|
+
):
|
|
613
|
+
"""Placeholder docstring"""
|
|
614
|
+
code_dir = os.path.join(model_dir, "code")
|
|
615
|
+
if source_directory and source_directory.lower().startswith("s3://"):
|
|
616
|
+
local_code_path = os.path.join(tmp, "local_code.tar.gz")
|
|
617
|
+
download_file_from_url(source_directory, local_code_path, sagemaker_session)
|
|
618
|
+
|
|
619
|
+
with tarfile.open(name=local_code_path, mode="r:gz") as t:
|
|
620
|
+
custom_extractall_tarfile(t, code_dir)
|
|
621
|
+
|
|
622
|
+
elif source_directory:
|
|
623
|
+
if os.path.exists(code_dir):
|
|
624
|
+
shutil.rmtree(code_dir)
|
|
625
|
+
shutil.copytree(source_directory, code_dir)
|
|
626
|
+
else:
|
|
627
|
+
if not os.path.exists(code_dir):
|
|
628
|
+
os.mkdir(code_dir)
|
|
629
|
+
try:
|
|
630
|
+
shutil.copy2(inference_script, code_dir)
|
|
631
|
+
except FileNotFoundError:
|
|
632
|
+
if os.path.exists(os.path.join(code_dir, inference_script)):
|
|
633
|
+
pass
|
|
634
|
+
else:
|
|
635
|
+
raise
|
|
636
|
+
|
|
637
|
+
for dependency in dependencies:
|
|
638
|
+
lib_dir = os.path.join(code_dir, "lib")
|
|
639
|
+
if os.path.isdir(dependency):
|
|
640
|
+
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
|
|
641
|
+
else:
|
|
642
|
+
if not os.path.exists(lib_dir):
|
|
643
|
+
os.mkdir(lib_dir)
|
|
644
|
+
shutil.copy2(dependency, lib_dir)
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def _extract_model(model_uri, sagemaker_session, tmp):
|
|
648
|
+
"""Placeholder docstring"""
|
|
649
|
+
tmp_model_dir = os.path.join(tmp, "model")
|
|
650
|
+
os.mkdir(tmp_model_dir)
|
|
651
|
+
if model_uri.lower().startswith("s3://"):
|
|
652
|
+
local_model_path = os.path.join(tmp, "tar_file")
|
|
653
|
+
download_file_from_url(model_uri, local_model_path, sagemaker_session)
|
|
654
|
+
else:
|
|
655
|
+
local_model_path = model_uri.replace("file://", "")
|
|
656
|
+
with tarfile.open(name=local_model_path, mode="r:gz") as t:
|
|
657
|
+
custom_extractall_tarfile(t, tmp_model_dir)
|
|
658
|
+
return tmp_model_dir
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def download_file_from_url(url, dst, sagemaker_session):
|
|
662
|
+
"""Placeholder docstring"""
|
|
663
|
+
url = parse.urlparse(url)
|
|
664
|
+
bucket, key = url.netloc, url.path.lstrip("/")
|
|
665
|
+
|
|
666
|
+
download_file(bucket, key, dst, sagemaker_session)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def download_file(bucket_name, path, target, sagemaker_session):
|
|
670
|
+
"""Download a Single File from S3 into a local path
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
bucket_name (str): S3 bucket name
|
|
674
|
+
path (str): file path within the bucket
|
|
675
|
+
target (str): destination directory for the downloaded file.
|
|
676
|
+
sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
|
|
677
|
+
interact with S3.
|
|
678
|
+
"""
|
|
679
|
+
path = path.lstrip("/")
|
|
680
|
+
boto_session = sagemaker_session.boto_session
|
|
681
|
+
|
|
682
|
+
s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name)
|
|
683
|
+
bucket = s3.Bucket(bucket_name)
|
|
684
|
+
bucket.download_file(path, target)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def sts_regional_endpoint(region):
|
|
688
|
+
"""Get the AWS STS endpoint specific for the given region.
|
|
689
|
+
|
|
690
|
+
We need this function because the AWS SDK does not yet honor
|
|
691
|
+
the ``region_name`` parameter when creating an AWS STS client.
|
|
692
|
+
|
|
693
|
+
For the list of regional endpoints, see
|
|
694
|
+
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
region (str): AWS region name
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
str: AWS STS regional endpoint
|
|
701
|
+
"""
|
|
702
|
+
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
|
|
703
|
+
if region == "il-central-1" and not endpoint_data:
|
|
704
|
+
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
|
|
705
|
+
return "https://{}".format(endpoint_data["hostname"])
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def retries(
|
|
709
|
+
max_retry_count,
|
|
710
|
+
exception_message_prefix,
|
|
711
|
+
seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS,
|
|
712
|
+
):
|
|
713
|
+
"""Retries until max retry count is reached.
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
max_retry_count (int): The retry count.
|
|
717
|
+
exception_message_prefix (str): The message to include in the exception on failure.
|
|
718
|
+
seconds_to_sleep (int): The number of seconds to sleep between executions.
|
|
719
|
+
|
|
720
|
+
"""
|
|
721
|
+
for i in range(max_retry_count):
|
|
722
|
+
yield i
|
|
723
|
+
time.sleep(seconds_to_sleep)
|
|
724
|
+
|
|
725
|
+
raise Exception(
|
|
726
|
+
"'{}' has reached the maximum retry count of {}".format(
|
|
727
|
+
exception_message_prefix, max_retry_count
|
|
728
|
+
)
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
|
|
733
|
+
"""Retry with backoff until maximum attempts are reached
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
callable_func (Callable): The callable function to retry.
|
|
737
|
+
num_attempts (int): The maximum number of attempts to retry.(Default: 8)
|
|
738
|
+
botocore_client_error_code (str): The specific Botocore ClientError exception error code
|
|
739
|
+
on which to retry on.
|
|
740
|
+
If provided other exceptions will be raised directly w/o retry.
|
|
741
|
+
If not provided, retry on any exception.
|
|
742
|
+
(Default: None)
|
|
743
|
+
"""
|
|
744
|
+
if num_attempts < 1:
|
|
745
|
+
raise ValueError(
|
|
746
|
+
"The num_attempts must be >= 1, but the given value is {}.".format(num_attempts)
|
|
747
|
+
)
|
|
748
|
+
for i in range(num_attempts):
|
|
749
|
+
try:
|
|
750
|
+
return callable_func()
|
|
751
|
+
except Exception as ex: # pylint: disable=broad-except
|
|
752
|
+
if not botocore_client_error_code or (
|
|
753
|
+
botocore_client_error_code
|
|
754
|
+
and isinstance(ex, botocore.exceptions.ClientError)
|
|
755
|
+
and ex.response["Error"]["Code"] # pylint: disable=no-member
|
|
756
|
+
== botocore_client_error_code
|
|
757
|
+
):
|
|
758
|
+
if i == num_attempts - 1:
|
|
759
|
+
raise ex
|
|
760
|
+
else:
|
|
761
|
+
raise ex
|
|
762
|
+
logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
|
|
763
|
+
time.sleep(2**i)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def _botocore_resolver():
|
|
767
|
+
"""Get the DNS suffix for the given region.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
region (str): AWS region name
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
str: the DNS suffix
|
|
774
|
+
"""
|
|
775
|
+
loader = botocore.loaders.create_loader()
|
|
776
|
+
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def aws_partition(region):
|
|
780
|
+
"""Given a region name (ex: "cn-north-1"), return the corresponding aws partition ("aws-cn").
|
|
781
|
+
|
|
782
|
+
Args:
|
|
783
|
+
region (str): The region name for which to return the corresponding partition.
|
|
784
|
+
Ex: "cn-north-1"
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
str: partition corresponding to the region name passed in. Ex: "aws-cn"
|
|
788
|
+
"""
|
|
789
|
+
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
|
|
790
|
+
if region == "il-central-1" and not endpoint_data:
|
|
791
|
+
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
|
|
792
|
+
return endpoint_data["partition"]
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
class DeferredError(object):
|
|
796
|
+
"""Stores an exception and raises it at a later time if this object is accessed in any way.
|
|
797
|
+
|
|
798
|
+
Useful to allow soft-dependencies on imports, so that the ImportError can be raised again
|
|
799
|
+
later if code actually relies on the missing library.
|
|
800
|
+
|
|
801
|
+
Example::
|
|
802
|
+
|
|
803
|
+
try:
|
|
804
|
+
import obscurelib
|
|
805
|
+
except ImportError as e:
|
|
806
|
+
logger.warning("Failed to import obscurelib. Obscure features will not work.")
|
|
807
|
+
obscurelib = DeferredError(e)
|
|
808
|
+
"""
|
|
809
|
+
|
|
810
|
+
def __init__(self, exception):
|
|
811
|
+
"""Placeholder docstring"""
|
|
812
|
+
self.exc = exception
|
|
813
|
+
|
|
814
|
+
def __getattr__(self, name):
|
|
815
|
+
"""Called by Python interpreter before using any method or property on the object.
|
|
816
|
+
|
|
817
|
+
So this will short-circuit essentially any access to this object.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
name:
|
|
821
|
+
"""
|
|
822
|
+
raise self.exc
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def _module_import_error(py_module, feature, extras):
|
|
826
|
+
"""Return error message for module import errors, provide installation details.
|
|
827
|
+
|
|
828
|
+
Args:
|
|
829
|
+
py_module (str): Module that failed to be imported
|
|
830
|
+
feature (str): Affected SageMaker feature
|
|
831
|
+
extras (str): Name of the `extras_require` to install the relevant dependencies
|
|
832
|
+
|
|
833
|
+
Returns:
|
|
834
|
+
str: Error message with installation instructions.
|
|
835
|
+
"""
|
|
836
|
+
error_msg = (
|
|
837
|
+
"Failed to import {}. {} features will be impaired or broken. "
|
|
838
|
+
"Please run \"pip install 'sagemaker[{}]'\" "
|
|
839
|
+
"to install all required dependencies."
|
|
840
|
+
)
|
|
841
|
+
return error_msg.format(py_module, feature, extras)
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
class DataConfig(abc.ABC):
|
|
845
|
+
"""Abstract base class for accessing data config hosted in AWS resources.
|
|
846
|
+
|
|
847
|
+
Provides a skeleton for customization by overriding of method fetch_data_config.
|
|
848
|
+
"""
|
|
849
|
+
|
|
850
|
+
@abc.abstractmethod
|
|
851
|
+
def fetch_data_config(self):
|
|
852
|
+
"""Abstract method implementing retrieval of data config from a pre-configured data source.
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
object: The data configuration object.
|
|
856
|
+
"""
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class S3DataConfig(DataConfig):
|
|
860
|
+
"""This class extends the DataConfig class to fetch a data config file hosted on S3"""
|
|
861
|
+
|
|
862
|
+
def __init__(
|
|
863
|
+
self,
|
|
864
|
+
sagemaker_session,
|
|
865
|
+
bucket_name,
|
|
866
|
+
prefix,
|
|
867
|
+
):
|
|
868
|
+
"""Initialize a ``S3DataConfig`` instance.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
sagemaker_session (Session): SageMaker session instance to use for boto configuration.
|
|
872
|
+
bucket_name (str): Required. Bucket name from which data config needs to be fetched.
|
|
873
|
+
prefix (str): Required. The object prefix for the hosted data config.
|
|
874
|
+
|
|
875
|
+
"""
|
|
876
|
+
if bucket_name is None or prefix is None:
|
|
877
|
+
raise ValueError(
|
|
878
|
+
"Bucket Name and S3 file Prefix are required arguments and must be provided."
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
super(S3DataConfig, self).__init__()
|
|
882
|
+
|
|
883
|
+
self.bucket_name = bucket_name
|
|
884
|
+
self.prefix = prefix
|
|
885
|
+
self.sagemaker_session = sagemaker_session
|
|
886
|
+
|
|
887
|
+
def fetch_data_config(self):
|
|
888
|
+
"""Fetches data configuration from a S3 bucket.
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
object: The JSON object containing data configuration.
|
|
892
|
+
"""
|
|
893
|
+
|
|
894
|
+
json_string = self.sagemaker_session.read_s3_file(self.bucket_name, self.prefix)
|
|
895
|
+
return json.loads(json_string)
|
|
896
|
+
|
|
897
|
+
def get_data_bucket(self, region_requested=None):
|
|
898
|
+
"""Provides the bucket containing the data for specified region.
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
region_requested (str): The region for which the data is beig requested.
|
|
902
|
+
|
|
903
|
+
Returns:
|
|
904
|
+
str: Name of the S3 bucket containing datasets in the requested region.
|
|
905
|
+
"""
|
|
906
|
+
|
|
907
|
+
config = self.fetch_data_config()
|
|
908
|
+
region = region_requested if region_requested else self.sagemaker_session.boto_region_name
|
|
909
|
+
return config[region] if region in config.keys() else config["default"]
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
def update_container_with_inference_params(
|
|
913
|
+
framework=None,
|
|
914
|
+
framework_version=None,
|
|
915
|
+
nearest_model_name=None,
|
|
916
|
+
data_input_configuration=None,
|
|
917
|
+
container_def=None,
|
|
918
|
+
container_list=None,
|
|
919
|
+
):
|
|
920
|
+
"""Function to check if inference recommender parameters exist and update container.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
framework (str): Machine learning framework of the model package container image
|
|
924
|
+
(default: None).
|
|
925
|
+
framework_version (str): Framework version of the Model Package Container Image
|
|
926
|
+
(default: None).
|
|
927
|
+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
|
|
928
|
+
Amazon SageMaker Inference Recommender (default: None).
|
|
929
|
+
data_input_configuration (str): Input object for the model (default: None).
|
|
930
|
+
container_def (dict): object to be updated.
|
|
931
|
+
container_list (list): list to be updated.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
dict: dict with inference recommender params
|
|
935
|
+
"""
|
|
936
|
+
|
|
937
|
+
if container_list is not None:
|
|
938
|
+
for obj in container_list:
|
|
939
|
+
construct_container_object(
|
|
940
|
+
obj, data_input_configuration, framework, framework_version, nearest_model_name
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
if container_def is not None:
|
|
944
|
+
construct_container_object(
|
|
945
|
+
container_def,
|
|
946
|
+
data_input_configuration,
|
|
947
|
+
framework,
|
|
948
|
+
framework_version,
|
|
949
|
+
nearest_model_name,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
return container_list or container_def
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def construct_container_object(
|
|
956
|
+
obj, data_input_configuration, framework, framework_version, nearest_model_name
|
|
957
|
+
):
|
|
958
|
+
"""Function to construct container object.
|
|
959
|
+
|
|
960
|
+
Args:
|
|
961
|
+
framework (str): Machine learning framework of the model package container image
|
|
962
|
+
(default: None).
|
|
963
|
+
framework_version (str): Framework version of the Model Package Container Image
|
|
964
|
+
(default: None).
|
|
965
|
+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
|
|
966
|
+
Amazon SageMaker Inference Recommender (default: None).
|
|
967
|
+
data_input_configuration (str): Input object for the model (default: None).
|
|
968
|
+
obj (dict): object to be updated.
|
|
969
|
+
|
|
970
|
+
Returns:
|
|
971
|
+
dict: container object
|
|
972
|
+
"""
|
|
973
|
+
|
|
974
|
+
if framework is not None:
|
|
975
|
+
obj.update(
|
|
976
|
+
{
|
|
977
|
+
"Framework": framework,
|
|
978
|
+
}
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
if framework_version is not None:
|
|
982
|
+
obj.update(
|
|
983
|
+
{
|
|
984
|
+
"FrameworkVersion": framework_version,
|
|
985
|
+
}
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
if nearest_model_name is not None:
|
|
989
|
+
obj.update(
|
|
990
|
+
{
|
|
991
|
+
"NearestModelName": nearest_model_name,
|
|
992
|
+
}
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
if data_input_configuration is not None:
|
|
996
|
+
obj.update(
|
|
997
|
+
{
|
|
998
|
+
"ModelInput": {
|
|
999
|
+
"DataInputConfig": data_input_configuration,
|
|
1000
|
+
},
|
|
1001
|
+
}
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
return obj
|
|
1005
|
+
|
|
1006
|
+
|
|
1007
|
+
def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None):
|
|
1008
|
+
"""Pop out the unused key-word argument and give a warning.
|
|
1009
|
+
|
|
1010
|
+
Args:
|
|
1011
|
+
arg_name (str): The name of the argument to be checked if it is unused.
|
|
1012
|
+
kwargs (dict): The key-word argument dict.
|
|
1013
|
+
override_val (str): The value used to override the unused argument (default: None).
|
|
1014
|
+
"""
|
|
1015
|
+
if arg_name not in kwargs:
|
|
1016
|
+
return
|
|
1017
|
+
warn_msg = "{} supplied in kwargs will be ignored".format(arg_name)
|
|
1018
|
+
if override_val:
|
|
1019
|
+
warn_msg += " and further overridden with {}.".format(override_val)
|
|
1020
|
+
logging.warning(warn_msg)
|
|
1021
|
+
kwargs.pop(arg_name)
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
def to_string(obj: object):
|
|
1025
|
+
"""Convert an object to string
|
|
1026
|
+
|
|
1027
|
+
This helper function handles converting PipelineVariable object to string as well
|
|
1028
|
+
|
|
1029
|
+
Args:
|
|
1030
|
+
obj (object): The object to be converted
|
|
1031
|
+
"""
|
|
1032
|
+
return obj.to_string() if is_pipeline_variable(obj) else str(obj)
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def _start_waiting(waiting_time: int):
|
|
1036
|
+
"""Waiting and print the in progress animation to stdout.
|
|
1037
|
+
|
|
1038
|
+
Args:
|
|
1039
|
+
waiting_time (int): The total waiting time.
|
|
1040
|
+
"""
|
|
1041
|
+
interval = float(waiting_time) / WAITING_DOT_NUMBER
|
|
1042
|
+
|
|
1043
|
+
progress = ""
|
|
1044
|
+
for _ in range(WAITING_DOT_NUMBER):
|
|
1045
|
+
progress += "."
|
|
1046
|
+
print(progress, end="\r")
|
|
1047
|
+
time.sleep(interval)
|
|
1048
|
+
print(len(progress) * " ", end="\r")
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
def get_module(module_name):
|
|
1052
|
+
"""Import a module.
|
|
1053
|
+
|
|
1054
|
+
Args:
|
|
1055
|
+
module_name (str): name of the module to import.
|
|
1056
|
+
|
|
1057
|
+
Returns:
|
|
1058
|
+
object: The imported module.
|
|
1059
|
+
|
|
1060
|
+
Raises:
|
|
1061
|
+
Exception: when the module name is not found
|
|
1062
|
+
"""
|
|
1063
|
+
try:
|
|
1064
|
+
return import_module(module_name)
|
|
1065
|
+
except ImportError:
|
|
1066
|
+
raise Exception("Cannot import module {}, please try again.".format(module_name))
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict:
|
|
1070
|
+
"""Check user input experiment_config or get it from the current Run object if exists.
|
|
1071
|
+
|
|
1072
|
+
Args:
|
|
1073
|
+
experiment_config (dict): The experiment_config supplied by the user.
|
|
1074
|
+
|
|
1075
|
+
Returns:
|
|
1076
|
+
dict: Return the user supplied experiment_config if it is not None.
|
|
1077
|
+
Otherwise fetch the experiment_config from the current Run object if exists.
|
|
1078
|
+
"""
|
|
1079
|
+
from sagemaker.core.experiments._run_context import _RunContext
|
|
1080
|
+
|
|
1081
|
+
run_obj = _RunContext.get_current_run()
|
|
1082
|
+
if experiment_config:
|
|
1083
|
+
if run_obj:
|
|
1084
|
+
logger.warning(
|
|
1085
|
+
"The function is invoked within an Experiment Run context "
|
|
1086
|
+
"but another experiment_config (%s) was supplied, so "
|
|
1087
|
+
"ignoring the experiment_config fetched from the Run object.",
|
|
1088
|
+
experiment_config,
|
|
1089
|
+
)
|
|
1090
|
+
return experiment_config
|
|
1091
|
+
|
|
1092
|
+
return run_obj.experiment_config if run_obj else None
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
def resolve_value_from_config(
|
|
1096
|
+
direct_input=None,
|
|
1097
|
+
config_path: str = None,
|
|
1098
|
+
default_value=None,
|
|
1099
|
+
sagemaker_session=None,
|
|
1100
|
+
sagemaker_config: dict = None,
|
|
1101
|
+
):
|
|
1102
|
+
"""Decides which value for the caller to use.
|
|
1103
|
+
|
|
1104
|
+
Note: This method incorporates information from the sagemaker config.
|
|
1105
|
+
|
|
1106
|
+
Uses this order of prioritization:
|
|
1107
|
+
1. direct_input
|
|
1108
|
+
2. config value
|
|
1109
|
+
3. default_value
|
|
1110
|
+
4. None
|
|
1111
|
+
|
|
1112
|
+
Args:
|
|
1113
|
+
direct_input: The value that the caller of this method starts with. Usually this is an
|
|
1114
|
+
input to the caller's class or method.
|
|
1115
|
+
config_path (str): A string denoting the path used to lookup the value in the
|
|
1116
|
+
sagemaker config.
|
|
1117
|
+
default_value: The value used if not present elsewhere.
|
|
1118
|
+
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
|
|
1119
|
+
SageMaker interactions (default: None).
|
|
1120
|
+
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
|
|
1121
|
+
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
|
|
1122
|
+
be checked for the config value if (and only if) sagemaker_session is None. This
|
|
1123
|
+
parameter exists for the rare cases where the user provided no Session but a default
|
|
1124
|
+
Session cannot be initialized before config injection is needed. In that case,
|
|
1125
|
+
the config dictionary may be loaded and passed here before a default Session object
|
|
1126
|
+
is created.
|
|
1127
|
+
|
|
1128
|
+
Returns:
|
|
1129
|
+
The value that should be used by the caller
|
|
1130
|
+
"""
|
|
1131
|
+
|
|
1132
|
+
config_value = (
|
|
1133
|
+
get_sagemaker_config_value(
|
|
1134
|
+
sagemaker_session, config_path, sagemaker_config=sagemaker_config
|
|
1135
|
+
)
|
|
1136
|
+
if config_path
|
|
1137
|
+
else None
|
|
1138
|
+
)
|
|
1139
|
+
from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
|
|
1140
|
+
|
|
1141
|
+
_log_sagemaker_config_single_substitution(direct_input, config_value, config_path)
|
|
1142
|
+
|
|
1143
|
+
if direct_input is not None:
|
|
1144
|
+
return direct_input
|
|
1145
|
+
|
|
1146
|
+
if config_value is not None:
|
|
1147
|
+
return config_value
|
|
1148
|
+
|
|
1149
|
+
return default_value
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = None):
|
|
1153
|
+
"""Returns the value that corresponds to the provided key from the configuration file.
|
|
1154
|
+
|
|
1155
|
+
Args:
|
|
1156
|
+
key: Key Path of the config file entry.
|
|
1157
|
+
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
|
|
1158
|
+
SageMaker interactions.
|
|
1159
|
+
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
|
|
1160
|
+
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
|
|
1161
|
+
be checked for the config value if (and only if) sagemaker_session is None. This
|
|
1162
|
+
parameter exists for the rare cases where no Session provided but a default Session
|
|
1163
|
+
cannot be initialized before config injection is needed. In that case, the config
|
|
1164
|
+
dictionary may be loaded and passed here before a default Session object is created.
|
|
1165
|
+
|
|
1166
|
+
Returns:
|
|
1167
|
+
object: The corresponding default value in the configuration file.
|
|
1168
|
+
"""
|
|
1169
|
+
from sagemaker.core.config.config_manager import SageMakerConfig
|
|
1170
|
+
|
|
1171
|
+
if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"):
|
|
1172
|
+
config_to_check = sagemaker_session.sagemaker_config
|
|
1173
|
+
else:
|
|
1174
|
+
config_to_check = sagemaker_config
|
|
1175
|
+
|
|
1176
|
+
if not config_to_check:
|
|
1177
|
+
return None
|
|
1178
|
+
|
|
1179
|
+
SageMakerConfig().validate_sagemaker_config(config_to_check)
|
|
1180
|
+
config_value = get_config_value(key, config_to_check)
|
|
1181
|
+
# Copy the value so any modifications to the output will not modify the source config
|
|
1182
|
+
return copy.deepcopy(config_value)
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
def get_resource_name_from_arn(arn):
|
|
1186
|
+
"""Extract the resource name from an ARN string.
|
|
1187
|
+
|
|
1188
|
+
Args:
|
|
1189
|
+
arn (str): An ARN.
|
|
1190
|
+
|
|
1191
|
+
Returns:
|
|
1192
|
+
str: The resource name.
|
|
1193
|
+
"""
|
|
1194
|
+
return arn.split(":", 5)[5].split("/", 1)[1]
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def list_tags(sagemaker_session, resource_arn, max_results=50):
|
|
1198
|
+
"""List the tags given an Amazon Resource Name.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
resource_arn (str): The Amazon Resource Name (ARN) for which to get the tags list.
|
|
1202
|
+
max_results (int): The maximum number of results to include in a single page.
|
|
1203
|
+
This method takes care of that abstraction and returns a full list.
|
|
1204
|
+
"""
|
|
1205
|
+
tags_list = []
|
|
1206
|
+
|
|
1207
|
+
try:
|
|
1208
|
+
list_tags_response = sagemaker_session.sagemaker_client.list_tags(
|
|
1209
|
+
ResourceArn=resource_arn, MaxResults=max_results
|
|
1210
|
+
)
|
|
1211
|
+
tags_list = tags_list + list_tags_response["Tags"]
|
|
1212
|
+
|
|
1213
|
+
next_token = list_tags_response.get("nextToken")
|
|
1214
|
+
while next_token is not None:
|
|
1215
|
+
list_tags_response = sagemaker_session.sagemaker_client.list_tags(
|
|
1216
|
+
ResourceArn=resource_arn, MaxResults=max_results, NextToken=next_token
|
|
1217
|
+
)
|
|
1218
|
+
tags_list = tags_list + list_tags_response["Tags"]
|
|
1219
|
+
next_token = list_tags_response.get("nextToken")
|
|
1220
|
+
|
|
1221
|
+
non_aws_tags = []
|
|
1222
|
+
for tag in tags_list:
|
|
1223
|
+
if "aws:" not in tag["Key"]:
|
|
1224
|
+
non_aws_tags.append(tag)
|
|
1225
|
+
return non_aws_tags
|
|
1226
|
+
except ClientError as error:
|
|
1227
|
+
logger.error("Error retrieving tags. resource_arn: %s", resource_arn)
|
|
1228
|
+
raise error
|
|
1229
|
+
|
|
1230
|
+
|
|
1231
|
+
def resolve_class_attribute_from_config(
|
|
1232
|
+
clazz: Optional[type],
|
|
1233
|
+
instance: Optional[object],
|
|
1234
|
+
attribute: str,
|
|
1235
|
+
config_path: str,
|
|
1236
|
+
default_value=None,
|
|
1237
|
+
sagemaker_session=None,
|
|
1238
|
+
):
|
|
1239
|
+
"""Utility method that merges config values to data classes.
|
|
1240
|
+
|
|
1241
|
+
Takes an instance of a class and, if not already set, sets the instance's attribute to a
|
|
1242
|
+
value fetched from the sagemaker_config or the default_value.
|
|
1243
|
+
|
|
1244
|
+
Uses this order of prioritization to determine what the value of the attribute should be:
|
|
1245
|
+
1. current value of attribute
|
|
1246
|
+
2. config value
|
|
1247
|
+
3. default_value
|
|
1248
|
+
4. does not set it
|
|
1249
|
+
|
|
1250
|
+
Args:
|
|
1251
|
+
clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the
|
|
1252
|
+
instance is None. If None is provided here, no new object will be created
|
|
1253
|
+
if 'instance' doesnt exist. Note: if provided, the constructor should set default
|
|
1254
|
+
values to None; Otherwise, the constructor's non-None default will be left
|
|
1255
|
+
as-is even if a config value was defined.
|
|
1256
|
+
instance (Optional[object]): instance of the Class 'clazz' that has an attribute
|
|
1257
|
+
of 'attribute' to set
|
|
1258
|
+
attribute (str): attribute of the instance to set if not already set
|
|
1259
|
+
config_path (str): a string denoting the path to use to lookup the config value in the
|
|
1260
|
+
sagemaker config
|
|
1261
|
+
default_value: the value to use if not present elsewhere
|
|
1262
|
+
sagemaker_session (sagemaker.core.helper.session.Sessionn): A SageMaker Session object, used for
|
|
1263
|
+
SageMaker interactions (default: None).
|
|
1264
|
+
|
|
1265
|
+
Returns:
|
|
1266
|
+
The updated class instance that should be used by the caller instead of the
|
|
1267
|
+
'instance' parameter that was passed in.
|
|
1268
|
+
"""
|
|
1269
|
+
config_value = get_sagemaker_config_value(sagemaker_session, config_path)
|
|
1270
|
+
|
|
1271
|
+
if config_value is None and default_value is None:
|
|
1272
|
+
# return instance unmodified. Could be None or populated
|
|
1273
|
+
return instance
|
|
1274
|
+
|
|
1275
|
+
if instance is None:
|
|
1276
|
+
if clazz is None or not inspect.isclass(clazz):
|
|
1277
|
+
return instance
|
|
1278
|
+
# construct a new instance if the instance does not exist
|
|
1279
|
+
instance = clazz()
|
|
1280
|
+
|
|
1281
|
+
if not hasattr(instance, attribute):
|
|
1282
|
+
raise TypeError(
|
|
1283
|
+
"Unexpected structure of object.",
|
|
1284
|
+
"Expected attribute {} to be present inside instance {} of class {}".format(
|
|
1285
|
+
attribute, instance, clazz
|
|
1286
|
+
),
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
current_value = getattr(instance, attribute)
|
|
1290
|
+
if current_value is None:
|
|
1291
|
+
# only set value if object does not already have a value set
|
|
1292
|
+
if config_value is not None:
|
|
1293
|
+
setattr(instance, attribute, config_value)
|
|
1294
|
+
elif default_value is not None:
|
|
1295
|
+
setattr(instance, attribute, default_value)
|
|
1296
|
+
|
|
1297
|
+
from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
|
|
1298
|
+
|
|
1299
|
+
_log_sagemaker_config_single_substitution(current_value, config_value, config_path)
|
|
1300
|
+
|
|
1301
|
+
return instance
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
def resolve_nested_dict_value_from_config(
|
|
1305
|
+
dictionary: dict,
|
|
1306
|
+
nested_keys: List[str],
|
|
1307
|
+
config_path: str,
|
|
1308
|
+
default_value: object = None,
|
|
1309
|
+
sagemaker_session=None,
|
|
1310
|
+
):
|
|
1311
|
+
"""Utility method that sets the value of a key path in a nested dictionary .
|
|
1312
|
+
|
|
1313
|
+
This method takes a dictionary and, if not already set, sets the value for the provided
|
|
1314
|
+
list of nested keys to the value fetched from the sagemaker_config or the default_value.
|
|
1315
|
+
|
|
1316
|
+
Uses this order of prioritization to determine what the value of the attribute should be:
|
|
1317
|
+
(1) current value of nested key, (2) config value, (3) default_value, (4) does not set it
|
|
1318
|
+
|
|
1319
|
+
Args:
|
|
1320
|
+
dictionary: The dict to update.
|
|
1321
|
+
nested_keys: The paths of keys where the value should be checked and set if needed.
|
|
1322
|
+
config_path (str): A string denoting the path used to find the config value in the
|
|
1323
|
+
sagemaker config.
|
|
1324
|
+
default_value: The value to use if not present elsewhere.
|
|
1325
|
+
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
|
|
1326
|
+
SageMaker interactions (default: None).
|
|
1327
|
+
|
|
1328
|
+
Returns:
|
|
1329
|
+
The updated dictionary that should be used by the caller instead of the
|
|
1330
|
+
'dictionary' parameter that was passed in.
|
|
1331
|
+
"""
|
|
1332
|
+
config_value = get_sagemaker_config_value(sagemaker_session, config_path)
|
|
1333
|
+
|
|
1334
|
+
if config_value is None and default_value is None:
|
|
1335
|
+
# if there is nothing to set, return early. And there is no need to traverse through
|
|
1336
|
+
# the dictionary or add nested dicts to it
|
|
1337
|
+
return dictionary
|
|
1338
|
+
|
|
1339
|
+
try:
|
|
1340
|
+
current_nested_value = get_nested_value(dictionary, nested_keys)
|
|
1341
|
+
except ValueError as e:
|
|
1342
|
+
logging.error("Failed to check dictionary for applying sagemaker config: %s", e)
|
|
1343
|
+
return dictionary
|
|
1344
|
+
|
|
1345
|
+
if current_nested_value is None:
|
|
1346
|
+
# only set value if not already set
|
|
1347
|
+
if config_value is not None:
|
|
1348
|
+
dictionary = set_nested_value(dictionary, nested_keys, config_value)
|
|
1349
|
+
elif default_value is not None:
|
|
1350
|
+
dictionary = set_nested_value(dictionary, nested_keys, default_value)
|
|
1351
|
+
|
|
1352
|
+
from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
|
|
1353
|
+
|
|
1354
|
+
_log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path)
|
|
1355
|
+
|
|
1356
|
+
return dictionary
|
|
1357
|
+
|
|
1358
|
+
|
|
1359
|
+
def update_list_of_dicts_with_values_from_config(
|
|
1360
|
+
input_list,
|
|
1361
|
+
config_key_path,
|
|
1362
|
+
required_key_paths: List[str] = None,
|
|
1363
|
+
union_key_paths: List[List[str]] = None,
|
|
1364
|
+
sagemaker_session=None,
|
|
1365
|
+
):
|
|
1366
|
+
"""Updates a list of dictionaries with missing values that are present in Config.
|
|
1367
|
+
|
|
1368
|
+
In some cases, config file might introduce new parameters which requires certain other
|
|
1369
|
+
parameters to be provided as part of the input list. Without those parameters, the underlying
|
|
1370
|
+
service will throw an exception. This method provides the capability to specify required key
|
|
1371
|
+
paths.
|
|
1372
|
+
|
|
1373
|
+
In some other cases, config file might introduce new parameters but the service API requires
|
|
1374
|
+
either an existing parameter or the new parameter that was supplied by config but not both
|
|
1375
|
+
|
|
1376
|
+
Args:
|
|
1377
|
+
input_list: The input list that was provided as a method parameter.
|
|
1378
|
+
config_key_path: The Key Path in the Config file that corresponds to the input_list
|
|
1379
|
+
parameter.
|
|
1380
|
+
required_key_paths (List[str]): List of required key paths that should be verified in the
|
|
1381
|
+
merged output. If a required key path is missing, we will not perform the merge for that
|
|
1382
|
+
item.
|
|
1383
|
+
union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify
|
|
1384
|
+
whether exactly zero/one of the parameters exist.
|
|
1385
|
+
For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or
|
|
1386
|
+
neither but not both, then pass [['X1', 'X2']]
|
|
1387
|
+
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
|
|
1388
|
+
SageMaker interactions (default: None).
|
|
1389
|
+
|
|
1390
|
+
Returns:
|
|
1391
|
+
No output. In place merge happens.
|
|
1392
|
+
"""
|
|
1393
|
+
if not input_list:
|
|
1394
|
+
return
|
|
1395
|
+
inputs_copy = copy.deepcopy(input_list)
|
|
1396
|
+
inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or []
|
|
1397
|
+
unmodified_inputs_from_config = copy.deepcopy(inputs_from_config)
|
|
1398
|
+
|
|
1399
|
+
for i in range(min(len(input_list), len(inputs_from_config))):
|
|
1400
|
+
dict_from_inputs = input_list[i]
|
|
1401
|
+
dict_from_config = inputs_from_config[i]
|
|
1402
|
+
merge_dicts(dict_from_config, dict_from_inputs)
|
|
1403
|
+
# Check if required key paths are present in merged dict (dict_from_config)
|
|
1404
|
+
required_key_path_check_passed = _validate_required_paths_in_a_dict(
|
|
1405
|
+
dict_from_config, required_key_paths
|
|
1406
|
+
)
|
|
1407
|
+
if not required_key_path_check_passed:
|
|
1408
|
+
# Don't do the merge, config is introducing a new parameter which needs a
|
|
1409
|
+
# corresponding required parameter.
|
|
1410
|
+
continue
|
|
1411
|
+
union_key_path_check_passed = _validate_union_key_paths_in_a_dict(
|
|
1412
|
+
dict_from_config, union_key_paths
|
|
1413
|
+
)
|
|
1414
|
+
if not union_key_path_check_passed:
|
|
1415
|
+
# Don't do the merge, Union parameters are not obeyed.
|
|
1416
|
+
continue
|
|
1417
|
+
input_list[i] = dict_from_config
|
|
1418
|
+
|
|
1419
|
+
from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
|
|
1420
|
+
|
|
1421
|
+
_log_sagemaker_config_merge(
|
|
1422
|
+
source_value=inputs_copy,
|
|
1423
|
+
config_value=unmodified_inputs_from_config,
|
|
1424
|
+
merged_source_and_config_value=input_list,
|
|
1425
|
+
config_key_path=config_key_path,
|
|
1426
|
+
)
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool:
|
|
1430
|
+
"""Placeholder docstring"""
|
|
1431
|
+
if not required_key_paths:
|
|
1432
|
+
return True
|
|
1433
|
+
for required_key_path in required_key_paths:
|
|
1434
|
+
if get_config_value(required_key_path, source_dict) is None:
|
|
1435
|
+
return False
|
|
1436
|
+
return True
|
|
1437
|
+
|
|
1438
|
+
|
|
1439
|
+
def _validate_union_key_paths_in_a_dict(
|
|
1440
|
+
source_dict, union_key_paths: List[List[str]] = None
|
|
1441
|
+
) -> bool:
|
|
1442
|
+
"""Placeholder docstring"""
|
|
1443
|
+
if not union_key_paths:
|
|
1444
|
+
return True
|
|
1445
|
+
for union_key_path in union_key_paths:
|
|
1446
|
+
union_parameter_present = False
|
|
1447
|
+
for key_path in union_key_path:
|
|
1448
|
+
if get_config_value(key_path, source_dict):
|
|
1449
|
+
if union_parameter_present:
|
|
1450
|
+
return False
|
|
1451
|
+
union_parameter_present = True
|
|
1452
|
+
return True
|
|
1453
|
+
|
|
1454
|
+
|
|
1455
|
+
def update_nested_dictionary_with_values_from_config(
|
|
1456
|
+
source_dict, config_key_path, sagemaker_session=None
|
|
1457
|
+
) -> dict:
|
|
1458
|
+
"""Updates a nested dictionary with missing values that are present in Config.
|
|
1459
|
+
|
|
1460
|
+
Args:
|
|
1461
|
+
source_dict: The input nested dictionary that was provided as method parameter.
|
|
1462
|
+
config_key_path: The Key Path in the Config file which corresponds to this
|
|
1463
|
+
source_dict parameter.
|
|
1464
|
+
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
|
|
1465
|
+
SageMaker interactions (default: None).
|
|
1466
|
+
|
|
1467
|
+
Returns:
|
|
1468
|
+
dict: The merged nested dictionary that is updated with missing values that are present
|
|
1469
|
+
in the Config file.
|
|
1470
|
+
"""
|
|
1471
|
+
inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {}
|
|
1472
|
+
original_config_dict_value = copy.deepcopy(inferred_config_dict)
|
|
1473
|
+
merge_dicts(inferred_config_dict, source_dict or {})
|
|
1474
|
+
|
|
1475
|
+
if original_config_dict_value == {}:
|
|
1476
|
+
# The config value is empty. That means either
|
|
1477
|
+
# (1) inferred_config_dict equals source_dict, or
|
|
1478
|
+
# (2) if source_dict was None, inferred_config_dict equals {}
|
|
1479
|
+
# We should return whatever source_dict was to be safe. Because if for example,
|
|
1480
|
+
# a VpcConfig is set to {} instead of None, some boto calls will fail due to
|
|
1481
|
+
# ParamValidationError (because a VpcConfig was specified but required parameters for
|
|
1482
|
+
# the VpcConfig were missing.)
|
|
1483
|
+
|
|
1484
|
+
# Don't need to print because no config value was used or defined
|
|
1485
|
+
return source_dict
|
|
1486
|
+
|
|
1487
|
+
from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
|
|
1488
|
+
|
|
1489
|
+
_log_sagemaker_config_merge(
|
|
1490
|
+
source_value=source_dict,
|
|
1491
|
+
config_value=original_config_dict_value,
|
|
1492
|
+
merged_source_and_config_value=inferred_config_dict,
|
|
1493
|
+
config_key_path=config_key_path,
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
return inferred_config_dict
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
def stringify_object(obj: Any) -> str:
|
|
1500
|
+
"""Returns string representation of object, returning only non-None fields."""
|
|
1501
|
+
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
|
|
1502
|
+
return f"{type(obj).__name__}: {str(non_none_atts)}"
|
|
1503
|
+
|
|
1504
|
+
|
|
1505
|
+
def volume_size_supported(instance_type: str) -> bool:
|
|
1506
|
+
"""Returns True if SageMaker allows volume_size to be used for the instance type.
|
|
1507
|
+
|
|
1508
|
+
Raises:
|
|
1509
|
+
ValueError: If the instance type is improperly formatted.
|
|
1510
|
+
"""
|
|
1511
|
+
|
|
1512
|
+
try:
|
|
1513
|
+
|
|
1514
|
+
# local mode does not support volume size
|
|
1515
|
+
# instance type given as pipeline parameter does not support volume size
|
|
1516
|
+
# do not change the if statement order below.
|
|
1517
|
+
if is_pipeline_variable(instance_type) or instance_type.startswith("local"):
|
|
1518
|
+
return False
|
|
1519
|
+
|
|
1520
|
+
parts: List[str] = instance_type.split(".")
|
|
1521
|
+
|
|
1522
|
+
if len(parts) == 3 and parts[0] == "ml":
|
|
1523
|
+
parts = parts[1:]
|
|
1524
|
+
|
|
1525
|
+
if len(parts) != 2:
|
|
1526
|
+
raise ValueError(f"Failed to parse instance type '{instance_type}'")
|
|
1527
|
+
|
|
1528
|
+
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc)
|
|
1529
|
+
# + g5 or g6 or p5 does not support attaching an EBS volume.
|
|
1530
|
+
family = parts[0]
|
|
1531
|
+
|
|
1532
|
+
unsupported_families = ["g5", "g6", "p5", "trn1"]
|
|
1533
|
+
|
|
1534
|
+
return "d" not in family and not any(
|
|
1535
|
+
family.startswith(prefix) for prefix in unsupported_families
|
|
1536
|
+
)
|
|
1537
|
+
except Exception as e:
|
|
1538
|
+
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")
|
|
1539
|
+
|
|
1540
|
+
|
|
1541
|
+
def instance_supports_kms(instance_type: str) -> bool:
|
|
1542
|
+
"""Returns True if SageMaker allows KMS keys to be attached to the instance.
|
|
1543
|
+
|
|
1544
|
+
Raises:
|
|
1545
|
+
ValueError: If the instance type is improperly formatted.
|
|
1546
|
+
"""
|
|
1547
|
+
return volume_size_supported(instance_type)
|
|
1548
|
+
|
|
1549
|
+
|
|
1550
|
+
def get_instance_type_family(instance_type: str) -> str:
|
|
1551
|
+
"""Return the family of the instance type.
|
|
1552
|
+
|
|
1553
|
+
Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
|
|
1554
|
+
or there is no match, return an empty string.
|
|
1555
|
+
"""
|
|
1556
|
+
instance_type_family = ""
|
|
1557
|
+
if isinstance(instance_type, str):
|
|
1558
|
+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
|
|
1559
|
+
if match is not None:
|
|
1560
|
+
instance_type_family = match[1]
|
|
1561
|
+
return instance_type_family
|
|
1562
|
+
|
|
1563
|
+
|
|
1564
|
+
def create_paginator_config(max_items: int = None, page_size: int = None) -> Dict[str, int]:
|
|
1565
|
+
"""Placeholder docstring"""
|
|
1566
|
+
return {
|
|
1567
|
+
"MaxItems": max_items if max_items else MAX_ITEMS,
|
|
1568
|
+
"PageSize": page_size if page_size else PAGE_SIZE,
|
|
1569
|
+
}
|
|
1570
|
+
|
|
1571
|
+
|
|
1572
|
+
def format_tags(tags: Tags) -> List[TagsDict]:
|
|
1573
|
+
"""Process tags to turn them into the expected format for Sagemaker."""
|
|
1574
|
+
if isinstance(tags, dict):
|
|
1575
|
+
return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]
|
|
1576
|
+
|
|
1577
|
+
return tags
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
def _get_resolved_path(path):
|
|
1581
|
+
"""Return the normalized absolute path of a given path.
|
|
1582
|
+
|
|
1583
|
+
abspath - returns the absolute path without resolving symlinks
|
|
1584
|
+
realpath - resolves the symlinks and gets the actual path
|
|
1585
|
+
normpath - normalizes paths (e.g. remove redudant separators)
|
|
1586
|
+
and handles platform-specific differences
|
|
1587
|
+
"""
|
|
1588
|
+
return normpath(realpath(abspath(path)))
|
|
1589
|
+
|
|
1590
|
+
|
|
1591
|
+
def _is_bad_path(path, base):
|
|
1592
|
+
"""Checks if the joined path (base directory + file path) is rooted under the base directory
|
|
1593
|
+
|
|
1594
|
+
Ensuring that the file does not attempt to access paths
|
|
1595
|
+
outside the expected directory structure.
|
|
1596
|
+
|
|
1597
|
+
Args:
|
|
1598
|
+
path (str): The file path.
|
|
1599
|
+
base (str): The base directory.
|
|
1600
|
+
|
|
1601
|
+
Returns:
|
|
1602
|
+
bool: True if the path is not rooted under the base directory, False otherwise.
|
|
1603
|
+
"""
|
|
1604
|
+
# joinpath will ignore base if path is absolute
|
|
1605
|
+
return not _get_resolved_path(joinpath(base, path)).startswith(base)
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
def _is_bad_link(info, base):
|
|
1609
|
+
"""Checks if the link is rooted under the base directory.
|
|
1610
|
+
|
|
1611
|
+
Ensuring that the link does not attempt to access paths outside the expected directory structure
|
|
1612
|
+
|
|
1613
|
+
Args:
|
|
1614
|
+
info (tarfile.TarInfo): The tar file info.
|
|
1615
|
+
base (str): The base directory.
|
|
1616
|
+
|
|
1617
|
+
Returns:
|
|
1618
|
+
bool: True if the link is not rooted under the base directory, False otherwise.
|
|
1619
|
+
"""
|
|
1620
|
+
# Links are interpreted relative to the directory containing the link
|
|
1621
|
+
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
|
|
1622
|
+
return _is_bad_path(info.linkname, base=tip)
|
|
1623
|
+
|
|
1624
|
+
|
|
1625
|
+
def _get_safe_members(members):
|
|
1626
|
+
"""A generator that yields members that are safe to extract.
|
|
1627
|
+
|
|
1628
|
+
It filters out bad paths and bad links.
|
|
1629
|
+
|
|
1630
|
+
Args:
|
|
1631
|
+
members (list): A list of members to check.
|
|
1632
|
+
|
|
1633
|
+
Yields:
|
|
1634
|
+
tarfile.TarInfo: The tar file info.
|
|
1635
|
+
"""
|
|
1636
|
+
base = _get_resolved_path("")
|
|
1637
|
+
|
|
1638
|
+
for file_info in members:
|
|
1639
|
+
if _is_bad_path(file_info.name, base):
|
|
1640
|
+
logger.error("%s is blocked (illegal path)", file_info.name)
|
|
1641
|
+
elif file_info.issym() and _is_bad_link(file_info, base):
|
|
1642
|
+
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
|
|
1643
|
+
elif file_info.islnk() and _is_bad_link(file_info, base):
|
|
1644
|
+
logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
|
|
1645
|
+
else:
|
|
1646
|
+
yield file_info
|
|
1647
|
+
|
|
1648
|
+
|
|
1649
|
+
def custom_extractall_tarfile(tar, extract_path):
|
|
1650
|
+
"""Extract a tarfile, optionally using data_filter if available.
|
|
1651
|
+
|
|
1652
|
+
# TODO: The function and it's usages can be deprecated once SageMaker Python SDK
|
|
1653
|
+
is upgraded to use Python 3.12+
|
|
1654
|
+
|
|
1655
|
+
If the tarfile has a data_filter attribute, it will be used to extract the contents of the file.
|
|
1656
|
+
Otherwise, the _get_safe_members function will be used to filter bad paths and bad links.
|
|
1657
|
+
|
|
1658
|
+
Args:
|
|
1659
|
+
tar (tarfile.TarFile): The opened tarfile object.
|
|
1660
|
+
extract_path (str): The path to extract the contents of the tarfile.
|
|
1661
|
+
|
|
1662
|
+
Returns:
|
|
1663
|
+
None
|
|
1664
|
+
"""
|
|
1665
|
+
if hasattr(tarfile, "data_filter"):
|
|
1666
|
+
tar.extractall(path=extract_path, filter="data")
|
|
1667
|
+
else:
|
|
1668
|
+
tar.extractall(path=extract_path, members=_get_safe_members(tar))
|
|
1669
|
+
|
|
1670
|
+
|
|
1671
|
+
def can_model_package_source_uri_autopopulate(source_uri: str):
|
|
1672
|
+
"""Checks if the source_uri can lead to auto-population of information in the Model registry.
|
|
1673
|
+
|
|
1674
|
+
Args:
|
|
1675
|
+
source_uri (str): The source uri.
|
|
1676
|
+
|
|
1677
|
+
Returns:
|
|
1678
|
+
bool: True if the source_uri can lead to auto-population, False otherwise.
|
|
1679
|
+
"""
|
|
1680
|
+
return bool(
|
|
1681
|
+
re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri)
|
|
1682
|
+
)
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def flatten_dict(
|
|
1686
|
+
d: Dict[str, Any],
|
|
1687
|
+
max_flatten_depth=None,
|
|
1688
|
+
) -> Dict[str, Any]:
|
|
1689
|
+
"""Flatten a dictionary object.
|
|
1690
|
+
|
|
1691
|
+
d (Dict[str, Any]):
|
|
1692
|
+
The dict that will be flattened.
|
|
1693
|
+
max_flatten_depth (Optional[int]):
|
|
1694
|
+
Maximum depth to merge.
|
|
1695
|
+
"""
|
|
1696
|
+
|
|
1697
|
+
def tuple_reducer(k1, k2):
|
|
1698
|
+
if k1 is None:
|
|
1699
|
+
return (k2,)
|
|
1700
|
+
return k1 + (k2,)
|
|
1701
|
+
|
|
1702
|
+
# check max_flatten_depth
|
|
1703
|
+
if max_flatten_depth is not None and max_flatten_depth < 1:
|
|
1704
|
+
raise ValueError("max_flatten_depth should not be less than 1.")
|
|
1705
|
+
|
|
1706
|
+
reducer = tuple_reducer
|
|
1707
|
+
|
|
1708
|
+
flat_dict = {}
|
|
1709
|
+
|
|
1710
|
+
def _flatten(_d, depth, parent=None):
|
|
1711
|
+
key_value_iterable = viewitems(_d)
|
|
1712
|
+
has_item = False
|
|
1713
|
+
for key, value in key_value_iterable:
|
|
1714
|
+
has_item = True
|
|
1715
|
+
flat_key = reducer(parent, key)
|
|
1716
|
+
if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth):
|
|
1717
|
+
has_child = _flatten(value, depth=depth + 1, parent=flat_key)
|
|
1718
|
+
if has_child:
|
|
1719
|
+
continue
|
|
1720
|
+
|
|
1721
|
+
if flat_key in flat_dict:
|
|
1722
|
+
raise ValueError("duplicated key '{}'".format(flat_key))
|
|
1723
|
+
flat_dict[flat_key] = value
|
|
1724
|
+
|
|
1725
|
+
return has_item
|
|
1726
|
+
|
|
1727
|
+
_flatten(d, depth=1)
|
|
1728
|
+
return flat_dict
|
|
1729
|
+
|
|
1730
|
+
|
|
1731
|
+
def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None:
|
|
1732
|
+
"""Set a value to a sequence of nested keys."""
|
|
1733
|
+
|
|
1734
|
+
key = keys[0]
|
|
1735
|
+
|
|
1736
|
+
if len(keys) == 1:
|
|
1737
|
+
d[key] = value
|
|
1738
|
+
return
|
|
1739
|
+
|
|
1740
|
+
d = d.setdefault(key, {})
|
|
1741
|
+
nested_set_dict(d, keys[1:], value)
|
|
1742
|
+
|
|
1743
|
+
|
|
1744
|
+
def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
|
1745
|
+
"""Unflatten dict-like object.
|
|
1746
|
+
|
|
1747
|
+
d (Dict[str, Any]) :
|
|
1748
|
+
The dict that will be unflattened.
|
|
1749
|
+
"""
|
|
1750
|
+
|
|
1751
|
+
unflattened_dict = {}
|
|
1752
|
+
for flat_key, value in viewitems(d):
|
|
1753
|
+
key_tuple = flat_key
|
|
1754
|
+
nested_set_dict(unflattened_dict, key_tuple, value)
|
|
1755
|
+
|
|
1756
|
+
return unflattened_dict
|
|
1757
|
+
|
|
1758
|
+
|
|
1759
|
+
def deep_override_dict(
|
|
1760
|
+
dict1: Dict[str, Any], dict2: Dict[str, Any], skip_keys: Optional[List[str]] = None
|
|
1761
|
+
) -> Dict[str, Any]:
|
|
1762
|
+
"""Overrides any overlapping contents of dict1 with the contents of dict2."""
|
|
1763
|
+
if skip_keys is None:
|
|
1764
|
+
skip_keys = []
|
|
1765
|
+
|
|
1766
|
+
flattened_dict1 = flatten_dict(dict1)
|
|
1767
|
+
flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None}
|
|
1768
|
+
flattened_dict2 = flatten_dict(
|
|
1769
|
+
{key: value for key, value in dict2.items() if key not in skip_keys}
|
|
1770
|
+
)
|
|
1771
|
+
flattened_dict1.update(flattened_dict2)
|
|
1772
|
+
return unflatten_dict(flattened_dict1) if flattened_dict1 else {}
|
|
1773
|
+
|
|
1774
|
+
|
|
1775
|
+
def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
|
1776
|
+
"""Resolve Routing Config
|
|
1777
|
+
|
|
1778
|
+
Args:
|
|
1779
|
+
routing_config (Optional[Dict[str, Any]]): The routing config.
|
|
1780
|
+
|
|
1781
|
+
Returns:
|
|
1782
|
+
Optional[Dict[str, Any]]: The resolved routing config.
|
|
1783
|
+
|
|
1784
|
+
Raises:
|
|
1785
|
+
ValueError: If the RoutingStrategy is invalid.
|
|
1786
|
+
"""
|
|
1787
|
+
|
|
1788
|
+
if routing_config:
|
|
1789
|
+
routing_strategy = routing_config.get("RoutingStrategy", None)
|
|
1790
|
+
if routing_strategy:
|
|
1791
|
+
if isinstance(routing_strategy, RoutingStrategy):
|
|
1792
|
+
return {"RoutingStrategy": routing_strategy.name}
|
|
1793
|
+
if isinstance(routing_strategy, str) and (
|
|
1794
|
+
routing_strategy.upper() == RoutingStrategy.RANDOM.name
|
|
1795
|
+
or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
|
|
1796
|
+
):
|
|
1797
|
+
return {"RoutingStrategy": routing_strategy.upper()}
|
|
1798
|
+
raise ValueError(
|
|
1799
|
+
"RoutingStrategy must be either RoutingStrategy.RANDOM "
|
|
1800
|
+
"or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
|
|
1801
|
+
)
|
|
1802
|
+
return None
|
|
1803
|
+
|
|
1804
|
+
|
|
1805
|
+
@lru_cache
|
|
1806
|
+
def get_instance_rate_per_hour(
|
|
1807
|
+
instance_type: str,
|
|
1808
|
+
region: str,
|
|
1809
|
+
) -> Optional[Dict[str, str]]:
|
|
1810
|
+
"""Gets instance rate per hour for the given instance type.
|
|
1811
|
+
|
|
1812
|
+
Args:
|
|
1813
|
+
instance_type (str): The instance type.
|
|
1814
|
+
region (str): The region.
|
|
1815
|
+
Returns:
|
|
1816
|
+
Optional[Dict[str, str]]: Instance rate per hour.
|
|
1817
|
+
Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}.
|
|
1818
|
+
|
|
1819
|
+
Raises:
|
|
1820
|
+
Exception: An exception is raised if
|
|
1821
|
+
the IAM role is not authorized to perform pricing:GetProducts.
|
|
1822
|
+
or unexpected event happened.
|
|
1823
|
+
"""
|
|
1824
|
+
region_name = "us-east-1"
|
|
1825
|
+
if region.startswith("eu") or region.startswith("af"):
|
|
1826
|
+
region_name = "eu-central-1"
|
|
1827
|
+
elif region.startswith("ap") or region.startswith("cn"):
|
|
1828
|
+
region_name = "ap-south-1"
|
|
1829
|
+
|
|
1830
|
+
pricing_client: boto3.client = boto3.client("pricing", region_name=region_name)
|
|
1831
|
+
res = pricing_client.get_products(
|
|
1832
|
+
ServiceCode="AmazonSageMaker",
|
|
1833
|
+
Filters=[
|
|
1834
|
+
{"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type},
|
|
1835
|
+
{"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"},
|
|
1836
|
+
{"Type": "TERM_MATCH", "Field": "regionCode", "Value": region},
|
|
1837
|
+
],
|
|
1838
|
+
)
|
|
1839
|
+
|
|
1840
|
+
price_list = res.get("PriceList", [])
|
|
1841
|
+
if len(price_list) > 0:
|
|
1842
|
+
price_data = price_list[0]
|
|
1843
|
+
if isinstance(price_data, str):
|
|
1844
|
+
price_data = json.loads(price_data)
|
|
1845
|
+
|
|
1846
|
+
instance_rate_per_hour = extract_instance_rate_per_hour(price_data)
|
|
1847
|
+
if instance_rate_per_hour is not None:
|
|
1848
|
+
return instance_rate_per_hour
|
|
1849
|
+
raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.")
|
|
1850
|
+
|
|
1851
|
+
|
|
1852
|
+
def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
|
|
1853
|
+
"""Extract instance rate per hour for the given Price JSON data.
|
|
1854
|
+
|
|
1855
|
+
Args:
|
|
1856
|
+
price_data (Dict[str, Any]): The Price JSON data.
|
|
1857
|
+
Returns:
|
|
1858
|
+
Optional[Dict[str, str], None]: Instance rate per hour.
|
|
1859
|
+
"""
|
|
1860
|
+
|
|
1861
|
+
if price_data is not None:
|
|
1862
|
+
price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values()
|
|
1863
|
+
for dimension in price_dimensions:
|
|
1864
|
+
for price in dimension.get("priceDimensions", {}).values():
|
|
1865
|
+
for currency in price.get("pricePerUnit", {}).keys():
|
|
1866
|
+
value = price.get("pricePerUnit", {}).get(currency)
|
|
1867
|
+
if value is not None:
|
|
1868
|
+
value = str(round(float(value), 3))
|
|
1869
|
+
return {
|
|
1870
|
+
"unit": f"{currency}/Hr",
|
|
1871
|
+
"value": value,
|
|
1872
|
+
"name": "On-demand Instance Rate",
|
|
1873
|
+
}
|
|
1874
|
+
return None
|
|
1875
|
+
|
|
1876
|
+
|
|
1877
|
+
def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
1878
|
+
"""Iteratively updates a dictionary to convert all keys from snake_case to PascalCase.
|
|
1879
|
+
|
|
1880
|
+
Args:
|
|
1881
|
+
data (dict): The dictionary to be updated.
|
|
1882
|
+
|
|
1883
|
+
Returns:
|
|
1884
|
+
dict: The updated dictionary with keys in PascalCase.
|
|
1885
|
+
"""
|
|
1886
|
+
result = {}
|
|
1887
|
+
|
|
1888
|
+
def convert_key(key):
|
|
1889
|
+
"""Converts a snake_case key to PascalCase."""
|
|
1890
|
+
return "".join(part.capitalize() for part in key.split("_"))
|
|
1891
|
+
|
|
1892
|
+
def convert_value(value):
|
|
1893
|
+
"""Recursively processes the value of a key-value pair."""
|
|
1894
|
+
if isinstance(value, dict):
|
|
1895
|
+
return camel_case_to_pascal_case(value)
|
|
1896
|
+
if isinstance(value, list):
|
|
1897
|
+
return [convert_value(item) for item in value]
|
|
1898
|
+
|
|
1899
|
+
return value
|
|
1900
|
+
|
|
1901
|
+
for key, value in data.items():
|
|
1902
|
+
result[convert_key(key)] = convert_value(value)
|
|
1903
|
+
|
|
1904
|
+
return result
|
|
1905
|
+
|
|
1906
|
+
|
|
1907
|
+
def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool:
|
|
1908
|
+
"""Returns True if ``tag`` already exists.
|
|
1909
|
+
|
|
1910
|
+
Args:
|
|
1911
|
+
tag (TagsDict): The tag dictionary.
|
|
1912
|
+
curr_tags (Optional[Tags]): The current tags.
|
|
1913
|
+
|
|
1914
|
+
Returns:
|
|
1915
|
+
bool: True if the tag exists.
|
|
1916
|
+
"""
|
|
1917
|
+
if curr_tags is None:
|
|
1918
|
+
return False
|
|
1919
|
+
|
|
1920
|
+
for curr_tag in curr_tags:
|
|
1921
|
+
if tag["Key"] == curr_tag["Key"]:
|
|
1922
|
+
return True
|
|
1923
|
+
|
|
1924
|
+
return False
|
|
1925
|
+
|
|
1926
|
+
|
|
1927
|
+
def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]:
|
|
1928
|
+
"""Validates new tags against existing tags.
|
|
1929
|
+
|
|
1930
|
+
Args:
|
|
1931
|
+
new_tags (Optional[Tags]): The new tags.
|
|
1932
|
+
curr_tags (Optional[Tags]): The current tags.
|
|
1933
|
+
|
|
1934
|
+
Returns:
|
|
1935
|
+
Optional[Tags]: The updated tags.
|
|
1936
|
+
"""
|
|
1937
|
+
if curr_tags is None:
|
|
1938
|
+
return new_tags
|
|
1939
|
+
|
|
1940
|
+
if curr_tags and isinstance(curr_tags, dict):
|
|
1941
|
+
curr_tags = [curr_tags]
|
|
1942
|
+
|
|
1943
|
+
if isinstance(new_tags, dict):
|
|
1944
|
+
if not tag_exists(new_tags, curr_tags):
|
|
1945
|
+
curr_tags.append(new_tags)
|
|
1946
|
+
elif isinstance(new_tags, list):
|
|
1947
|
+
for new_tag in new_tags:
|
|
1948
|
+
if not tag_exists(new_tag, curr_tags):
|
|
1949
|
+
curr_tags.append(new_tag)
|
|
1950
|
+
|
|
1951
|
+
return curr_tags
|
|
1952
|
+
|
|
1953
|
+
|
|
1954
|
+
def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
|
|
1955
|
+
"""Remove a tag with the given key from the list of tags.
|
|
1956
|
+
|
|
1957
|
+
Args:
|
|
1958
|
+
key (str): The key of the tag to remove.
|
|
1959
|
+
tags (Optional[Tags]): The current list of tags.
|
|
1960
|
+
|
|
1961
|
+
Returns:
|
|
1962
|
+
Optional[Tags]: The updated list of tags with the tag removed.
|
|
1963
|
+
"""
|
|
1964
|
+
if tags is None:
|
|
1965
|
+
return tags
|
|
1966
|
+
if isinstance(tags, dict):
|
|
1967
|
+
tags = [tags]
|
|
1968
|
+
|
|
1969
|
+
updated_tags = []
|
|
1970
|
+
for tag in tags:
|
|
1971
|
+
if tag["Key"] != key:
|
|
1972
|
+
updated_tags.append(tag)
|
|
1973
|
+
|
|
1974
|
+
if not updated_tags:
|
|
1975
|
+
return None
|
|
1976
|
+
if len(updated_tags) == 1:
|
|
1977
|
+
return updated_tags[0]
|
|
1978
|
+
return updated_tags
|
|
1979
|
+
|
|
1980
|
+
|
|
1981
|
+
def get_domain_for_region(region: str) -> str:
|
|
1982
|
+
"""Returns the domain for the given region.
|
|
1983
|
+
|
|
1984
|
+
Args:
|
|
1985
|
+
region (str): AWS region name.
|
|
1986
|
+
"""
|
|
1987
|
+
return ALTERNATE_DOMAINS.get(region, "amazonaws.com")
|
|
1988
|
+
|
|
1989
|
+
|
|
1990
|
+
def camel_to_snake(camel_case_string: str) -> str:
|
|
1991
|
+
"""Converts camelCase to snake_case_string using a regex.
|
|
1992
|
+
|
|
1993
|
+
This regex cannot handle whitespace ("camelString TwoWords")
|
|
1994
|
+
"""
|
|
1995
|
+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
|
|
1996
|
+
|
|
1997
|
+
|
|
1998
|
+
def walk_and_apply_json(
|
|
1999
|
+
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
|
|
2000
|
+
) -> Dict[Any, Any]:
|
|
2001
|
+
"""Recursively walks a json object and applies a given function to the keys.
|
|
2002
|
+
|
|
2003
|
+
stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
|
|
2004
|
+
Any children of these keys will not have the application function applied to them.
|
|
2005
|
+
"""
|
|
2006
|
+
|
|
2007
|
+
def _walk_and_apply_json(json_obj, new):
|
|
2008
|
+
if isinstance(json_obj, dict) and isinstance(new, dict):
|
|
2009
|
+
for key, value in json_obj.items():
|
|
2010
|
+
new_key = apply(key)
|
|
2011
|
+
if (stop_keys and new_key not in stop_keys) or stop_keys is None:
|
|
2012
|
+
if isinstance(value, dict):
|
|
2013
|
+
new[new_key] = {}
|
|
2014
|
+
_walk_and_apply_json(value, new=new[new_key])
|
|
2015
|
+
elif isinstance(value, list):
|
|
2016
|
+
new[new_key] = []
|
|
2017
|
+
for item in value:
|
|
2018
|
+
_walk_and_apply_json(item, new=new[new_key])
|
|
2019
|
+
else:
|
|
2020
|
+
new[new_key] = value
|
|
2021
|
+
else:
|
|
2022
|
+
new[new_key] = value
|
|
2023
|
+
elif isinstance(json_obj, dict) and isinstance(new, list):
|
|
2024
|
+
new.append(_walk_and_apply_json(json_obj, new={}))
|
|
2025
|
+
elif isinstance(json_obj, list) and isinstance(new, dict):
|
|
2026
|
+
new.update(json_obj)
|
|
2027
|
+
elif isinstance(json_obj, list) and isinstance(new, list):
|
|
2028
|
+
new.append(json_obj)
|
|
2029
|
+
elif isinstance(json_obj, str) and isinstance(new, list):
|
|
2030
|
+
new.append(json_obj)
|
|
2031
|
+
return new
|
|
2032
|
+
|
|
2033
|
+
return _walk_and_apply_json(json_obj, new={})
|
|
2034
|
+
|
|
2035
|
+
|
|
2036
|
+
def _wait_until(callable_fn, poll=5):
|
|
2037
|
+
"""Placeholder docstring"""
|
|
2038
|
+
elapsed_time = 0
|
|
2039
|
+
result = None
|
|
2040
|
+
while result is None:
|
|
2041
|
+
try:
|
|
2042
|
+
elapsed_time += poll
|
|
2043
|
+
time.sleep(poll)
|
|
2044
|
+
result = callable_fn()
|
|
2045
|
+
except botocore.exceptions.ClientError as err:
|
|
2046
|
+
# For initial 5 mins we accept/pass AccessDeniedException.
|
|
2047
|
+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
|
|
2048
|
+
# access policy based on resource tags, The caveat here is for true AccessDenied
|
|
2049
|
+
# cases the routine will fail after 5 mins
|
|
2050
|
+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
|
|
2051
|
+
logger.warning(
|
|
2052
|
+
"Received AccessDeniedException. This could mean the IAM role does not "
|
|
2053
|
+
"have the resource permissions, in which case please add resource access "
|
|
2054
|
+
"and retry. For cases where the role has tag based resource policy, "
|
|
2055
|
+
"continuing to wait for tag propagation.."
|
|
2056
|
+
)
|
|
2057
|
+
continue
|
|
2058
|
+
raise err
|
|
2059
|
+
return result
|
|
2060
|
+
|
|
2061
|
+
|
|
2062
|
+
def _flush_log_streams(
|
|
2063
|
+
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
|
|
2064
|
+
):
|
|
2065
|
+
"""Placeholder docstring"""
|
|
2066
|
+
if len(stream_names) < instance_count:
|
|
2067
|
+
# Log streams are created whenever a container starts writing to stdout/err, so this list
|
|
2068
|
+
# may be dynamic until we have a stream for every instance.
|
|
2069
|
+
try:
|
|
2070
|
+
streams = client.describe_log_streams(
|
|
2071
|
+
logGroupName=log_group,
|
|
2072
|
+
logStreamNamePrefix=job_name + "/",
|
|
2073
|
+
orderBy="LogStreamName",
|
|
2074
|
+
limit=min(instance_count, 50),
|
|
2075
|
+
)
|
|
2076
|
+
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
|
|
2077
|
+
|
|
2078
|
+
while "nextToken" in streams:
|
|
2079
|
+
streams = client.describe_log_streams(
|
|
2080
|
+
logGroupName=log_group,
|
|
2081
|
+
logStreamNamePrefix=job_name + "/",
|
|
2082
|
+
orderBy="LogStreamName",
|
|
2083
|
+
limit=50,
|
|
2084
|
+
)
|
|
2085
|
+
|
|
2086
|
+
stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
|
|
2087
|
+
|
|
2088
|
+
positions.update(
|
|
2089
|
+
[
|
|
2090
|
+
(s, sagemaker.core.logs.Position(timestamp=0, skip=0))
|
|
2091
|
+
for s in stream_names
|
|
2092
|
+
if s not in positions
|
|
2093
|
+
]
|
|
2094
|
+
)
|
|
2095
|
+
except ClientError as e:
|
|
2096
|
+
# On the very first training job run on an account, there's no log group until
|
|
2097
|
+
# the container starts logging, so ignore any errors thrown about that
|
|
2098
|
+
err = e.response.get("Error", {})
|
|
2099
|
+
if err.get("Code", None) != "ResourceNotFoundException":
|
|
2100
|
+
raise
|
|
2101
|
+
|
|
2102
|
+
if len(stream_names) > 0:
|
|
2103
|
+
if dot:
|
|
2104
|
+
print("")
|
|
2105
|
+
dot = False
|
|
2106
|
+
for idx, event in sagemaker.core.logs.multi_stream_iter(
|
|
2107
|
+
client, log_group, stream_names, positions
|
|
2108
|
+
):
|
|
2109
|
+
color_wrap(idx, event["message"])
|
|
2110
|
+
ts, count = positions[stream_names[idx]]
|
|
2111
|
+
if event["timestamp"] == ts:
|
|
2112
|
+
positions[stream_names[idx]] = sagemaker.core.logs.Position(
|
|
2113
|
+
timestamp=ts, skip=count + 1
|
|
2114
|
+
)
|
|
2115
|
+
else:
|
|
2116
|
+
positions[stream_names[idx]] = sagemaker.core.logs.Position(
|
|
2117
|
+
timestamp=event["timestamp"], skip=1
|
|
2118
|
+
)
|
|
2119
|
+
else:
|
|
2120
|
+
dot = True
|
|
2121
|
+
print(".", end="")
|
|
2122
|
+
sys.stdout.flush()
|
|
2123
|
+
|
|
2124
|
+
|
|
2125
|
+
class LogState(object):
|
|
2126
|
+
"""Placeholder docstring"""
|
|
2127
|
+
|
|
2128
|
+
STARTING = 1
|
|
2129
|
+
WAIT_IN_PROGRESS = 2
|
|
2130
|
+
TAILING = 3
|
|
2131
|
+
JOB_COMPLETE = 4
|
|
2132
|
+
COMPLETE = 5
|
|
2133
|
+
|
|
2134
|
+
|
|
2135
|
+
_STATUS_CODE_TABLE = {
|
|
2136
|
+
"COMPLETED": "Completed",
|
|
2137
|
+
"INPROGRESS": "InProgress",
|
|
2138
|
+
"IN_PROGRESS": "InProgress",
|
|
2139
|
+
"FAILED": "Failed",
|
|
2140
|
+
"STOPPED": "Stopped",
|
|
2141
|
+
"STOPPING": "Stopping",
|
|
2142
|
+
"STARTING": "Starting",
|
|
2143
|
+
"PENDING": "Pending",
|
|
2144
|
+
}
|
|
2145
|
+
|
|
2146
|
+
|
|
2147
|
+
def _get_initial_job_state(description, status_key, wait):
|
|
2148
|
+
"""Placeholder docstring"""
|
|
2149
|
+
status = description[status_key]
|
|
2150
|
+
job_already_completed = status in ("Completed", "Failed", "Stopped")
|
|
2151
|
+
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
|
|
2152
|
+
|
|
2153
|
+
|
|
2154
|
+
def _logs_init(boto_session, description, job):
|
|
2155
|
+
"""Placeholder docstring"""
|
|
2156
|
+
if job == "Training":
|
|
2157
|
+
if "InstanceGroups" in description["ResourceConfig"]:
|
|
2158
|
+
instance_count = 0
|
|
2159
|
+
for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
|
|
2160
|
+
instance_count += instanceGroup["InstanceCount"]
|
|
2161
|
+
else:
|
|
2162
|
+
instance_count = description["ResourceConfig"]["InstanceCount"]
|
|
2163
|
+
elif job == "Transform":
|
|
2164
|
+
instance_count = description["TransformResources"]["InstanceCount"]
|
|
2165
|
+
elif job == "Processing":
|
|
2166
|
+
instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
|
|
2167
|
+
elif job == "AutoML":
|
|
2168
|
+
instance_count = 0
|
|
2169
|
+
|
|
2170
|
+
stream_names = [] # The list of log streams
|
|
2171
|
+
positions = {} # The current position in each stream, map of stream name -> position
|
|
2172
|
+
|
|
2173
|
+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
|
|
2174
|
+
# to be interrupted by a transient exception.
|
|
2175
|
+
config = botocore.config.Config(retries={"max_attempts": 15})
|
|
2176
|
+
client = boto_session.client("logs", config=config)
|
|
2177
|
+
log_group = "/aws/sagemaker/" + job + "Jobs"
|
|
2178
|
+
|
|
2179
|
+
dot = False
|
|
2180
|
+
|
|
2181
|
+
color_wrap = sagemaker.core.logs.ColorWrap()
|
|
2182
|
+
|
|
2183
|
+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
|
|
2184
|
+
|
|
2185
|
+
|
|
2186
|
+
def _check_job_status(job, desc, status_key_name):
|
|
2187
|
+
"""Check to see if the job completed successfully.
|
|
2188
|
+
|
|
2189
|
+
If not, construct and raise a exceptions. (UnexpectedStatusException).
|
|
2190
|
+
|
|
2191
|
+
Args:
|
|
2192
|
+
job (str): The name of the job to check.
|
|
2193
|
+
desc (dict[str, str]): The result of ``describe_training_job()``.
|
|
2194
|
+
status_key_name (str): Status key name to check for.
|
|
2195
|
+
|
|
2196
|
+
Raises:
|
|
2197
|
+
exceptions.CapacityError: If the training job fails with CapacityError.
|
|
2198
|
+
exceptions.UnexpectedStatusException: If the training job fails.
|
|
2199
|
+
"""
|
|
2200
|
+
status = desc[status_key_name]
|
|
2201
|
+
# If the status is capital case, then convert it to Camel case
|
|
2202
|
+
status = _STATUS_CODE_TABLE.get(status, status)
|
|
2203
|
+
|
|
2204
|
+
if status == "Stopped":
|
|
2205
|
+
logger.warning(
|
|
2206
|
+
"Job ended with status 'Stopped' rather than 'Completed'. "
|
|
2207
|
+
"This could mean the job timed out or stopped early for some other reason: "
|
|
2208
|
+
"Consider checking whether it completed as you expect."
|
|
2209
|
+
)
|
|
2210
|
+
elif status != "Completed":
|
|
2211
|
+
reason = desc.get("FailureReason", "(No reason provided)")
|
|
2212
|
+
job_type = status_key_name.replace("JobStatus", " job")
|
|
2213
|
+
troubleshooting = (
|
|
2214
|
+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
|
|
2215
|
+
"sagemaker-python-sdk-troubleshooting.html"
|
|
2216
|
+
)
|
|
2217
|
+
message = (
|
|
2218
|
+
"Error for {job_type} {job_name}: {status}. Reason: {reason}. "
|
|
2219
|
+
"Check troubleshooting guide for common errors: {troubleshooting}"
|
|
2220
|
+
).format(
|
|
2221
|
+
job_type=job_type,
|
|
2222
|
+
job_name=job,
|
|
2223
|
+
status=status,
|
|
2224
|
+
reason=reason,
|
|
2225
|
+
troubleshooting=troubleshooting,
|
|
2226
|
+
)
|
|
2227
|
+
if "CapacityError" in str(reason):
|
|
2228
|
+
raise exceptions.CapacityError(
|
|
2229
|
+
message=message,
|
|
2230
|
+
allowed_statuses=["Completed", "Stopped"],
|
|
2231
|
+
actual_status=status,
|
|
2232
|
+
)
|
|
2233
|
+
raise exceptions.UnexpectedStatusException(
|
|
2234
|
+
message=message,
|
|
2235
|
+
allowed_statuses=["Completed", "Stopped"],
|
|
2236
|
+
actual_status=status,
|
|
2237
|
+
)
|
|
2238
|
+
|
|
2239
|
+
|
|
2240
|
+
def _create_resource(create_fn):
|
|
2241
|
+
"""Call create function and accepts/pass when resource already exists.
|
|
2242
|
+
|
|
2243
|
+
This is a helper function to use an existing resource if found when creating.
|
|
2244
|
+
|
|
2245
|
+
Args:
|
|
2246
|
+
create_fn: Create resource function.
|
|
2247
|
+
|
|
2248
|
+
Returns:
|
|
2249
|
+
(bool): True if new resource was created, False if resource already exists.
|
|
2250
|
+
"""
|
|
2251
|
+
try:
|
|
2252
|
+
create_fn()
|
|
2253
|
+
# create function succeeded, resource does not exist already
|
|
2254
|
+
return True
|
|
2255
|
+
except ClientError as ce:
|
|
2256
|
+
error_code = ce.response["Error"]["Code"]
|
|
2257
|
+
error_message = ce.response["Error"]["Message"]
|
|
2258
|
+
already_exists_exceptions = ["ValidationException", "ResourceInUse"]
|
|
2259
|
+
already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
|
|
2260
|
+
if not (
|
|
2261
|
+
error_code in already_exists_exceptions
|
|
2262
|
+
and any(p in error_message for p in already_exists_msg_patterns)
|
|
2263
|
+
):
|
|
2264
|
+
raise ce
|
|
2265
|
+
# no new resource created as resource already exists
|
|
2266
|
+
return False
|
|
2267
|
+
|
|
2268
|
+
|
|
2269
|
+
def _is_s3_uri(s3_uri: Optional[str]) -> bool:
|
|
2270
|
+
"""Checks whether an S3 URI is valid.
|
|
2271
|
+
|
|
2272
|
+
Args:
|
|
2273
|
+
s3_uri (Optional[str]): The S3 URI.
|
|
2274
|
+
|
|
2275
|
+
Returns:
|
|
2276
|
+
bool: Whether the S3 URI is valid.
|
|
2277
|
+
"""
|
|
2278
|
+
if s3_uri is None:
|
|
2279
|
+
return False
|
|
2280
|
+
|
|
2281
|
+
return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None
|