sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
# sagemaker_config.py
|
|
2
|
+
|
|
3
|
+
import pathlib
|
|
4
|
+
import copy
|
|
5
|
+
import inspect
|
|
6
|
+
import os
|
|
7
|
+
from typing import List, Optional
|
|
8
|
+
import boto3
|
|
9
|
+
import yaml
|
|
10
|
+
import jsonschema
|
|
11
|
+
from platformdirs import site_config_dir, user_config_dir
|
|
12
|
+
from botocore.utils import merge_dicts
|
|
13
|
+
from six.moves.urllib.parse import urlparse
|
|
14
|
+
from sagemaker.core.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
|
|
15
|
+
from sagemaker.core.config.config_utils import (
|
|
16
|
+
non_repeating_log_factory,
|
|
17
|
+
get_sagemaker_config_logger,
|
|
18
|
+
_log_sagemaker_config_single_substitution,
|
|
19
|
+
_log_sagemaker_config_merge,
|
|
20
|
+
)
|
|
21
|
+
from functools import lru_cache
|
|
22
|
+
|
|
23
|
+
logger = get_sagemaker_config_logger()
|
|
24
|
+
log_info_function = non_repeating_log_factory(logger, "info")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SageMakerConfig:
|
|
28
|
+
_APP_NAME = "sagemaker"
|
|
29
|
+
_CONFIG_FILE_NAME = "config.yaml"
|
|
30
|
+
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
|
|
31
|
+
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
|
|
32
|
+
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH = os.path.join(
|
|
33
|
+
os.path.expanduser("~"), ".sagemaker", _CONFIG_FILE_NAME
|
|
34
|
+
)
|
|
35
|
+
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
|
|
36
|
+
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
|
|
37
|
+
S3_PREFIX = "s3://"
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
self.logger = get_sagemaker_config_logger()
|
|
41
|
+
self.log_info_function = non_repeating_log_factory(self.logger, "info")
|
|
42
|
+
|
|
43
|
+
def load_sagemaker_config(
|
|
44
|
+
self,
|
|
45
|
+
additional_config_paths: Optional[List[str]] = None,
|
|
46
|
+
s3_resource=None,
|
|
47
|
+
repeat_log: bool = False,
|
|
48
|
+
) -> dict:
|
|
49
|
+
default_config_path = os.getenv(
|
|
50
|
+
self.ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, self._DEFAULT_ADMIN_CONFIG_FILE_PATH
|
|
51
|
+
)
|
|
52
|
+
user_config_path = os.getenv(
|
|
53
|
+
self.ENV_VARIABLE_USER_CONFIG_OVERRIDE, self._DEFAULT_USER_CONFIG_FILE_PATH
|
|
54
|
+
)
|
|
55
|
+
config_paths = [default_config_path, user_config_path]
|
|
56
|
+
if additional_config_paths:
|
|
57
|
+
config_paths += additional_config_paths
|
|
58
|
+
config_paths = list(filter(lambda item: item is not None, config_paths))
|
|
59
|
+
merged_config = {}
|
|
60
|
+
|
|
61
|
+
log_info = self.log_info_function
|
|
62
|
+
if repeat_log:
|
|
63
|
+
log_info = self.logger.info
|
|
64
|
+
|
|
65
|
+
for file_path in config_paths:
|
|
66
|
+
config_from_file = {}
|
|
67
|
+
if file_path.startswith(self.S3_PREFIX):
|
|
68
|
+
config_from_file = self._load_config_from_s3(file_path, s3_resource)
|
|
69
|
+
else:
|
|
70
|
+
try:
|
|
71
|
+
config_from_file = self._load_config_from_file(file_path)
|
|
72
|
+
except ValueError as error:
|
|
73
|
+
if file_path not in (
|
|
74
|
+
self._DEFAULT_ADMIN_CONFIG_FILE_PATH,
|
|
75
|
+
self._DEFAULT_USER_CONFIG_FILE_PATH,
|
|
76
|
+
):
|
|
77
|
+
raise
|
|
78
|
+
self.logger.debug(error)
|
|
79
|
+
if config_from_file:
|
|
80
|
+
self.validate_sagemaker_config(config_from_file)
|
|
81
|
+
merge_dicts(merged_config, config_from_file)
|
|
82
|
+
log_info("Fetched defaults config from location: %s", file_path)
|
|
83
|
+
else:
|
|
84
|
+
log_info("Not applying SDK defaults from location: %s", file_path)
|
|
85
|
+
|
|
86
|
+
return merged_config
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def validate_sagemaker_config(sagemaker_config: Optional[dict] = None):
|
|
90
|
+
jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
|
|
91
|
+
|
|
92
|
+
def load_local_mode_config(self) -> Optional[dict]:
|
|
93
|
+
try:
|
|
94
|
+
content = self._load_config_from_file(self._DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
|
|
95
|
+
except ValueError:
|
|
96
|
+
content = None
|
|
97
|
+
return content
|
|
98
|
+
|
|
99
|
+
def _load_config_from_file(self, file_path: str) -> dict:
|
|
100
|
+
inferred_file_path = file_path
|
|
101
|
+
if os.path.isdir(file_path):
|
|
102
|
+
inferred_file_path = os.path.join(file_path, self._CONFIG_FILE_NAME)
|
|
103
|
+
if not os.path.exists(inferred_file_path):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Unable to load the config file from the location: {file_path}"
|
|
106
|
+
f"Provide a valid file path"
|
|
107
|
+
)
|
|
108
|
+
self.logger.debug("Fetching defaults config from location: %s", file_path)
|
|
109
|
+
with open(inferred_file_path, "r") as f:
|
|
110
|
+
content = yaml.safe_load(f)
|
|
111
|
+
return content
|
|
112
|
+
|
|
113
|
+
def _load_config_from_s3(self, s3_uri, s3_resource_for_config) -> dict:
|
|
114
|
+
if not s3_resource_for_config:
|
|
115
|
+
boto_session = boto3.DEFAULT_SESSION or boto3.Session()
|
|
116
|
+
boto_region_name = boto_session.region_name
|
|
117
|
+
if boto_region_name is None:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"Must setup local AWS configuration with a region supported by SageMaker."
|
|
120
|
+
)
|
|
121
|
+
s3_resource_for_config = boto_session.resource("s3", region_name=boto_region_name)
|
|
122
|
+
|
|
123
|
+
self.logger.debug("Fetching defaults config from location: %s", s3_uri)
|
|
124
|
+
inferred_s3_uri = self._get_inferred_s3_uri(s3_uri, s3_resource_for_config)
|
|
125
|
+
parsed_url = urlparse(inferred_s3_uri)
|
|
126
|
+
bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/")
|
|
127
|
+
s3_object = s3_resource_for_config.Object(bucket, key_prefix)
|
|
128
|
+
s3_file_content = s3_object.get()["Body"].read()
|
|
129
|
+
return yaml.safe_load(s3_file_content.decode("utf-8"))
|
|
130
|
+
|
|
131
|
+
def _get_inferred_s3_uri(self, s3_uri, s3_resource_for_config):
|
|
132
|
+
parsed_url = urlparse(s3_uri)
|
|
133
|
+
bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/")
|
|
134
|
+
s3_bucket = s3_resource_for_config.Bucket(name=bucket)
|
|
135
|
+
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
|
|
136
|
+
s3_files_with_same_prefix = [
|
|
137
|
+
f"{self.S3_PREFIX}{bucket}/{s3_object.key}" for s3_object in s3_objects
|
|
138
|
+
]
|
|
139
|
+
if len(s3_files_with_same_prefix) == 0:
|
|
140
|
+
raise ValueError(f"Provide a valid S3 path instead of {s3_uri}")
|
|
141
|
+
if len(s3_files_with_same_prefix) > 1:
|
|
142
|
+
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, self._CONFIG_FILE_NAME)).replace(
|
|
143
|
+
"s3:/", "s3://"
|
|
144
|
+
)
|
|
145
|
+
if inferred_s3_uri not in s3_files_with_same_prefix:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Provide an S3 URI of a directory that has a {self._CONFIG_FILE_NAME} file."
|
|
148
|
+
)
|
|
149
|
+
return inferred_s3_uri
|
|
150
|
+
return s3_uri
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def get_config_value(key_path, config):
|
|
154
|
+
"""Placeholder Docstring"""
|
|
155
|
+
if config is None:
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
current_section = config
|
|
159
|
+
for key in key_path.split("."):
|
|
160
|
+
if key in current_section:
|
|
161
|
+
current_section = current_section[key]
|
|
162
|
+
else:
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
return current_section
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def get_nested_value(dictionary: dict, nested_keys: List[str]):
|
|
169
|
+
"""Returns a nested value from the given dictionary, and None if none present.
|
|
170
|
+
|
|
171
|
+
Raises
|
|
172
|
+
ValueError if the dictionary structure does not match the nested_keys
|
|
173
|
+
"""
|
|
174
|
+
if (
|
|
175
|
+
dictionary is not None
|
|
176
|
+
and isinstance(dictionary, dict)
|
|
177
|
+
and nested_keys is not None
|
|
178
|
+
and len(nested_keys) > 0
|
|
179
|
+
):
|
|
180
|
+
|
|
181
|
+
current_section = dictionary
|
|
182
|
+
|
|
183
|
+
for key in nested_keys[:-1]:
|
|
184
|
+
current_section = current_section.get(key, None)
|
|
185
|
+
if current_section is None:
|
|
186
|
+
# means the full path of nested_keys doesnt exist in the dictionary
|
|
187
|
+
# or the value was set to None
|
|
188
|
+
return None
|
|
189
|
+
if not isinstance(current_section, dict):
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Unexpected structure of dictionary.",
|
|
192
|
+
"Expected value of type dict at key '{}' but got '{}' for dict '{}'".format(
|
|
193
|
+
key, current_section, dictionary
|
|
194
|
+
),
|
|
195
|
+
)
|
|
196
|
+
return current_section.get(nested_keys[-1], None)
|
|
197
|
+
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object):
|
|
202
|
+
"""Sets a nested value in a dictionary.
|
|
203
|
+
|
|
204
|
+
This sets a nested value inside the given dictionary and returns the new dictionary. Note: if
|
|
205
|
+
provided an unintended list of nested keys, this can overwrite an unexpected part of the dict.
|
|
206
|
+
Recommended to use after a check with get_nested_value first
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
if dictionary is None:
|
|
210
|
+
dictionary = {}
|
|
211
|
+
|
|
212
|
+
if (
|
|
213
|
+
dictionary is not None
|
|
214
|
+
and isinstance(dictionary, dict)
|
|
215
|
+
and nested_keys is not None
|
|
216
|
+
and len(nested_keys) > 0
|
|
217
|
+
):
|
|
218
|
+
current_section = dictionary
|
|
219
|
+
for key in nested_keys[:-1]:
|
|
220
|
+
if (
|
|
221
|
+
key not in current_section
|
|
222
|
+
or current_section[key] is None
|
|
223
|
+
or not isinstance(current_section[key], dict)
|
|
224
|
+
):
|
|
225
|
+
current_section[key] = {}
|
|
226
|
+
current_section = current_section[key]
|
|
227
|
+
|
|
228
|
+
current_section[nested_keys[-1]] = value_to_set
|
|
229
|
+
return dictionary
|
|
230
|
+
|
|
231
|
+
def resolve_value_from_config(
|
|
232
|
+
self,
|
|
233
|
+
direct_input=None,
|
|
234
|
+
config_path: str = None,
|
|
235
|
+
default_value=None,
|
|
236
|
+
sagemaker_session=None,
|
|
237
|
+
sagemaker_config: dict = None,
|
|
238
|
+
):
|
|
239
|
+
"""Decides which value for the caller to use.
|
|
240
|
+
|
|
241
|
+
Note: This method incorporates information from the sagemaker config.
|
|
242
|
+
|
|
243
|
+
Uses this order of prioritization:
|
|
244
|
+
1. direct_input
|
|
245
|
+
2. config value
|
|
246
|
+
3. default_value
|
|
247
|
+
4. None
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
direct_input: The value that the caller of this method starts with. Usually this is an
|
|
251
|
+
input to the caller's class or method.
|
|
252
|
+
config_path (str): A string denoting the path used to lookup the value in the
|
|
253
|
+
sagemaker config.
|
|
254
|
+
default_value: The value used if not present elsewhere.
|
|
255
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
|
|
256
|
+
SageMaker interactions (default: None).
|
|
257
|
+
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
|
|
258
|
+
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
|
|
259
|
+
be checked for the config value if (and only if) sagemaker_session is None. This
|
|
260
|
+
parameter exists for the rare cases where the user provided no Session but a default
|
|
261
|
+
Session cannot be initialized before config injection is needed. In that case,
|
|
262
|
+
the config dictionary may be loaded and passed here before a default Session object
|
|
263
|
+
is created.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
The value that should be used by the caller
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
config_value = (
|
|
270
|
+
self.get_sagemaker_config_value(
|
|
271
|
+
sagemaker_session, config_path, sagemaker_config=sagemaker_config
|
|
272
|
+
)
|
|
273
|
+
if config_path
|
|
274
|
+
else None
|
|
275
|
+
)
|
|
276
|
+
_log_sagemaker_config_single_substitution(direct_input, config_value, config_path)
|
|
277
|
+
|
|
278
|
+
if direct_input is not None:
|
|
279
|
+
return direct_input
|
|
280
|
+
|
|
281
|
+
if config_value is not None:
|
|
282
|
+
return config_value
|
|
283
|
+
|
|
284
|
+
return default_value
|
|
285
|
+
|
|
286
|
+
def get_sagemaker_config_value(self, sagemaker_session, key, sagemaker_config: dict = None):
|
|
287
|
+
"""Returns the value that corresponds to the provided key from the configuration file.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
key: Key Path of the config file entry.
|
|
291
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
|
|
292
|
+
SageMaker interactions.
|
|
293
|
+
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
|
|
294
|
+
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
|
|
295
|
+
be checked for the config value if (and only if) sagemaker_session is None. This
|
|
296
|
+
parameter exists for the rare cases where no Session provided but a default Session
|
|
297
|
+
cannot be initialized before config injection is needed. In that case, the config
|
|
298
|
+
dictionary may be loaded and passed here before a default Session object is created.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
object: The corresponding default value in the configuration file.
|
|
302
|
+
"""
|
|
303
|
+
if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"):
|
|
304
|
+
config_to_check = sagemaker_session.sagemaker_config
|
|
305
|
+
else:
|
|
306
|
+
config_to_check = sagemaker_config
|
|
307
|
+
|
|
308
|
+
if not config_to_check:
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
self.validate_sagemaker_config(config_to_check)
|
|
312
|
+
config_value = self.get_config_value(key, config_to_check)
|
|
313
|
+
# Copy the value so any modifications to the output will not modify the source config
|
|
314
|
+
return copy.deepcopy(config_value)
|
|
315
|
+
|
|
316
|
+
def resolve_class_attribute_from_config(
|
|
317
|
+
self,
|
|
318
|
+
clazz: Optional[type],
|
|
319
|
+
instance: Optional[object],
|
|
320
|
+
attribute: str,
|
|
321
|
+
config_path: str,
|
|
322
|
+
default_value=None,
|
|
323
|
+
sagemaker_session=None,
|
|
324
|
+
):
|
|
325
|
+
"""Utility method that merges config values to data classes.
|
|
326
|
+
|
|
327
|
+
Takes an instance of a class and, if not already set, sets the instance's attribute to a
|
|
328
|
+
value fetched from the sagemaker_config or the default_value.
|
|
329
|
+
|
|
330
|
+
Uses this order of prioritization to determine what the value of the attribute should be:
|
|
331
|
+
1. current value of attribute
|
|
332
|
+
2. config value
|
|
333
|
+
3. default_value
|
|
334
|
+
4. does not set it
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the
|
|
338
|
+
instance is None. If None is provided here, no new object will be created
|
|
339
|
+
if 'instance' doesnt exist. Note: if provided, the constructor should set default
|
|
340
|
+
values to None; Otherwise, the constructor's non-None default will be left
|
|
341
|
+
as-is even if a config value was defined.
|
|
342
|
+
instance (Optional[object]): instance of the Class 'clazz' that has an attribute
|
|
343
|
+
of 'attribute' to set
|
|
344
|
+
attribute (str): attribute of the instance to set if not already set
|
|
345
|
+
config_path (str): a string denoting the path to use to lookup the config value in the
|
|
346
|
+
sagemaker config
|
|
347
|
+
default_value: the value to use if not present elsewhere
|
|
348
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
|
|
349
|
+
SageMaker interactions (default: None).
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
The updated class instance that should be used by the caller instead of the
|
|
353
|
+
'instance' parameter that was passed in.
|
|
354
|
+
"""
|
|
355
|
+
config_value = self.get_sagemaker_config_value(sagemaker_session, config_path)
|
|
356
|
+
|
|
357
|
+
if config_value is None and default_value is None:
|
|
358
|
+
# return instance unmodified. Could be None or populated
|
|
359
|
+
return instance
|
|
360
|
+
|
|
361
|
+
if instance is None:
|
|
362
|
+
if clazz is None or not inspect.isclass(clazz):
|
|
363
|
+
return instance
|
|
364
|
+
# construct a new instance if the instance does not exist
|
|
365
|
+
instance = clazz()
|
|
366
|
+
|
|
367
|
+
if not hasattr(instance, attribute):
|
|
368
|
+
raise TypeError(
|
|
369
|
+
"Unexpected structure of object.",
|
|
370
|
+
"Expected attribute {} to be present inside instance {} of class {}".format(
|
|
371
|
+
attribute, instance, clazz
|
|
372
|
+
),
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
current_value = getattr(instance, attribute)
|
|
376
|
+
if current_value is None:
|
|
377
|
+
# only set value if object does not already have a value set
|
|
378
|
+
if config_value is not None:
|
|
379
|
+
setattr(instance, attribute, config_value)
|
|
380
|
+
elif default_value is not None:
|
|
381
|
+
setattr(instance, attribute, default_value)
|
|
382
|
+
|
|
383
|
+
_log_sagemaker_config_single_substitution(current_value, config_value, config_path)
|
|
384
|
+
|
|
385
|
+
return instance
|
|
386
|
+
|
|
387
|
+
def resolve_nested_dict_value_from_config(
|
|
388
|
+
self,
|
|
389
|
+
dictionary: dict,
|
|
390
|
+
nested_keys: List[str],
|
|
391
|
+
config_path: str,
|
|
392
|
+
default_value: object = None,
|
|
393
|
+
sagemaker_session=None,
|
|
394
|
+
):
|
|
395
|
+
"""Utility method that sets the value of a key path in a nested dictionary .
|
|
396
|
+
|
|
397
|
+
This method takes a dictionary and, if not already set, sets the value for the provided
|
|
398
|
+
list of nested keys to the value fetched from the sagemaker_config or the default_value.
|
|
399
|
+
|
|
400
|
+
Uses this order of prioritization to determine what the value of the attribute should be:
|
|
401
|
+
(1) current value of nested key, (2) config value, (3) default_value, (4) does not set it
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
dictionary: The dict to update.
|
|
405
|
+
nested_keys: The paths of keys where the value should be checked and set if needed.
|
|
406
|
+
config_path (str): A string denoting the path used to find the config value in the
|
|
407
|
+
sagemaker config.
|
|
408
|
+
default_value: The value to use if not present elsewhere.
|
|
409
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
|
|
410
|
+
SageMaker interactions (default: None).
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
The updated dictionary that should be used by the caller instead of the
|
|
414
|
+
'dictionary' parameter that was passed in.
|
|
415
|
+
"""
|
|
416
|
+
config_value = self.get_sagemaker_config_value(sagemaker_session, config_path)
|
|
417
|
+
|
|
418
|
+
if config_value is None and default_value is None:
|
|
419
|
+
# if there is nothing to set, return early. And there is no need to traverse through
|
|
420
|
+
# the dictionary or add nested dicts to it
|
|
421
|
+
return dictionary
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
current_nested_value = self.get_nested_value(dictionary, nested_keys)
|
|
425
|
+
except ValueError as e:
|
|
426
|
+
logger.error("Failed to check dictionary for applying sagemaker config: %s", e)
|
|
427
|
+
return dictionary
|
|
428
|
+
|
|
429
|
+
if current_nested_value is None:
|
|
430
|
+
# only set value if not already set
|
|
431
|
+
if config_value is not None:
|
|
432
|
+
dictionary = self.set_nested_value(dictionary, nested_keys, config_value)
|
|
433
|
+
elif default_value is not None:
|
|
434
|
+
dictionary = self.set_nested_value(dictionary, nested_keys, default_value)
|
|
435
|
+
|
|
436
|
+
_log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path)
|
|
437
|
+
|
|
438
|
+
return dictionary
|
|
439
|
+
|
|
440
|
+
def update_list_of_dicts_with_values_from_config(
|
|
441
|
+
self,
|
|
442
|
+
input_list,
|
|
443
|
+
config_key_path,
|
|
444
|
+
required_key_paths: List[str] = None,
|
|
445
|
+
union_key_paths: List[List[str]] = None,
|
|
446
|
+
sagemaker_session=None,
|
|
447
|
+
):
|
|
448
|
+
"""Updates a list of dictionaries with missing values that are present in Config.
|
|
449
|
+
|
|
450
|
+
In some cases, config file might introduce new parameters which requires certain other
|
|
451
|
+
parameters to be provided as part of the input list. Without those parameters, the underlying
|
|
452
|
+
service will throw an exception. This method provides the capability to specify required key
|
|
453
|
+
paths.
|
|
454
|
+
|
|
455
|
+
In some other cases, config file might introduce new parameters but the service API requires
|
|
456
|
+
either an existing parameter or the new parameter that was supplied by config but not both
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
input_list: The input list that was provided as a method parameter.
|
|
460
|
+
config_key_path: The Key Path in the Config file that corresponds to the input_list
|
|
461
|
+
parameter.
|
|
462
|
+
required_key_paths (List[str]): List of required key paths that should be verified in the
|
|
463
|
+
merged output. If a required key path is missing, we will not perform the merge for that
|
|
464
|
+
item.
|
|
465
|
+
union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify
|
|
466
|
+
whether exactly zero/one of the parameters exist.
|
|
467
|
+
For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or
|
|
468
|
+
neither but not both, then pass [['X1', 'X2']]
|
|
469
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
|
|
470
|
+
SageMaker interactions (default: None).
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
No output. In place merge happens.
|
|
474
|
+
"""
|
|
475
|
+
if not input_list:
|
|
476
|
+
return
|
|
477
|
+
inputs_copy = copy.deepcopy(input_list)
|
|
478
|
+
inputs_from_config = (
|
|
479
|
+
self.get_sagemaker_config_value(sagemaker_session, config_key_path) or []
|
|
480
|
+
)
|
|
481
|
+
unmodified_inputs_from_config = copy.deepcopy(inputs_from_config)
|
|
482
|
+
|
|
483
|
+
for i in range(min(len(input_list), len(inputs_from_config))):
|
|
484
|
+
dict_from_inputs = input_list[i]
|
|
485
|
+
dict_from_config = inputs_from_config[i]
|
|
486
|
+
merge_dicts(dict_from_config, dict_from_inputs)
|
|
487
|
+
# Check if required key paths are present in merged dict (dict_from_config)
|
|
488
|
+
required_key_path_check_passed = self._validate_required_paths_in_a_dict(
|
|
489
|
+
dict_from_config, required_key_paths
|
|
490
|
+
)
|
|
491
|
+
if not required_key_path_check_passed:
|
|
492
|
+
# Don't do the merge, config is introducing a new parameter which needs a
|
|
493
|
+
# corresponding required parameter.
|
|
494
|
+
continue
|
|
495
|
+
union_key_path_check_passed = self._validate_union_key_paths_in_a_dict(
|
|
496
|
+
dict_from_config, union_key_paths
|
|
497
|
+
)
|
|
498
|
+
if not union_key_path_check_passed:
|
|
499
|
+
# Don't do the merge, Union parameters are not obeyed.
|
|
500
|
+
continue
|
|
501
|
+
input_list[i] = dict_from_config
|
|
502
|
+
|
|
503
|
+
_log_sagemaker_config_merge(
|
|
504
|
+
source_value=inputs_copy,
|
|
505
|
+
config_value=unmodified_inputs_from_config,
|
|
506
|
+
merged_source_and_config_value=input_list,
|
|
507
|
+
config_key_path=config_key_path,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
def _validate_required_paths_in_a_dict(
|
|
511
|
+
self, source_dict, required_key_paths: List[str] = None
|
|
512
|
+
) -> bool:
|
|
513
|
+
"""Placeholder docstring"""
|
|
514
|
+
if not required_key_paths:
|
|
515
|
+
return True
|
|
516
|
+
for required_key_path in required_key_paths:
|
|
517
|
+
if self.get_config_value(required_key_path, source_dict) is None:
|
|
518
|
+
return False
|
|
519
|
+
return True
|
|
520
|
+
|
|
521
|
+
def _validate_union_key_paths_in_a_dict(
|
|
522
|
+
self, source_dict, union_key_paths: List[List[str]] = None
|
|
523
|
+
) -> bool:
|
|
524
|
+
"""Placeholder docstring"""
|
|
525
|
+
if not union_key_paths:
|
|
526
|
+
return True
|
|
527
|
+
for union_key_path in union_key_paths:
|
|
528
|
+
union_parameter_present = False
|
|
529
|
+
for key_path in union_key_path:
|
|
530
|
+
if self.get_config_value(key_path, source_dict):
|
|
531
|
+
if union_parameter_present:
|
|
532
|
+
return False
|
|
533
|
+
union_parameter_present = True
|
|
534
|
+
return True
|
|
535
|
+
|
|
536
|
+
def update_nested_dictionary_with_values_from_config(
|
|
537
|
+
self, source_dict, config_key_path, sagemaker_session=None
|
|
538
|
+
) -> dict:
|
|
539
|
+
"""Updates a nested dictionary with missing values that are present in Config.
|
|
540
|
+
|
|
541
|
+
Args:
|
|
542
|
+
source_dict: The input nested dictionary that was provided as method parameter.
|
|
543
|
+
config_key_path: The Key Path in the Config file which corresponds to this
|
|
544
|
+
source_dict parameter.
|
|
545
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
|
|
546
|
+
SageMaker interactions (default: None).
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
dict: The merged nested dictionary that is updated with missing values that are present
|
|
550
|
+
in the Config file.
|
|
551
|
+
"""
|
|
552
|
+
inferred_config_dict = (
|
|
553
|
+
self.get_sagemaker_config_value(sagemaker_session, config_key_path) or {}
|
|
554
|
+
)
|
|
555
|
+
original_config_dict_value = copy.deepcopy(inferred_config_dict)
|
|
556
|
+
merge_dicts(inferred_config_dict, source_dict or {})
|
|
557
|
+
|
|
558
|
+
if original_config_dict_value == {}:
|
|
559
|
+
# The config value is empty. That means either
|
|
560
|
+
# (1) inferred_config_dict equals source_dict, or
|
|
561
|
+
# (2) if source_dict was None, inferred_config_dict equals {}
|
|
562
|
+
# We should return whatever source_dict was to be safe. Because if for example,
|
|
563
|
+
# a VpcConfig is set to {} instead of None, some boto calls will fail due to
|
|
564
|
+
# ParamValidationError (because a VpcConfig was specified but required parameters for
|
|
565
|
+
# the VpcConfig were missing.)
|
|
566
|
+
|
|
567
|
+
# Don't need to print because no config value was used or defined
|
|
568
|
+
return source_dict
|
|
569
|
+
|
|
570
|
+
_log_sagemaker_config_merge(
|
|
571
|
+
source_value=source_dict,
|
|
572
|
+
config_value=original_config_dict_value,
|
|
573
|
+
merged_source_and_config_value=inferred_config_dict,
|
|
574
|
+
config_key_path=config_key_path,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
return inferred_config_dict
|
|
578
|
+
|
|
579
|
+
@lru_cache(maxsize=None)
|
|
580
|
+
def load_default_configs_for_resource_name(self, resource_name: str):
|
|
581
|
+
configs_data = self.load_sagemaker_config()
|
|
582
|
+
if not configs_data:
|
|
583
|
+
logger.debug("No default configurations found for resource: %s", resource_name)
|
|
584
|
+
return {}
|
|
585
|
+
return configs_data["SageMaker"]["PythonSDK"]["Resources"].get(resource_name)
|
|
586
|
+
|
|
587
|
+
def get_resolved_config_value(self, attribute, resource_defaults, global_defaults):
|
|
588
|
+
if resource_defaults and attribute in resource_defaults:
|
|
589
|
+
return resource_defaults[attribute]
|
|
590
|
+
if global_defaults and attribute in global_defaults:
|
|
591
|
+
return global_defaults[attribute]
|
|
592
|
+
logger.debug(
|
|
593
|
+
f"Configurable value {attribute} not entered in parameters or present in the Config"
|
|
594
|
+
)
|
|
595
|
+
return None
|