sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,2965 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
from __future__ import absolute_import, annotations, print_function
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
21
|
+
import uuid
|
|
22
|
+
import six
|
|
23
|
+
import warnings
|
|
24
|
+
from functools import reduce
|
|
25
|
+
from typing import Dict, Any, Optional, List
|
|
26
|
+
|
|
27
|
+
import boto3
|
|
28
|
+
import botocore
|
|
29
|
+
import botocore.config
|
|
30
|
+
from botocore.exceptions import ClientError
|
|
31
|
+
from sagemaker.core import exceptions
|
|
32
|
+
from sagemaker.core.common_utils import (
|
|
33
|
+
secondary_training_status_changed,
|
|
34
|
+
secondary_training_status_message,
|
|
35
|
+
)
|
|
36
|
+
import sagemaker.core.logs
|
|
37
|
+
from sagemaker.core.session_settings import SessionSettings
|
|
38
|
+
from sagemaker.core.common_utils import (
|
|
39
|
+
secondary_training_status_changed,
|
|
40
|
+
secondary_training_status_message,
|
|
41
|
+
sts_regional_endpoint,
|
|
42
|
+
retries,
|
|
43
|
+
resolve_value_from_config,
|
|
44
|
+
get_sagemaker_config_value,
|
|
45
|
+
resolve_class_attribute_from_config,
|
|
46
|
+
resolve_nested_dict_value_from_config,
|
|
47
|
+
update_nested_dictionary_with_values_from_config,
|
|
48
|
+
update_list_of_dicts_with_values_from_config,
|
|
49
|
+
format_tags,
|
|
50
|
+
Tags,
|
|
51
|
+
TagsDict,
|
|
52
|
+
instance_supports_kms,
|
|
53
|
+
create_paginator_config,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
|
|
57
|
+
from sagemaker.core._studio import _append_project_tags
|
|
58
|
+
from sagemaker.core.config.config import load_sagemaker_config, validate_sagemaker_config
|
|
59
|
+
from sagemaker.core.config.config_schema import (
|
|
60
|
+
KEY,
|
|
61
|
+
TRANSFORM_JOB,
|
|
62
|
+
TRANSFORM_JOB_ENVIRONMENT_PATH,
|
|
63
|
+
TRANSFORM_JOB_KMS_KEY_ID_PATH,
|
|
64
|
+
TRANSFORM_OUTPUT_KMS_KEY_ID_PATH,
|
|
65
|
+
VOLUME_KMS_KEY_ID,
|
|
66
|
+
TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH,
|
|
67
|
+
MODEL,
|
|
68
|
+
MODEL_CONTAINERS_PATH,
|
|
69
|
+
MODEL_EXECUTION_ROLE_ARN_PATH,
|
|
70
|
+
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
|
|
71
|
+
MODEL_PRIMARY_CONTAINER_PATH,
|
|
72
|
+
MODEL_VPC_CONFIG_PATH,
|
|
73
|
+
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
|
|
74
|
+
KMS_KEY_ID,
|
|
75
|
+
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
|
|
76
|
+
ENDPOINT_CONFIG,
|
|
77
|
+
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
|
|
78
|
+
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
|
|
79
|
+
ENDPOINT_CONFIG_VPC_CONFIG_PATH,
|
|
80
|
+
ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
|
|
81
|
+
ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
|
|
82
|
+
ENDPOINT,
|
|
83
|
+
INFERENCE_COMPONENT,
|
|
84
|
+
SAGEMAKER,
|
|
85
|
+
TAGS,
|
|
86
|
+
SESSION_DEFAULT_S3_BUCKET_PATH,
|
|
87
|
+
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Setting LOGGER for backward compatibility, in case users import it...
|
|
91
|
+
logger = LOGGER = logging.getLogger("sagemaker")
|
|
92
|
+
|
|
93
|
+
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
|
|
94
|
+
MODEL_MONITOR_ONE_TIME_SCHEDULE = "NOW"
|
|
95
|
+
_STATUS_CODE_TABLE = {
|
|
96
|
+
"COMPLETED": "Completed",
|
|
97
|
+
"INPROGRESS": "InProgress",
|
|
98
|
+
"IN_PROGRESS": "InProgress",
|
|
99
|
+
"FAILED": "Failed",
|
|
100
|
+
"STOPPED": "Stopped",
|
|
101
|
+
"STOPPING": "Stopping",
|
|
102
|
+
"STARTING": "Starting",
|
|
103
|
+
"PENDING": "Pending",
|
|
104
|
+
}
|
|
105
|
+
EP_LOGGER_POLL = 30
|
|
106
|
+
DEFAULT_EP_POLL = 30
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class LogState(object):
|
|
110
|
+
"""Placeholder docstring"""
|
|
111
|
+
|
|
112
|
+
STARTING = 1
|
|
113
|
+
WAIT_IN_PROGRESS = 2
|
|
114
|
+
TAILING = 3
|
|
115
|
+
JOB_COMPLETE = 4
|
|
116
|
+
COMPLETE = 5
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Session(object): # pylint: disable=too-many-public-methods
|
|
120
|
+
"""Manage interactions with the Amazon SageMaker APIs and any other AWS services needed.
|
|
121
|
+
|
|
122
|
+
This class provides convenient methods for manipulating entities and resources that Amazon
|
|
123
|
+
SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
|
|
124
|
+
AWS service calls are delegated to an underlying Boto3 session, which by default
|
|
125
|
+
is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call
|
|
126
|
+
that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default
|
|
127
|
+
bucket based on a naming convention which includes the current AWS account ID.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
boto_session=None,
|
|
133
|
+
sagemaker_client=None,
|
|
134
|
+
sagemaker_runtime_client=None,
|
|
135
|
+
sagemaker_featurestore_runtime_client=None,
|
|
136
|
+
default_bucket=None,
|
|
137
|
+
sagemaker_config: dict = None,
|
|
138
|
+
settings=None,
|
|
139
|
+
sagemaker_metrics_client=None,
|
|
140
|
+
default_bucket_prefix: str = None,
|
|
141
|
+
):
|
|
142
|
+
"""Initialize a SageMaker ``Session``.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
|
|
146
|
+
calls are delegated to (default: None). If not provided, one is created with
|
|
147
|
+
default AWS configuration chain.
|
|
148
|
+
sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service
|
|
149
|
+
calls other than ``InvokeEndpoint`` (default: None). Estimators created using this
|
|
150
|
+
``Session`` use this client. If not provided, one will be created using this
|
|
151
|
+
instance's ``boto_session``.
|
|
152
|
+
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes
|
|
153
|
+
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
|
|
154
|
+
using this ``Session`` use this client. If not provided, one will be created using
|
|
155
|
+
this instance's ``boto_session``.
|
|
156
|
+
sagemaker_featurestore_runtime_client (boto3.SageMakerFeatureStoreRuntime.Client):
|
|
157
|
+
Client which makes SageMaker FeatureStore record related calls to Amazon SageMaker
|
|
158
|
+
(default: None). If not provided, one will be created using
|
|
159
|
+
this instance's ``boto_session``.
|
|
160
|
+
default_bucket (str): The default Amazon S3 bucket to be used by this session.
|
|
161
|
+
This will be created the next time an Amazon S3 bucket is needed (by calling
|
|
162
|
+
:func:`default_bucket`).
|
|
163
|
+
If not provided, it will be fetched from the sagemaker_config. If not configured
|
|
164
|
+
there either, a default bucket will be created based on the following format:
|
|
165
|
+
"sagemaker-{region}-{aws-account-id}".
|
|
166
|
+
Example: "sagemaker-my-custom-bucket".
|
|
167
|
+
sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
|
|
168
|
+
Client which makes SageMaker Metrics related calls to Amazon SageMaker
|
|
169
|
+
(default: None). If not provided, one will be created using
|
|
170
|
+
this instance's ``boto_session``.
|
|
171
|
+
default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default:
|
|
172
|
+
None). If provided and where applicable, it will be used by the SDK to construct
|
|
173
|
+
default S3 URIs, in the format:
|
|
174
|
+
`s3://{default_bucket}/{default_bucket_prefix}/<rest of object key>`
|
|
175
|
+
This parameter can also be specified via `{sagemaker_config}` instead of here. If
|
|
176
|
+
not provided here or within `{sagemaker_config}`, default S3 URIs will have the
|
|
177
|
+
format: `s3://{default_bucket}/<rest of object key>`
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# sagemaker_config is validated and initialized inside :func:`_initialize`,
|
|
181
|
+
# so if default_bucket is None and the sagemaker_config has a default S3 bucket configured,
|
|
182
|
+
# _default_bucket_name_override will be set again inside :func:`_initialize`.
|
|
183
|
+
self.endpoint_arn = None
|
|
184
|
+
self._default_bucket = None
|
|
185
|
+
self._default_bucket_name_override = default_bucket
|
|
186
|
+
# this may also be set again inside :func:`_initialize` if it is None
|
|
187
|
+
self.default_bucket_prefix = default_bucket_prefix
|
|
188
|
+
self._default_bucket_set_by_sdk = False
|
|
189
|
+
|
|
190
|
+
self.s3_resource = None
|
|
191
|
+
self.s3_client = None
|
|
192
|
+
self.resource_groups_client = None
|
|
193
|
+
self.resource_group_tagging_client = None
|
|
194
|
+
self._config = None
|
|
195
|
+
self.lambda_client = None
|
|
196
|
+
self.settings = settings if settings else SessionSettings()
|
|
197
|
+
|
|
198
|
+
self._initialize(
|
|
199
|
+
boto_session=boto_session,
|
|
200
|
+
sagemaker_client=sagemaker_client,
|
|
201
|
+
sagemaker_runtime_client=sagemaker_runtime_client,
|
|
202
|
+
sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
|
|
203
|
+
sagemaker_metrics_client=sagemaker_metrics_client,
|
|
204
|
+
sagemaker_config=sagemaker_config,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def _initialize(
|
|
208
|
+
self,
|
|
209
|
+
boto_session,
|
|
210
|
+
sagemaker_client,
|
|
211
|
+
sagemaker_runtime_client,
|
|
212
|
+
sagemaker_featurestore_runtime_client,
|
|
213
|
+
sagemaker_metrics_client,
|
|
214
|
+
sagemaker_config: dict = None,
|
|
215
|
+
):
|
|
216
|
+
"""Initialize this SageMaker Session.
|
|
217
|
+
|
|
218
|
+
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
|
|
219
|
+
Sets the region_name.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
|
|
223
|
+
|
|
224
|
+
self._region_name = self.boto_session.region_name
|
|
225
|
+
if self._region_name is None:
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"Must setup local AWS configuration with a region supported by SageMaker."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Make use of user_agent_extra field of the botocore_config object
|
|
231
|
+
# to append SageMaker Python SDK specific user_agent suffix
|
|
232
|
+
# to the current User-Agent header value from boto3
|
|
233
|
+
# This config will also make sure that user_agent never fails to log the User-Agent string
|
|
234
|
+
# even if boto User-Agent header format is updated in the future
|
|
235
|
+
# Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
|
236
|
+
|
|
237
|
+
# Create sagemaker_client with the botocore_config object
|
|
238
|
+
# This config is customized to append SageMaker Python SDK specific user_agent suffix
|
|
239
|
+
if sagemaker_client is not None:
|
|
240
|
+
self.sagemaker_client = sagemaker_client
|
|
241
|
+
else:
|
|
242
|
+
from sagemaker.core.user_agent import get_user_agent_extra_suffix
|
|
243
|
+
config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
|
|
244
|
+
self.sagemaker_client = self.boto_session.client("sagemaker", config=config)
|
|
245
|
+
|
|
246
|
+
if sagemaker_runtime_client is not None:
|
|
247
|
+
self.sagemaker_runtime_client = sagemaker_runtime_client
|
|
248
|
+
else:
|
|
249
|
+
config = botocore.config.Config(read_timeout=80)
|
|
250
|
+
self.sagemaker_runtime_client = self.boto_session.client(
|
|
251
|
+
"runtime.sagemaker", config=config
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if sagemaker_featurestore_runtime_client:
|
|
255
|
+
self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client
|
|
256
|
+
else:
|
|
257
|
+
self.sagemaker_featurestore_runtime_client = self.boto_session.client(
|
|
258
|
+
"sagemaker-featurestore-runtime"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if sagemaker_metrics_client:
|
|
262
|
+
self.sagemaker_metrics_client = sagemaker_metrics_client
|
|
263
|
+
else:
|
|
264
|
+
self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
|
|
265
|
+
|
|
266
|
+
self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name)
|
|
267
|
+
self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name)
|
|
268
|
+
|
|
269
|
+
self.local_mode = False
|
|
270
|
+
|
|
271
|
+
if sagemaker_config:
|
|
272
|
+
validate_sagemaker_config(sagemaker_config)
|
|
273
|
+
self.sagemaker_config = sagemaker_config
|
|
274
|
+
else:
|
|
275
|
+
# self.s3_resource might be None. If it is None, load_sagemaker_config will
|
|
276
|
+
# create a default S3 resource, but only if it needs to fetch from S3
|
|
277
|
+
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
|
|
278
|
+
|
|
279
|
+
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
|
|
280
|
+
self._default_bucket_name_override = resolve_value_from_config(
|
|
281
|
+
direct_input=self._default_bucket_name_override,
|
|
282
|
+
config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
|
|
283
|
+
sagemaker_session=self,
|
|
284
|
+
)
|
|
285
|
+
# after sagemaker_config initialization, update self.default_bucket_prefix if needed
|
|
286
|
+
self.default_bucket_prefix = resolve_value_from_config(
|
|
287
|
+
direct_input=self.default_bucket_prefix,
|
|
288
|
+
config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
|
|
289
|
+
sagemaker_session=self,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def account_id(self) -> str:
|
|
293
|
+
"""Get the AWS account id of the caller.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
AWS account ID.
|
|
297
|
+
"""
|
|
298
|
+
region = self.boto_session.region_name
|
|
299
|
+
sts_client = self.boto_session.client(
|
|
300
|
+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
|
|
301
|
+
)
|
|
302
|
+
return sts_client.get_caller_identity()["Account"]
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def config(self) -> Dict | None:
|
|
306
|
+
"""The config for the local mode, unused in a normal session"""
|
|
307
|
+
return self._config
|
|
308
|
+
|
|
309
|
+
@config.setter
|
|
310
|
+
def config(self, value: Dict | None):
|
|
311
|
+
"""The config for the local mode, unused in a normal session"""
|
|
312
|
+
self._config = value
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def boto_region_name(self):
|
|
316
|
+
"""Placeholder docstring"""
|
|
317
|
+
return self._region_name
|
|
318
|
+
|
|
319
|
+
def get_caller_identity_arn(self):
|
|
320
|
+
"""Returns the ARN user or role whose credentials are used to call the API.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
str: The ARN user or role
|
|
324
|
+
"""
|
|
325
|
+
if os.path.exists(NOTEBOOK_METADATA_FILE):
|
|
326
|
+
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
|
|
327
|
+
metadata = json.loads(f.read())
|
|
328
|
+
instance_name = metadata.get("ResourceName")
|
|
329
|
+
domain_id = metadata.get("DomainId")
|
|
330
|
+
user_profile_name = metadata.get("UserProfileName")
|
|
331
|
+
execution_role_arn = metadata.get("ExecutionRoleArn")
|
|
332
|
+
try:
|
|
333
|
+
if domain_id is None:
|
|
334
|
+
instance_desc = self.sagemaker_client.describe_notebook_instance(
|
|
335
|
+
NotebookInstanceName=instance_name
|
|
336
|
+
)
|
|
337
|
+
return instance_desc["RoleArn"]
|
|
338
|
+
|
|
339
|
+
# find execution role from the metadata file if present
|
|
340
|
+
if execution_role_arn is not None:
|
|
341
|
+
return execution_role_arn
|
|
342
|
+
|
|
343
|
+
user_profile_desc = self.sagemaker_client.describe_user_profile(
|
|
344
|
+
DomainId=domain_id, UserProfileName=user_profile_name
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# First, try to find role in userSettings
|
|
348
|
+
if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"):
|
|
349
|
+
return user_profile_desc["UserSettings"]["ExecutionRole"]
|
|
350
|
+
|
|
351
|
+
# If not found, fallback to the domain
|
|
352
|
+
domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id)
|
|
353
|
+
return domain_desc["DefaultUserSettings"]["ExecutionRole"]
|
|
354
|
+
except ClientError:
|
|
355
|
+
logger.debug(
|
|
356
|
+
"Couldn't call 'describe_notebook_instance' to get the Role "
|
|
357
|
+
"ARN of the instance %s.",
|
|
358
|
+
instance_name,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
assumed_role = self.boto_session.client(
|
|
362
|
+
"sts",
|
|
363
|
+
region_name=self.boto_region_name,
|
|
364
|
+
endpoint_url=sts_regional_endpoint(self.boto_region_name),
|
|
365
|
+
).get_caller_identity()["Arn"]
|
|
366
|
+
|
|
367
|
+
role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)
|
|
368
|
+
|
|
369
|
+
# Call IAM to get the role's path
|
|
370
|
+
role_name = role[role.rfind("/") + 1 :]
|
|
371
|
+
try:
|
|
372
|
+
role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"]
|
|
373
|
+
except ClientError:
|
|
374
|
+
logger.warning(
|
|
375
|
+
"Couldn't call 'get_role' to get Role ARN from role name %s to get Role path.",
|
|
376
|
+
role_name,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# This conditional has been present since the inception of SageMaker
|
|
380
|
+
# Guessing this conditional's purpose was to handle lack of IAM permissions
|
|
381
|
+
# https://github.com/aws/sagemaker-python-sdk/issues/2089#issuecomment-791802713
|
|
382
|
+
if "AmazonSageMaker-ExecutionRole" in assumed_role:
|
|
383
|
+
logger.warning(
|
|
384
|
+
"Assuming role was created in SageMaker AWS console, "
|
|
385
|
+
"as the name contains `AmazonSageMaker-ExecutionRole`. "
|
|
386
|
+
"Defaulting to Role ARN with service-role in path. "
|
|
387
|
+
"If this Role ARN is incorrect, please add "
|
|
388
|
+
"IAM read permissions to your role or supply the "
|
|
389
|
+
"Role Arn directly."
|
|
390
|
+
)
|
|
391
|
+
role = re.sub(
|
|
392
|
+
r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$",
|
|
393
|
+
r"\1iam::\2:role/service-role/\3",
|
|
394
|
+
assumed_role,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return role
|
|
398
|
+
|
|
399
|
+
def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra_args=None):
|
|
400
|
+
"""Upload local file or directory to S3.
|
|
401
|
+
|
|
402
|
+
If a single file is specified for upload, the resulting S3 object key is
|
|
403
|
+
``{key_prefix}/{filename}`` (filename does not include the local path, if any specified).
|
|
404
|
+
If a directory is specified for upload, the API uploads all content, recursively,
|
|
405
|
+
preserving relative structure of subdirectories. The resulting object key names are:
|
|
406
|
+
``{key_prefix}/{relative_subdirectory_path}/filename``.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
path (str): Path (absolute or relative) of local file or directory to upload.
|
|
410
|
+
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
|
|
411
|
+
default bucket of the ``Session`` is used (if default bucket does not exist, the
|
|
412
|
+
``Session`` creates it).
|
|
413
|
+
key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the
|
|
414
|
+
prefix to create a directory structure for the bucket content that it display in
|
|
415
|
+
the S3 console.
|
|
416
|
+
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
|
|
417
|
+
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
|
|
418
|
+
ExtraArgs parameter documentation here:
|
|
419
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
|
|
423
|
+
the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
|
|
424
|
+
If a directory is specified in the path argument, the URI format is
|
|
425
|
+
``s3://{bucket name}/{key_prefix}``.
|
|
426
|
+
"""
|
|
427
|
+
bucket, key_prefix = self.determine_bucket_and_prefix(
|
|
428
|
+
bucket=bucket, key_prefix=key_prefix, sagemaker_session=self
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
|
|
432
|
+
files = []
|
|
433
|
+
key_suffix = None
|
|
434
|
+
if os.path.isdir(path):
|
|
435
|
+
for dirpath, _, filenames in os.walk(path):
|
|
436
|
+
for name in filenames:
|
|
437
|
+
local_path = os.path.join(dirpath, name)
|
|
438
|
+
s3_relative_prefix = (
|
|
439
|
+
"" if path == dirpath else os.path.relpath(dirpath, start=path) + "/"
|
|
440
|
+
)
|
|
441
|
+
s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name)
|
|
442
|
+
files.append((local_path, s3_key))
|
|
443
|
+
else:
|
|
444
|
+
_, name = os.path.split(path)
|
|
445
|
+
s3_key = "{}/{}".format(key_prefix, name)
|
|
446
|
+
files.append((path, s3_key))
|
|
447
|
+
key_suffix = name
|
|
448
|
+
|
|
449
|
+
if self.s3_resource is None:
|
|
450
|
+
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
|
|
451
|
+
else:
|
|
452
|
+
s3 = self.s3_resource
|
|
453
|
+
|
|
454
|
+
for local_path, s3_key in files:
|
|
455
|
+
s3.Object(bucket, s3_key).upload_file(
|
|
456
|
+
local_path, Callback=callback, ExtraArgs=extra_args
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
s3_uri = "s3://{}/{}".format(bucket, key_prefix)
|
|
460
|
+
# If a specific file was used as input (instead of a directory), we return the full S3 key
|
|
461
|
+
# of the uploaded object. This prevents unintentionally using other files under the same
|
|
462
|
+
# prefix during training.
|
|
463
|
+
if key_suffix:
|
|
464
|
+
s3_uri = "{}/{}".format(s3_uri, key_suffix)
|
|
465
|
+
return s3_uri
|
|
466
|
+
|
|
467
|
+
def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
|
|
468
|
+
"""Upload a string as a file body.
|
|
469
|
+
Args:
|
|
470
|
+
body (str): String representing the body of the file.
|
|
471
|
+
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
|
|
472
|
+
default bucket of the ``Session`` is used (if default bucket does not exist, the
|
|
473
|
+
``Session`` creates it).
|
|
474
|
+
key (str): S3 object key. This is the s3 path to the file.
|
|
475
|
+
kms_key (str): The KMS key to use for encrypting the file.
|
|
476
|
+
Returns:
|
|
477
|
+
str: The S3 URI of the uploaded file.
|
|
478
|
+
The URI format is: ``s3://{bucket name}/{key}``.
|
|
479
|
+
"""
|
|
480
|
+
if self.s3_resource is None:
|
|
481
|
+
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
|
|
482
|
+
else:
|
|
483
|
+
s3 = self.s3_resource
|
|
484
|
+
|
|
485
|
+
s3_object = s3.Object(bucket_name=bucket, key=key)
|
|
486
|
+
|
|
487
|
+
if kms_key is not None:
|
|
488
|
+
s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
|
|
489
|
+
else:
|
|
490
|
+
s3_object.put(Body=body)
|
|
491
|
+
|
|
492
|
+
s3_uri = "s3://{}/{}".format(bucket, key)
|
|
493
|
+
return s3_uri
|
|
494
|
+
|
|
495
|
+
def download_data(self, path, bucket, key_prefix="", extra_args=None):
|
|
496
|
+
"""Download file or directory from S3.
|
|
497
|
+
Args:
|
|
498
|
+
path (str): Local path where the file or directory should be downloaded to.
|
|
499
|
+
bucket (str): Name of the S3 Bucket to download from.
|
|
500
|
+
key_prefix (str): Optional S3 object key name prefix.
|
|
501
|
+
extra_args (dict): Optional extra arguments that may be passed to the
|
|
502
|
+
download operation. Please refer to the ExtraArgs parameter in the boto3
|
|
503
|
+
documentation here:
|
|
504
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
|
|
505
|
+
Returns:
|
|
506
|
+
list[str]: List of local paths of downloaded files
|
|
507
|
+
"""
|
|
508
|
+
# Initialize the S3 client.
|
|
509
|
+
if self.s3_client is None:
|
|
510
|
+
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
|
|
511
|
+
else:
|
|
512
|
+
s3 = self.s3_client
|
|
513
|
+
|
|
514
|
+
# Initialize the variables used to loop through the contents of the S3 bucket.
|
|
515
|
+
keys = []
|
|
516
|
+
directories = []
|
|
517
|
+
next_token = ""
|
|
518
|
+
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
|
|
519
|
+
|
|
520
|
+
# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
|
|
521
|
+
# a "keys" list.
|
|
522
|
+
while next_token is not None:
|
|
523
|
+
request_parameters = base_parameters.copy()
|
|
524
|
+
if next_token != "":
|
|
525
|
+
request_parameters.update({"ContinuationToken": next_token})
|
|
526
|
+
response = s3.list_objects_v2(**request_parameters)
|
|
527
|
+
contents = response.get("Contents", None)
|
|
528
|
+
if not contents:
|
|
529
|
+
logger.info(
|
|
530
|
+
"Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
|
|
531
|
+
)
|
|
532
|
+
return []
|
|
533
|
+
# For each object, save its key or directory.
|
|
534
|
+
for s3_object in contents:
|
|
535
|
+
key: str = s3_object.get("Key")
|
|
536
|
+
obj_size = s3_object.get("Size")
|
|
537
|
+
if key.endswith("/") and int(obj_size) == 0:
|
|
538
|
+
directories.append(os.path.join(path, key))
|
|
539
|
+
else:
|
|
540
|
+
keys.append(key)
|
|
541
|
+
next_token = response.get("NextContinuationToken")
|
|
542
|
+
|
|
543
|
+
# For each object key, create the directory on the local machine if needed, and then
|
|
544
|
+
# download the file.
|
|
545
|
+
downloaded_paths = []
|
|
546
|
+
for dir_path in directories:
|
|
547
|
+
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
|
|
548
|
+
for key in keys:
|
|
549
|
+
tail_s3_uri_path = os.path.basename(key)
|
|
550
|
+
if not os.path.splitext(key_prefix)[1]:
|
|
551
|
+
tail_s3_uri_path = os.path.relpath(key, key_prefix)
|
|
552
|
+
destination_path = os.path.join(path, tail_s3_uri_path)
|
|
553
|
+
if not os.path.exists(os.path.dirname(destination_path)):
|
|
554
|
+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
|
|
555
|
+
s3.download_file(
|
|
556
|
+
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
|
|
557
|
+
)
|
|
558
|
+
downloaded_paths.append(destination_path)
|
|
559
|
+
return downloaded_paths
|
|
560
|
+
|
|
561
|
+
def read_s3_file(self, bucket, key_prefix):
|
|
562
|
+
"""Read a single file from S3.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
bucket (str): Name of the S3 Bucket to download from.
|
|
566
|
+
key_prefix (str): S3 object key name prefix.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
str: The body of the s3 file as a string.
|
|
570
|
+
"""
|
|
571
|
+
if self.s3_client is None:
|
|
572
|
+
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
|
|
573
|
+
else:
|
|
574
|
+
s3 = self.s3_client
|
|
575
|
+
|
|
576
|
+
# Explicitly passing a None kms_key to boto3 throws a validation error.
|
|
577
|
+
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
|
|
578
|
+
|
|
579
|
+
return s3_object["Body"].read().decode("utf-8")
|
|
580
|
+
|
|
581
|
+
def list_s3_files(self, bucket, key_prefix):
|
|
582
|
+
"""Lists the S3 files given an S3 bucket and key.
|
|
583
|
+
Args:
|
|
584
|
+
bucket (str): Name of the S3 Bucket to download from.
|
|
585
|
+
key_prefix (str): S3 object key name prefix.
|
|
586
|
+
Returns:
|
|
587
|
+
[str]: The list of files at the S3 path.
|
|
588
|
+
"""
|
|
589
|
+
if self.s3_resource is None:
|
|
590
|
+
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
|
|
591
|
+
else:
|
|
592
|
+
s3 = self.s3_resource
|
|
593
|
+
|
|
594
|
+
s3_bucket = s3.Bucket(name=bucket)
|
|
595
|
+
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
|
|
596
|
+
return [s3_object.key for s3_object in s3_objects]
|
|
597
|
+
|
|
598
|
+
def default_bucket(self):
|
|
599
|
+
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
|
|
600
|
+
|
|
601
|
+
This function will create the s3 bucket if it does not exist.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
str: The name of the default bucket. If the name was not explicitly specified through
|
|
605
|
+
the Session or sagemaker_config, the bucket will take the form:
|
|
606
|
+
``sagemaker-{region}-{AWS account ID}``.
|
|
607
|
+
"""
|
|
608
|
+
|
|
609
|
+
if self._default_bucket:
|
|
610
|
+
return self._default_bucket
|
|
611
|
+
|
|
612
|
+
region = self.boto_session.region_name
|
|
613
|
+
|
|
614
|
+
default_bucket = self._default_bucket_name_override
|
|
615
|
+
if not default_bucket:
|
|
616
|
+
default_bucket = self.generate_default_sagemaker_bucket_name(self.boto_session)
|
|
617
|
+
self._default_bucket_set_by_sdk = True
|
|
618
|
+
|
|
619
|
+
self._create_s3_bucket_if_it_does_not_exist(
|
|
620
|
+
bucket_name=default_bucket,
|
|
621
|
+
region=region,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
self._default_bucket = default_bucket
|
|
625
|
+
|
|
626
|
+
return self._default_bucket
|
|
627
|
+
|
|
628
|
+
def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
|
|
629
|
+
"""Creates an S3 Bucket if it does not exist.
|
|
630
|
+
|
|
631
|
+
Also swallows a few common exceptions that indicate that the bucket already exists or
|
|
632
|
+
that it is being created.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
bucket_name (str): Name of the S3 bucket to be created.
|
|
636
|
+
region (str): The region in which to create the bucket.
|
|
637
|
+
|
|
638
|
+
Raises:
|
|
639
|
+
botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
|
|
640
|
+
creation.
|
|
641
|
+
If the exception is due to the bucket already existing or
|
|
642
|
+
already being created, no exception is raised.
|
|
643
|
+
"""
|
|
644
|
+
if self.s3_resource is None:
|
|
645
|
+
s3 = self.boto_session.resource("s3", region_name=region)
|
|
646
|
+
else:
|
|
647
|
+
s3 = self.s3_resource
|
|
648
|
+
|
|
649
|
+
bucket = s3.Bucket(name=bucket_name)
|
|
650
|
+
if bucket.creation_date is None:
|
|
651
|
+
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
|
|
652
|
+
|
|
653
|
+
elif self._default_bucket_set_by_sdk:
|
|
654
|
+
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
|
|
655
|
+
|
|
656
|
+
expected_bucket_owner_id = self.account_id()
|
|
657
|
+
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
|
|
658
|
+
|
|
659
|
+
def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
|
|
660
|
+
"""Checks if the bucket belongs to a particular owner and throws a Client Error if it is not
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
bucket_name (str): Name of the S3 bucket
|
|
664
|
+
s3 (str): S3 object from boto session
|
|
665
|
+
expected_bucket_owner_id (str): Owner ID string
|
|
666
|
+
|
|
667
|
+
"""
|
|
668
|
+
try:
|
|
669
|
+
s3.meta.client.head_bucket(
|
|
670
|
+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
|
|
671
|
+
)
|
|
672
|
+
except ClientError as e:
|
|
673
|
+
error_code = e.response["Error"]["Code"]
|
|
674
|
+
message = e.response["Error"]["Message"]
|
|
675
|
+
if error_code == "403" and message == "Forbidden":
|
|
676
|
+
LOGGER.error(
|
|
677
|
+
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
|
|
678
|
+
"%s bucket. "
|
|
679
|
+
"This bucket cannot be configured to use as it is not owned by Account %s. "
|
|
680
|
+
"To unblock it's recommended to use custom default_bucket "
|
|
681
|
+
"parameter in sagemaker.Session",
|
|
682
|
+
bucket_name,
|
|
683
|
+
expected_bucket_owner_id,
|
|
684
|
+
)
|
|
685
|
+
raise
|
|
686
|
+
|
|
687
|
+
def general_bucket_check_if_user_has_permission(
|
|
688
|
+
self, bucket_name, s3, bucket, region, bucket_creation_date_none
|
|
689
|
+
):
|
|
690
|
+
"""Checks if the person running has the permissions to the bucket
|
|
691
|
+
|
|
692
|
+
If there is any other error that comes up with calling head bucket, it is raised up here
|
|
693
|
+
If there is no bucket , it will create one
|
|
694
|
+
|
|
695
|
+
Args:
|
|
696
|
+
bucket_name (str): Name of the S3 bucket
|
|
697
|
+
s3 (str): S3 object from boto session
|
|
698
|
+
region (str): The region in which to create the bucket.
|
|
699
|
+
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
|
|
700
|
+
"""
|
|
701
|
+
try:
|
|
702
|
+
s3.meta.client.head_bucket(Bucket=bucket_name)
|
|
703
|
+
except ClientError as e:
|
|
704
|
+
error_code = e.response["Error"]["Code"]
|
|
705
|
+
message = e.response["Error"]["Message"]
|
|
706
|
+
# bucket does not exist or forbidden to access
|
|
707
|
+
if bucket_creation_date_none:
|
|
708
|
+
if error_code == "404" and message == "Not Found":
|
|
709
|
+
self.create_bucket_for_not_exist_error(bucket_name, region, s3)
|
|
710
|
+
elif error_code == "403" and message == "Forbidden":
|
|
711
|
+
LOGGER.error(
|
|
712
|
+
"Bucket %s exists, but access is forbidden. Please try again after "
|
|
713
|
+
"adding appropriate access.",
|
|
714
|
+
bucket.name,
|
|
715
|
+
)
|
|
716
|
+
raise
|
|
717
|
+
else:
|
|
718
|
+
raise
|
|
719
|
+
|
|
720
|
+
def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
|
|
721
|
+
"""Creates the S3 bucket in the given region
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
bucket_name (str): Name of the S3 bucket
|
|
725
|
+
s3 (str): S3 object from boto session
|
|
726
|
+
region (str): The region in which to create the bucket.
|
|
727
|
+
"""
|
|
728
|
+
# bucket does not exist, create one
|
|
729
|
+
try:
|
|
730
|
+
if region == "us-east-1":
|
|
731
|
+
# 'us-east-1' cannot be specified because it is the default region:
|
|
732
|
+
# https://github.com/boto/boto3/issues/125
|
|
733
|
+
s3.create_bucket(Bucket=bucket_name)
|
|
734
|
+
else:
|
|
735
|
+
s3.create_bucket(
|
|
736
|
+
Bucket=bucket_name,
|
|
737
|
+
CreateBucketConfiguration={"LocationConstraint": region},
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
logger.info("Created S3 bucket: %s", bucket_name)
|
|
741
|
+
except ClientError as e:
|
|
742
|
+
error_code = e.response["Error"]["Code"]
|
|
743
|
+
message = e.response["Error"]["Message"]
|
|
744
|
+
|
|
745
|
+
if error_code == "OperationAborted" and "conflicting conditional operation" in message:
|
|
746
|
+
# If this bucket is already being concurrently created,
|
|
747
|
+
# we don't need to create it again.
|
|
748
|
+
pass
|
|
749
|
+
else:
|
|
750
|
+
raise
|
|
751
|
+
|
|
752
|
+
def generate_default_sagemaker_bucket_name(self, boto_session):
|
|
753
|
+
"""Generates a name for the default sagemaker S3 bucket.
|
|
754
|
+
|
|
755
|
+
Args:
|
|
756
|
+
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
|
|
757
|
+
"""
|
|
758
|
+
region = boto_session.region_name
|
|
759
|
+
account = boto_session.client(
|
|
760
|
+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
|
|
761
|
+
).get_caller_identity()["Account"]
|
|
762
|
+
return "sagemaker-{}-{}".format(region, account)
|
|
763
|
+
|
|
764
|
+
def determine_bucket_and_prefix(
|
|
765
|
+
self, bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
|
|
766
|
+
):
|
|
767
|
+
"""Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
bucket (Optional[str]): S3 Bucket to use (if it exists)
|
|
771
|
+
key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
|
|
772
|
+
sagemaker_session (sagemaker.core.session.Session): Session to fetch a default bucket and
|
|
773
|
+
prefix from, if bucket doesn't exist. Expected to exist
|
|
774
|
+
|
|
775
|
+
Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
|
|
776
|
+
"""
|
|
777
|
+
if bucket:
|
|
778
|
+
final_bucket = bucket
|
|
779
|
+
final_key_prefix = key_prefix
|
|
780
|
+
else:
|
|
781
|
+
final_bucket = sagemaker_session.default_bucket()
|
|
782
|
+
|
|
783
|
+
# default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
|
|
784
|
+
# exist and we are using the Session's default_bucket.
|
|
785
|
+
final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
|
|
786
|
+
|
|
787
|
+
# We should not append default_bucket_prefix even if the bucket exists but is equal to the
|
|
788
|
+
# default_bucket, because either:
|
|
789
|
+
# (1) the bucket was explicitly passed in by the user and just happens to be the same as the
|
|
790
|
+
# default_bucket (in which case we don't want to change the user's input), or
|
|
791
|
+
# (2) the default_bucket was fetched from Session earlier already (and the default prefix
|
|
792
|
+
# should have been fetched then as well), and then this function was
|
|
793
|
+
# called with it. If we appended the default prefix here, we would be appending it more than
|
|
794
|
+
# once in total.
|
|
795
|
+
|
|
796
|
+
return final_bucket, final_key_prefix
|
|
797
|
+
|
|
798
|
+
def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
|
|
799
|
+
"""Appends tags specified in the sagemaker_config to the given list of tags.
|
|
800
|
+
|
|
801
|
+
To minimize the chance of duplicate tags being applied, this is intended to be used
|
|
802
|
+
immediately before calls to sagemaker_client, rather than during initialization of
|
|
803
|
+
classes like EstimatorBase.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
tags: The list of tags to append to.
|
|
807
|
+
config_path_to_tags: The path to look up tags in the config.
|
|
808
|
+
|
|
809
|
+
Returns:
|
|
810
|
+
A list of tags.
|
|
811
|
+
"""
|
|
812
|
+
config_tags = get_sagemaker_config_value(self, config_path_to_tags)
|
|
813
|
+
|
|
814
|
+
if config_tags is None or len(config_tags) == 0:
|
|
815
|
+
return tags
|
|
816
|
+
|
|
817
|
+
all_tags = tags or []
|
|
818
|
+
for config_tag in config_tags:
|
|
819
|
+
config_tag_key = config_tag[KEY]
|
|
820
|
+
if not any(tag.get("Key", None) == config_tag_key for tag in all_tags):
|
|
821
|
+
# This check prevents new tags with duplicate keys from being added
|
|
822
|
+
# (to prevent API failure and/or overwriting of tags). If there is a conflict,
|
|
823
|
+
# the user-provided tag should take precedence over the config-provided tag.
|
|
824
|
+
# Note: this does not check user-provided tags for conflicts with other
|
|
825
|
+
# user-provided tags.
|
|
826
|
+
all_tags.append(config_tag)
|
|
827
|
+
|
|
828
|
+
_log_sagemaker_config_merge(
|
|
829
|
+
source_value=tags,
|
|
830
|
+
config_value=config_tags,
|
|
831
|
+
merged_source_and_config_value=all_tags,
|
|
832
|
+
config_key_path=config_path_to_tags,
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
return all_tags
|
|
836
|
+
|
|
837
|
+
def endpoint_from_production_variants(
|
|
838
|
+
self,
|
|
839
|
+
name,
|
|
840
|
+
production_variants,
|
|
841
|
+
tags=None,
|
|
842
|
+
kms_key=None,
|
|
843
|
+
wait=True,
|
|
844
|
+
data_capture_config_dict=None,
|
|
845
|
+
async_inference_config_dict=None,
|
|
846
|
+
explainer_config_dict=None,
|
|
847
|
+
live_logging=False,
|
|
848
|
+
vpc_config=None,
|
|
849
|
+
enable_network_isolation=None,
|
|
850
|
+
role=None,
|
|
851
|
+
):
|
|
852
|
+
"""Create an SageMaker ``Endpoint`` from a list of production variants.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
name (str): The name of the ``Endpoint`` to create.
|
|
856
|
+
production_variants (list[dict[str, str]]): The list of production variants to deploy.
|
|
857
|
+
tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
|
|
858
|
+
(default: None).
|
|
859
|
+
kms_key (str): The KMS key that is used to encrypt the data on the storage volume
|
|
860
|
+
attached to the instance hosting the endpoint.
|
|
861
|
+
wait (bool): Whether to wait for the endpoint deployment to complete before returning
|
|
862
|
+
(default: True).
|
|
863
|
+
data_capture_config_dict (dict): Specifies configuration related to Endpoint data
|
|
864
|
+
capture for use with Amazon SageMaker Model Monitoring. Default: None.
|
|
865
|
+
async_inference_config_dict (dict) : specifies configuration related to async endpoint.
|
|
866
|
+
Use this configuration when trying to create async endpoint and make async inference
|
|
867
|
+
(default: None)
|
|
868
|
+
explainer_config_dict (dict) : Specifies configuration related to explainer.
|
|
869
|
+
Use this configuration when trying to use online explainability.
|
|
870
|
+
(default: None).
|
|
871
|
+
vpc_config (dict[str, list[str]]:
|
|
872
|
+
The VpcConfig set on the model (default: None).
|
|
873
|
+
* 'Subnets' (list[str]): List of subnet ids.
|
|
874
|
+
* 'SecurityGroupIds' (list[str]): List of security group ids.
|
|
875
|
+
enable_network_isolation (Boolean): Default False.
|
|
876
|
+
If True, enables network isolation in the endpoint, isolating the model
|
|
877
|
+
container. No inbound or outbound network calls can be made to
|
|
878
|
+
or from the model container.
|
|
879
|
+
role (str): An AWS IAM role (either name or full ARN). The Amazon
|
|
880
|
+
SageMaker training jobs and APIs that create Amazon SageMaker
|
|
881
|
+
endpoints use this role to access training data and model
|
|
882
|
+
artifacts. After the endpoint is created, the inference code
|
|
883
|
+
might use the IAM role if it needs to access some AWS resources.
|
|
884
|
+
(default: None).
|
|
885
|
+
Returns:
|
|
886
|
+
str: The name of the created ``Endpoint``.
|
|
887
|
+
"""
|
|
888
|
+
|
|
889
|
+
supports_kms = any(
|
|
890
|
+
[
|
|
891
|
+
instance_supports_kms(production_variant["InstanceType"])
|
|
892
|
+
for production_variant in production_variants
|
|
893
|
+
if "InstanceType" in production_variant
|
|
894
|
+
]
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
update_list_of_dicts_with_values_from_config(
|
|
898
|
+
production_variants,
|
|
899
|
+
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
|
|
900
|
+
required_key_paths=["CoreDumpConfig.DestinationS3Uri"],
|
|
901
|
+
sagemaker_session=self,
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
|
|
905
|
+
|
|
906
|
+
kms_key = (
|
|
907
|
+
resolve_value_from_config(
|
|
908
|
+
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
|
|
909
|
+
)
|
|
910
|
+
if supports_kms
|
|
911
|
+
else kms_key
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
vpc_config = resolve_value_from_config(
|
|
915
|
+
vpc_config,
|
|
916
|
+
ENDPOINT_CONFIG_VPC_CONFIG_PATH,
|
|
917
|
+
sagemaker_session=self,
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
enable_network_isolation = resolve_value_from_config(
|
|
921
|
+
enable_network_isolation,
|
|
922
|
+
ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
|
|
923
|
+
sagemaker_session=self,
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
role = resolve_value_from_config(
|
|
927
|
+
role,
|
|
928
|
+
ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
|
|
929
|
+
sagemaker_session=self,
|
|
930
|
+
sagemaker_config=load_sagemaker_config() if (self is None) else None,
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
# For Amazon SageMaker inference component based endpoint, it will not pass
|
|
934
|
+
# Model names during endpoint creation. Instead, ExecutionRoleArn will be
|
|
935
|
+
# needed in the endpoint config to create Endpoint
|
|
936
|
+
model_names = [pv["ModelName"] for pv in production_variants if "ModelName" in pv]
|
|
937
|
+
if len(model_names) == 0:
|
|
938
|
+
# Currently, SageMaker Python SDK allow using RoleName to deploy models.
|
|
939
|
+
# Use expand_role method to handle this situation.
|
|
940
|
+
role = self.expand_role(role)
|
|
941
|
+
config_options["ExecutionRoleArn"] = role
|
|
942
|
+
endpoint_config_tags = _append_project_tags(format_tags(tags))
|
|
943
|
+
endpoint_tags = _append_project_tags(format_tags(tags))
|
|
944
|
+
|
|
945
|
+
endpoint_config_tags = self._append_sagemaker_config_tags(
|
|
946
|
+
endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
|
|
947
|
+
)
|
|
948
|
+
if endpoint_config_tags:
|
|
949
|
+
config_options["Tags"] = endpoint_config_tags
|
|
950
|
+
if kms_key:
|
|
951
|
+
config_options["KmsKeyId"] = kms_key
|
|
952
|
+
if data_capture_config_dict is not None:
|
|
953
|
+
inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config(
|
|
954
|
+
data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self
|
|
955
|
+
)
|
|
956
|
+
config_options["DataCaptureConfig"] = inferred_data_capture_config_dict
|
|
957
|
+
if async_inference_config_dict is not None:
|
|
958
|
+
inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config(
|
|
959
|
+
async_inference_config_dict,
|
|
960
|
+
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
|
|
961
|
+
sagemaker_session=self,
|
|
962
|
+
)
|
|
963
|
+
config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict
|
|
964
|
+
if explainer_config_dict is not None:
|
|
965
|
+
config_options["ExplainerConfig"] = explainer_config_dict
|
|
966
|
+
if vpc_config is not None:
|
|
967
|
+
config_options["VpcConfig"] = vpc_config
|
|
968
|
+
if enable_network_isolation is not None:
|
|
969
|
+
config_options["EnableNetworkIsolation"] = enable_network_isolation
|
|
970
|
+
if role is not None:
|
|
971
|
+
config_options["ExecutionRoleArn"] = role
|
|
972
|
+
|
|
973
|
+
logger.info("Creating endpoint-config with name %s", name)
|
|
974
|
+
self.sagemaker_client.create_endpoint_config(**config_options)
|
|
975
|
+
|
|
976
|
+
return self.create_endpoint(
|
|
977
|
+
endpoint_name=name,
|
|
978
|
+
config_name=name,
|
|
979
|
+
tags=endpoint_tags,
|
|
980
|
+
wait=wait,
|
|
981
|
+
live_logging=live_logging,
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live_logging=False):
|
|
985
|
+
"""Create an Amazon SageMaker ``Endpoint`` according to the configuration in the request.
|
|
986
|
+
|
|
987
|
+
Once the ``Endpoint`` is created, client applications can send requests to obtain
|
|
988
|
+
inferences. The endpoint configuration is created using the ``CreateEndpointConfig`` API.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` being created.
|
|
992
|
+
config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
|
|
993
|
+
wait (bool): Whether to wait for the endpoint deployment to complete before returning
|
|
994
|
+
(default: True).
|
|
995
|
+
tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
|
|
996
|
+
(default: None).
|
|
997
|
+
|
|
998
|
+
Returns:
|
|
999
|
+
str: Name of the Amazon SageMaker ``Endpoint`` created.
|
|
1000
|
+
|
|
1001
|
+
Raises:
|
|
1002
|
+
botocore.exceptions.ClientError: If Sagemaker throws an exception while creating
|
|
1003
|
+
endpoint.
|
|
1004
|
+
"""
|
|
1005
|
+
logger.info("Creating endpoint with name %s", endpoint_name)
|
|
1006
|
+
|
|
1007
|
+
tags = format_tags(tags) or []
|
|
1008
|
+
tags = _append_project_tags(tags)
|
|
1009
|
+
tags = self._append_sagemaker_config_tags(
|
|
1010
|
+
tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS)
|
|
1011
|
+
)
|
|
1012
|
+
try:
|
|
1013
|
+
res = self.sagemaker_client.create_endpoint(
|
|
1014
|
+
EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
|
|
1015
|
+
)
|
|
1016
|
+
if res:
|
|
1017
|
+
self.endpoint_arn = res["EndpointArn"]
|
|
1018
|
+
|
|
1019
|
+
if wait:
|
|
1020
|
+
self.wait_for_endpoint(endpoint_name, live_logging=live_logging)
|
|
1021
|
+
return endpoint_name
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
troubleshooting = (
|
|
1024
|
+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
|
|
1025
|
+
"sagemaker-python-sdk-troubleshooting.html"
|
|
1026
|
+
"#sagemaker-python-sdk-troubleshooting-create-endpoint"
|
|
1027
|
+
)
|
|
1028
|
+
logger.error(
|
|
1029
|
+
"Please check the troubleshooting guide for common errors: %s", troubleshooting
|
|
1030
|
+
)
|
|
1031
|
+
raise e
|
|
1032
|
+
|
|
1033
|
+
def wait_for_endpoint(self, endpoint, poll=DEFAULT_EP_POLL, live_logging=False):
|
|
1034
|
+
"""Wait for an Amazon SageMaker endpoint deployment to complete.
|
|
1035
|
+
|
|
1036
|
+
Args:
|
|
1037
|
+
endpoint (str): Name of the ``Endpoint`` to wait for.
|
|
1038
|
+
poll (int): Polling interval in seconds (default: 30).
|
|
1039
|
+
|
|
1040
|
+
Raises:
|
|
1041
|
+
exceptions.CapacityError: If the endpoint creation job fails with CapacityError.
|
|
1042
|
+
exceptions.UnexpectedStatusException: If the endpoint creation job fails.
|
|
1043
|
+
|
|
1044
|
+
Returns:
|
|
1045
|
+
dict: Return value from the ``DescribeEndpoint`` API.
|
|
1046
|
+
"""
|
|
1047
|
+
|
|
1048
|
+
if not live_logging or not _has_permission_for_live_logging(self.boto_session, endpoint):
|
|
1049
|
+
desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll)
|
|
1050
|
+
else:
|
|
1051
|
+
cloudwatch_client = self.boto_session.client("logs")
|
|
1052
|
+
paginator = cloudwatch_client.get_paginator("filter_log_events")
|
|
1053
|
+
paginator_config = create_paginator_config()
|
|
1054
|
+
desc = _wait_until(
|
|
1055
|
+
lambda: _live_logging_deploy_done(
|
|
1056
|
+
self.sagemaker_client, endpoint, paginator, paginator_config, EP_LOGGER_POLL
|
|
1057
|
+
),
|
|
1058
|
+
poll=EP_LOGGER_POLL,
|
|
1059
|
+
)
|
|
1060
|
+
status = desc["EndpointStatus"]
|
|
1061
|
+
|
|
1062
|
+
if status != "InService":
|
|
1063
|
+
reason = desc.get("FailureReason", None)
|
|
1064
|
+
trouble_shooting = (
|
|
1065
|
+
"Try changing the instance type or reference the troubleshooting page "
|
|
1066
|
+
"https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-troubleshooting"
|
|
1067
|
+
".html"
|
|
1068
|
+
)
|
|
1069
|
+
message = "Error hosting endpoint {}: {}. Reason: {}. {}".format(
|
|
1070
|
+
endpoint, status, reason, trouble_shooting
|
|
1071
|
+
)
|
|
1072
|
+
if "CapacityError" in str(reason):
|
|
1073
|
+
raise exceptions.CapacityError(
|
|
1074
|
+
message=message,
|
|
1075
|
+
allowed_statuses=["InService"],
|
|
1076
|
+
actual_status=status,
|
|
1077
|
+
)
|
|
1078
|
+
raise exceptions.UnexpectedStatusException(
|
|
1079
|
+
message=message,
|
|
1080
|
+
allowed_statuses=["InService"],
|
|
1081
|
+
actual_status=status,
|
|
1082
|
+
)
|
|
1083
|
+
return desc
|
|
1084
|
+
|
|
1085
|
+
def create_inference_component(
|
|
1086
|
+
self,
|
|
1087
|
+
inference_component_name: str,
|
|
1088
|
+
endpoint_name: str,
|
|
1089
|
+
variant_name: str,
|
|
1090
|
+
specification: Dict[str, Any],
|
|
1091
|
+
runtime_config: Optional[Dict[str, Any]] = None,
|
|
1092
|
+
tags: Optional[Tags] = None,
|
|
1093
|
+
wait: bool = True,
|
|
1094
|
+
):
|
|
1095
|
+
"""Create an Amazon SageMaker Inference Component.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
inference_component_name (str): Name of the Amazon SageMaker inference component
|
|
1099
|
+
to create.
|
|
1100
|
+
endpoint_name (str): Name of the Amazon SageMaker endpoint that the inference component
|
|
1101
|
+
will deploy to.
|
|
1102
|
+
variant_name (str): Name of the Amazon SageMaker variant that the inference component
|
|
1103
|
+
will deploy to.
|
|
1104
|
+
specification (Dict[str, Any]): The inference component specification.
|
|
1105
|
+
runtime_config (Optional[Dict[str, Any]]): Optional. The inference component
|
|
1106
|
+
runtime configuration. (Default: None).
|
|
1107
|
+
tags (Optional[Tags]): Optional. Either a dictionary or a list
|
|
1108
|
+
of dictionaries containing key-value pairs. (Default: None).
|
|
1109
|
+
wait (bool) : Optional. Wait for the inference component to finish being created before
|
|
1110
|
+
returning a value. (Default: True).
|
|
1111
|
+
|
|
1112
|
+
Returns:
|
|
1113
|
+
str: Name of the Amazon SageMaker ``InferenceComponent`` if created.
|
|
1114
|
+
"""
|
|
1115
|
+
LOGGER.info(
|
|
1116
|
+
"Creating inference component with name %s for endpoint %s",
|
|
1117
|
+
inference_component_name,
|
|
1118
|
+
endpoint_name,
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
if runtime_config is None:
|
|
1122
|
+
runtime_config = {"CopyCount": 1}
|
|
1123
|
+
|
|
1124
|
+
request = {
|
|
1125
|
+
"InferenceComponentName": inference_component_name,
|
|
1126
|
+
"EndpointName": endpoint_name,
|
|
1127
|
+
"VariantName": variant_name,
|
|
1128
|
+
"Specification": specification,
|
|
1129
|
+
"RuntimeConfig": runtime_config,
|
|
1130
|
+
}
|
|
1131
|
+
|
|
1132
|
+
tags = format_tags(tags)
|
|
1133
|
+
tags = _append_project_tags(tags)
|
|
1134
|
+
tags = self._append_sagemaker_config_tags(
|
|
1135
|
+
tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS)
|
|
1136
|
+
)
|
|
1137
|
+
if tags and len(tags) != 0:
|
|
1138
|
+
request["Tags"] = tags
|
|
1139
|
+
|
|
1140
|
+
self.sagemaker_client.create_inference_component(**request)
|
|
1141
|
+
if wait:
|
|
1142
|
+
self.wait_for_inference_component(inference_component_name)
|
|
1143
|
+
return inference_component_name
|
|
1144
|
+
|
|
1145
|
+
def wait_for_inference_component(self, inference_component_name, poll=20):
|
|
1146
|
+
"""Wait for an Amazon SageMaker ``Inference Component`` deployment to complete.
|
|
1147
|
+
|
|
1148
|
+
Args:
|
|
1149
|
+
inference_component_name (str): Name of the ``Inference Component`` to wait for.
|
|
1150
|
+
poll (int): Polling interval in seconds (default: 20).
|
|
1151
|
+
|
|
1152
|
+
Raises:
|
|
1153
|
+
exceptions.CapacityError: If the inference component creation fails with CapacityError.
|
|
1154
|
+
exceptions.UnexpectedStatusException: If the inference component creation fails.
|
|
1155
|
+
|
|
1156
|
+
Returns:
|
|
1157
|
+
dict: Return value from the ``DescribeInferenceComponent`` API.
|
|
1158
|
+
"""
|
|
1159
|
+
desc = _wait_until(
|
|
1160
|
+
lambda: self._inference_component_done(self.sagemaker_client, inference_component_name),
|
|
1161
|
+
poll,
|
|
1162
|
+
)
|
|
1163
|
+
status = desc["InferenceComponentStatus"]
|
|
1164
|
+
|
|
1165
|
+
if status != "InService":
|
|
1166
|
+
message = f"Error creating inference component '{inference_component_name}'"
|
|
1167
|
+
reason = desc.get("FailureReason")
|
|
1168
|
+
if reason:
|
|
1169
|
+
message = f"{message}: {reason}"
|
|
1170
|
+
if "CapacityError" in str(reason):
|
|
1171
|
+
raise exceptions.CapacityError(
|
|
1172
|
+
message=message,
|
|
1173
|
+
allowed_statuses=["InService"],
|
|
1174
|
+
actual_status=status,
|
|
1175
|
+
)
|
|
1176
|
+
raise exceptions.UnexpectedStatusException(
|
|
1177
|
+
message=message,
|
|
1178
|
+
allowed_statuses=["InService"],
|
|
1179
|
+
actual_status=status,
|
|
1180
|
+
)
|
|
1181
|
+
return desc
|
|
1182
|
+
|
|
1183
|
+
def describe_inference_component(self, inference_component_name):
|
|
1184
|
+
"""Describe an Amazon SageMaker ``InferenceComponent``
|
|
1185
|
+
|
|
1186
|
+
Args:
|
|
1187
|
+
inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
|
|
1188
|
+
|
|
1189
|
+
Returns:
|
|
1190
|
+
dict[str,str]: Inference component details.
|
|
1191
|
+
"""
|
|
1192
|
+
|
|
1193
|
+
return self.sagemaker_client.describe_inference_component(
|
|
1194
|
+
InferenceComponentName=inference_component_name
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
def _inference_component_done(self, sagemaker_client, inference_component_name):
|
|
1198
|
+
"""Check if creation of inference component is done.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker
|
|
1202
|
+
service calls
|
|
1203
|
+
inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
|
|
1204
|
+
Returns:
|
|
1205
|
+
dict[str,str]: Inference component details.
|
|
1206
|
+
"""
|
|
1207
|
+
|
|
1208
|
+
create_inference_component_codes = {
|
|
1209
|
+
"InService": "!",
|
|
1210
|
+
"Creating": "-",
|
|
1211
|
+
"Updating": "-",
|
|
1212
|
+
"Failed": "*",
|
|
1213
|
+
"Deleting": "o",
|
|
1214
|
+
}
|
|
1215
|
+
in_progress_statuses = ["Creating", "Updating", "Deleting"]
|
|
1216
|
+
|
|
1217
|
+
desc = sagemaker_client.describe_inference_component(
|
|
1218
|
+
InferenceComponentName=inference_component_name
|
|
1219
|
+
)
|
|
1220
|
+
status = desc["InferenceComponentStatus"]
|
|
1221
|
+
|
|
1222
|
+
print(create_inference_component_codes.get(status, "?"), end="", flush=True)
|
|
1223
|
+
|
|
1224
|
+
return None if status in in_progress_statuses else desc
|
|
1225
|
+
|
|
1226
|
+
def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True):
|
|
1227
|
+
"""Update an Amazon SageMaker ``Endpoint`` , Raise an error endpoint_name does not exist.
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update.
|
|
1231
|
+
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
|
|
1232
|
+
deploy.
|
|
1233
|
+
wait (bool): Whether to wait for the endpoint deployment to complete before returning
|
|
1234
|
+
(default: True).
|
|
1235
|
+
|
|
1236
|
+
Returns:
|
|
1237
|
+
str: Name of the Amazon SageMaker ``Endpoint`` being updated.
|
|
1238
|
+
|
|
1239
|
+
Raises:
|
|
1240
|
+
- ValueError: if the endpoint does not already exist
|
|
1241
|
+
- botocore.exceptions.ClientError: If SageMaker throws an error while
|
|
1242
|
+
creating endpoint config, describing endpoint or updating endpoint
|
|
1243
|
+
"""
|
|
1244
|
+
if not _deployment_entity_exists(
|
|
1245
|
+
lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
|
1246
|
+
):
|
|
1247
|
+
raise ValueError(
|
|
1248
|
+
"Endpoint with name '{}' does not exist; please use an "
|
|
1249
|
+
"existing endpoint name".format(endpoint_name)
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
try:
|
|
1253
|
+
|
|
1254
|
+
res = self.sagemaker_client.update_endpoint(
|
|
1255
|
+
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
|
|
1256
|
+
)
|
|
1257
|
+
if res:
|
|
1258
|
+
self.endpoint_arn = res["EndpointArn"]
|
|
1259
|
+
|
|
1260
|
+
if wait:
|
|
1261
|
+
self.wait_for_endpoint(endpoint_name)
|
|
1262
|
+
return endpoint_name
|
|
1263
|
+
except Exception as e:
|
|
1264
|
+
troubleshooting = (
|
|
1265
|
+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
|
|
1266
|
+
"sagemaker-python-sdk-troubleshooting.html"
|
|
1267
|
+
"#sagemaker-python-sdk-troubleshooting-update-endpoint"
|
|
1268
|
+
)
|
|
1269
|
+
logger.error(
|
|
1270
|
+
"Please check the troubleshooting guide for common errors: %s", troubleshooting
|
|
1271
|
+
)
|
|
1272
|
+
raise e
|
|
1273
|
+
|
|
1274
|
+
def endpoint_in_service_or_not(self, endpoint_name: str):
|
|
1275
|
+
"""Check whether an Amazon SageMaker ``Endpoint``` is in IN_SERVICE status.
|
|
1276
|
+
|
|
1277
|
+
Raise any exception that is not recognized as "not found".
|
|
1278
|
+
|
|
1279
|
+
Args:
|
|
1280
|
+
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to
|
|
1281
|
+
check status.
|
|
1282
|
+
|
|
1283
|
+
Returns:
|
|
1284
|
+
bool: True if ``Endpoint`` is IN_SERVICE, False if ``Endpoint`` not exists
|
|
1285
|
+
or it's in other status.
|
|
1286
|
+
|
|
1287
|
+
Raises:
|
|
1288
|
+
|
|
1289
|
+
"""
|
|
1290
|
+
try:
|
|
1291
|
+
desc = self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
|
1292
|
+
status = desc["EndpointStatus"]
|
|
1293
|
+
if status == "InService":
|
|
1294
|
+
return True
|
|
1295
|
+
return False
|
|
1296
|
+
|
|
1297
|
+
except botocore.exceptions.ClientError as e:
|
|
1298
|
+
str_err = str(e).lower()
|
|
1299
|
+
if "could not find" in str_err or "not found" in str_err:
|
|
1300
|
+
return False
|
|
1301
|
+
raise
|
|
1302
|
+
|
|
1303
|
+
def _intercept_create_request(
|
|
1304
|
+
self,
|
|
1305
|
+
request: Dict,
|
|
1306
|
+
create,
|
|
1307
|
+
func_name: str = None,
|
|
1308
|
+
# pylint: disable=unused-argument
|
|
1309
|
+
):
|
|
1310
|
+
"""This function intercepts the create job request.
|
|
1311
|
+
|
|
1312
|
+
PipelineSession inherits this Session class and will override
|
|
1313
|
+
this function to intercept the create request.
|
|
1314
|
+
|
|
1315
|
+
Args:
|
|
1316
|
+
request (dict): the create job request
|
|
1317
|
+
create (functor): a functor calls the sagemaker client create method
|
|
1318
|
+
func_name (str): the name of the function needed intercepting
|
|
1319
|
+
"""
|
|
1320
|
+
return create(request)
|
|
1321
|
+
|
|
1322
|
+
def _create_inference_recommendations_job_request(
|
|
1323
|
+
self,
|
|
1324
|
+
role: str,
|
|
1325
|
+
job_name: str,
|
|
1326
|
+
job_description: str,
|
|
1327
|
+
framework: str,
|
|
1328
|
+
sample_payload_url: str,
|
|
1329
|
+
supported_content_types: List[str],
|
|
1330
|
+
tags: Optional[Tags],
|
|
1331
|
+
model_name: str = None,
|
|
1332
|
+
model_package_version_arn: str = None,
|
|
1333
|
+
job_duration_in_seconds: int = None,
|
|
1334
|
+
job_type: str = "Default",
|
|
1335
|
+
framework_version: str = None,
|
|
1336
|
+
nearest_model_name: str = None,
|
|
1337
|
+
supported_instance_types: List[str] = None,
|
|
1338
|
+
endpoint_configurations: List[Dict[str, Any]] = None,
|
|
1339
|
+
traffic_pattern: Dict[str, Any] = None,
|
|
1340
|
+
stopping_conditions: Dict[str, Any] = None,
|
|
1341
|
+
resource_limit: Dict[str, Any] = None,
|
|
1342
|
+
) -> Dict[str, Any]:
|
|
1343
|
+
"""Get request dictionary for CreateInferenceRecommendationsJob API.
|
|
1344
|
+
|
|
1345
|
+
Args:
|
|
1346
|
+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
|
|
1347
|
+
jobs and APIs that create Amazon SageMaker endpoints use this role to access
|
|
1348
|
+
training data and model artifacts.
|
|
1349
|
+
You must grant sufficient permissions to this role.
|
|
1350
|
+
job_name (str): The name of the Inference Recommendations Job.
|
|
1351
|
+
job_description (str): A description of the Inference Recommendations Job.
|
|
1352
|
+
framework (str): The machine learning framework of the Image URI.
|
|
1353
|
+
sample_payload_url (str): The S3 path where the sample payload is stored.
|
|
1354
|
+
supported_content_types (List[str]): The supported MIME types for the input data.
|
|
1355
|
+
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
|
|
1356
|
+
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
|
|
1357
|
+
versioned model package.
|
|
1358
|
+
job_duration_in_seconds (int): The maximum job duration that a job
|
|
1359
|
+
can run for. Will be used for `Advanced` jobs.
|
|
1360
|
+
job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
|
|
1361
|
+
framework_version (str): The framework version of the Image URI.
|
|
1362
|
+
nearest_model_name (str): The name of a pre-trained machine learning model
|
|
1363
|
+
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
|
|
1364
|
+
supported_instance_types (List[str]): A list of the instance types that are used
|
|
1365
|
+
to generate inferences in real-time.
|
|
1366
|
+
tags (Optional[Tags]): Tags used to identify where
|
|
1367
|
+
the Inference Recommendatons Call was made from.
|
|
1368
|
+
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
|
|
1369
|
+
to use for a job. Will be used for `Advanced` jobs.
|
|
1370
|
+
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
|
|
1371
|
+
Will be used for `Advanced` jobs.
|
|
1372
|
+
stopping_conditions (Dict[str, any]): A set of conditions for stopping a
|
|
1373
|
+
recommendation job.
|
|
1374
|
+
If any of the conditions are met, the job is automatically stopped.
|
|
1375
|
+
Will be used for `Advanced` jobs.
|
|
1376
|
+
resource_limit (Dict[str, any]): Defines the resource limit for the job.
|
|
1377
|
+
Will be used for `Advanced` jobs.
|
|
1378
|
+
Returns:
|
|
1379
|
+
Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API
|
|
1380
|
+
"""
|
|
1381
|
+
|
|
1382
|
+
containerConfig = {
|
|
1383
|
+
"Domain": "MACHINE_LEARNING",
|
|
1384
|
+
"Task": "OTHER",
|
|
1385
|
+
"Framework": framework,
|
|
1386
|
+
"PayloadConfig": {
|
|
1387
|
+
"SamplePayloadUrl": sample_payload_url,
|
|
1388
|
+
"SupportedContentTypes": supported_content_types,
|
|
1389
|
+
},
|
|
1390
|
+
}
|
|
1391
|
+
|
|
1392
|
+
if framework_version:
|
|
1393
|
+
containerConfig["FrameworkVersion"] = framework_version
|
|
1394
|
+
if nearest_model_name:
|
|
1395
|
+
containerConfig["NearestModelName"] = nearest_model_name
|
|
1396
|
+
if supported_instance_types:
|
|
1397
|
+
containerConfig["SupportedInstanceTypes"] = supported_instance_types
|
|
1398
|
+
|
|
1399
|
+
request = {
|
|
1400
|
+
"JobName": job_name,
|
|
1401
|
+
"JobType": job_type,
|
|
1402
|
+
"RoleArn": role,
|
|
1403
|
+
"InputConfig": {
|
|
1404
|
+
"ContainerConfig": containerConfig,
|
|
1405
|
+
},
|
|
1406
|
+
"Tags": format_tags(tags),
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
request.get("InputConfig").update(
|
|
1410
|
+
{"ModelPackageVersionArn": model_package_version_arn}
|
|
1411
|
+
if model_package_version_arn
|
|
1412
|
+
else {"ModelName": model_name}
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
if job_description:
|
|
1416
|
+
request["JobDescription"] = job_description
|
|
1417
|
+
if job_duration_in_seconds:
|
|
1418
|
+
request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds
|
|
1419
|
+
|
|
1420
|
+
if job_type == "Advanced":
|
|
1421
|
+
if stopping_conditions:
|
|
1422
|
+
request["StoppingConditions"] = stopping_conditions
|
|
1423
|
+
if resource_limit:
|
|
1424
|
+
request["InputConfig"]["ResourceLimit"] = resource_limit
|
|
1425
|
+
if traffic_pattern:
|
|
1426
|
+
request["InputConfig"]["TrafficPattern"] = traffic_pattern
|
|
1427
|
+
if endpoint_configurations:
|
|
1428
|
+
request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations
|
|
1429
|
+
|
|
1430
|
+
return request
|
|
1431
|
+
|
|
1432
|
+
def create_inference_recommendations_job(
|
|
1433
|
+
self,
|
|
1434
|
+
role: str,
|
|
1435
|
+
sample_payload_url: str,
|
|
1436
|
+
supported_content_types: List[str],
|
|
1437
|
+
job_name: str = None,
|
|
1438
|
+
job_type: str = "Default",
|
|
1439
|
+
model_name: str = None,
|
|
1440
|
+
model_package_version_arn: str = None,
|
|
1441
|
+
job_duration_in_seconds: int = None,
|
|
1442
|
+
nearest_model_name: str = None,
|
|
1443
|
+
supported_instance_types: List[str] = None,
|
|
1444
|
+
framework: str = None,
|
|
1445
|
+
framework_version: str = None,
|
|
1446
|
+
endpoint_configurations: List[Dict[str, any]] = None,
|
|
1447
|
+
traffic_pattern: Dict[str, any] = None,
|
|
1448
|
+
stopping_conditions: Dict[str, any] = None,
|
|
1449
|
+
resource_limit: Dict[str, any] = None,
|
|
1450
|
+
):
|
|
1451
|
+
"""Creates an Inference Recommendations Job
|
|
1452
|
+
|
|
1453
|
+
Args:
|
|
1454
|
+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
|
|
1455
|
+
jobs and APIs that create Amazon SageMaker endpoints use this role to access
|
|
1456
|
+
training data and model artifactså.
|
|
1457
|
+
You must grant sufficient permissions to this role.
|
|
1458
|
+
sample_payload_url (str): The S3 path where the sample payload is stored.
|
|
1459
|
+
supported_content_types (List[str]): The supported MIME types for the input data.
|
|
1460
|
+
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
|
|
1461
|
+
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
|
|
1462
|
+
versioned model package.
|
|
1463
|
+
job_name (str): The name of the job being run.
|
|
1464
|
+
job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
|
|
1465
|
+
job_duration_in_seconds (int): The maximum job duration that a job
|
|
1466
|
+
can run for. Will be used for `Advanced` jobs.
|
|
1467
|
+
nearest_model_name (str): The name of a pre-trained machine learning model
|
|
1468
|
+
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
|
|
1469
|
+
supported_instance_types (List[str]): A list of the instance types that are used
|
|
1470
|
+
to generate inferences in real-time.
|
|
1471
|
+
framework (str): The machine learning framework of the Image URI.
|
|
1472
|
+
framework_version (str): The framework version of the Image URI.
|
|
1473
|
+
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
|
|
1474
|
+
to use for a job. Will be used for `Advanced` jobs.
|
|
1475
|
+
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
|
|
1476
|
+
Will be used for `Advanced` jobs.
|
|
1477
|
+
stopping_conditions (Dict[str, any]): A set of conditions for stopping a
|
|
1478
|
+
recommendation job.
|
|
1479
|
+
If any of the conditions are met, the job is automatically stopped.
|
|
1480
|
+
Will be used for `Advanced` jobs.
|
|
1481
|
+
resource_limit (Dict[str, any]): Defines the resource limit for the job.
|
|
1482
|
+
Will be used for `Advanced` jobs.
|
|
1483
|
+
Returns:
|
|
1484
|
+
str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
|
|
1485
|
+
"""
|
|
1486
|
+
|
|
1487
|
+
if model_name is None and model_package_version_arn is None:
|
|
1488
|
+
raise ValueError("Please provide either model_name or model_package_version_arn.")
|
|
1489
|
+
|
|
1490
|
+
if model_name is not None and model_package_version_arn is not None:
|
|
1491
|
+
raise ValueError("Please provide either model_name or model_package_version_arn.")
|
|
1492
|
+
|
|
1493
|
+
if not job_name:
|
|
1494
|
+
unique_tail = uuid.uuid4()
|
|
1495
|
+
job_name = "SMPYTHONSDK-" + str(unique_tail)
|
|
1496
|
+
job_description = "#python-sdk-create"
|
|
1497
|
+
|
|
1498
|
+
tags = [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}]
|
|
1499
|
+
|
|
1500
|
+
create_inference_recommendations_job_request = (
|
|
1501
|
+
self._create_inference_recommendations_job_request(
|
|
1502
|
+
role=role,
|
|
1503
|
+
model_name=model_name,
|
|
1504
|
+
model_package_version_arn=model_package_version_arn,
|
|
1505
|
+
job_name=job_name,
|
|
1506
|
+
job_type=job_type,
|
|
1507
|
+
job_duration_in_seconds=job_duration_in_seconds,
|
|
1508
|
+
job_description=job_description,
|
|
1509
|
+
framework=framework,
|
|
1510
|
+
framework_version=framework_version,
|
|
1511
|
+
nearest_model_name=nearest_model_name,
|
|
1512
|
+
sample_payload_url=sample_payload_url,
|
|
1513
|
+
supported_content_types=supported_content_types,
|
|
1514
|
+
supported_instance_types=supported_instance_types,
|
|
1515
|
+
endpoint_configurations=endpoint_configurations,
|
|
1516
|
+
traffic_pattern=traffic_pattern,
|
|
1517
|
+
stopping_conditions=stopping_conditions,
|
|
1518
|
+
resource_limit=resource_limit,
|
|
1519
|
+
tags=tags,
|
|
1520
|
+
)
|
|
1521
|
+
)
|
|
1522
|
+
|
|
1523
|
+
def submit(request):
|
|
1524
|
+
logger.info("Creating Inference Recommendations job with name: %s", job_name)
|
|
1525
|
+
logger.debug("process request: %s", json.dumps(request, indent=4))
|
|
1526
|
+
self.sagemaker_client.create_inference_recommendations_job(**request)
|
|
1527
|
+
|
|
1528
|
+
self._intercept_create_request(
|
|
1529
|
+
create_inference_recommendations_job_request,
|
|
1530
|
+
submit,
|
|
1531
|
+
self.create_inference_recommendations_job.__name__,
|
|
1532
|
+
)
|
|
1533
|
+
return job_name
|
|
1534
|
+
|
|
1535
|
+
def wait_for_inference_recommendations_job(
|
|
1536
|
+
self, job_name: str, poll: int = 120, log_level: str = "Verbose"
|
|
1537
|
+
) -> Dict[str, Any]:
|
|
1538
|
+
"""Wait for an Amazon SageMaker Inference Recommender job to complete.
|
|
1539
|
+
|
|
1540
|
+
Args:
|
|
1541
|
+
job_name (str): Name of the Inference Recommender job to wait for.
|
|
1542
|
+
poll (int): Polling interval in seconds (default: 120).
|
|
1543
|
+
log_level (str): The level of verbosity for the logs.
|
|
1544
|
+
Can be "Quiet" or "Verbose" (default: "Quiet").
|
|
1545
|
+
|
|
1546
|
+
Returns:
|
|
1547
|
+
(dict): Return value from the ``DescribeInferenceRecommendationsJob`` API.
|
|
1548
|
+
|
|
1549
|
+
Raises:
|
|
1550
|
+
exceptions.CapacityError: If the Inference Recommender job fails with CapacityError.
|
|
1551
|
+
exceptions.UnexpectedStatusException: If the Inference Recommender job fails.
|
|
1552
|
+
"""
|
|
1553
|
+
if log_level == "Quiet":
|
|
1554
|
+
_wait_until(
|
|
1555
|
+
lambda: _describe_inference_recommendations_job_status(
|
|
1556
|
+
self.sagemaker_client, job_name
|
|
1557
|
+
),
|
|
1558
|
+
poll,
|
|
1559
|
+
)
|
|
1560
|
+
elif log_level == "Verbose":
|
|
1561
|
+
_display_inference_recommendations_job_steps_status(
|
|
1562
|
+
self, self.sagemaker_client, job_name
|
|
1563
|
+
)
|
|
1564
|
+
else:
|
|
1565
|
+
raise ValueError("log_level must be either Quiet or Verbose")
|
|
1566
|
+
desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name)
|
|
1567
|
+
_check_job_status(job_name, desc, "Status")
|
|
1568
|
+
return desc
|
|
1569
|
+
|
|
1570
|
+
def delete_model(self, model_name):
|
|
1571
|
+
"""Delete an Amazon SageMaker Model.
|
|
1572
|
+
|
|
1573
|
+
Args:
|
|
1574
|
+
model_name (str): Name of the Amazon SageMaker model to delete.
|
|
1575
|
+
"""
|
|
1576
|
+
logger.info("Deleting model with name: %s", model_name)
|
|
1577
|
+
self.sagemaker_client.delete_model(ModelName=model_name)
|
|
1578
|
+
|
|
1579
|
+
def delete_endpoint(self, endpoint_name):
|
|
1580
|
+
"""Delete an Amazon SageMaker ``Endpoint``.
|
|
1581
|
+
|
|
1582
|
+
Args:
|
|
1583
|
+
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete.
|
|
1584
|
+
"""
|
|
1585
|
+
logger.info("Deleting endpoint with name: %s", endpoint_name)
|
|
1586
|
+
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
|
|
1587
|
+
|
|
1588
|
+
def delete_endpoint_config(self, endpoint_config_name):
|
|
1589
|
+
"""Delete an Amazon SageMaker endpoint configuration.
|
|
1590
|
+
|
|
1591
|
+
Args:
|
|
1592
|
+
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
|
|
1593
|
+
delete.
|
|
1594
|
+
"""
|
|
1595
|
+
logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name)
|
|
1596
|
+
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
1597
|
+
|
|
1598
|
+
def wait_for_optimization_job(self, job, poll=5):
|
|
1599
|
+
"""Wait for an Amazon SageMaker Optimization job to complete.
|
|
1600
|
+
|
|
1601
|
+
Args:
|
|
1602
|
+
job (str): Name of optimization job to wait for.
|
|
1603
|
+
poll (int): Polling interval in seconds (default: 5).
|
|
1604
|
+
|
|
1605
|
+
Returns:
|
|
1606
|
+
(dict): Return value from the ``DescribeOptimizationJob`` API.
|
|
1607
|
+
|
|
1608
|
+
Raises:
|
|
1609
|
+
exceptions.ResourceNotFound: If optimization job fails with CapacityError.
|
|
1610
|
+
exceptions.UnexpectedStatusException: If optimization job fails.
|
|
1611
|
+
"""
|
|
1612
|
+
desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll)
|
|
1613
|
+
_check_job_status(job, desc, "OptimizationJobStatus")
|
|
1614
|
+
return desc
|
|
1615
|
+
|
|
1616
|
+
def update_inference_component(
|
|
1617
|
+
self, inference_component_name, specification=None, runtime_config=None, wait=True
|
|
1618
|
+
):
|
|
1619
|
+
"""Update an Amazon SageMaker ``InferenceComponent``
|
|
1620
|
+
|
|
1621
|
+
Args:
|
|
1622
|
+
inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
|
|
1623
|
+
specification ([dict[str,int]]): Resource configuration. Optional.
|
|
1624
|
+
Example: {
|
|
1625
|
+
"MinMemoryRequiredInMb": 1024,
|
|
1626
|
+
"NumberOfCpuCoresRequired": 1,
|
|
1627
|
+
"NumberOfAcceleratorDevicesRequired": 1,
|
|
1628
|
+
"MaxMemoryRequiredInMb": 4096,
|
|
1629
|
+
},
|
|
1630
|
+
runtime_config ([dict[str,int]]): Number of copies. Optional.
|
|
1631
|
+
Default: {
|
|
1632
|
+
"copyCount": 1
|
|
1633
|
+
}
|
|
1634
|
+
wait: Wait for inference component to be created before return. Optional. Default is
|
|
1635
|
+
True.
|
|
1636
|
+
|
|
1637
|
+
Return:
|
|
1638
|
+
str: inference component name
|
|
1639
|
+
|
|
1640
|
+
Raises:
|
|
1641
|
+
ValueError: If the inference_component_name does not exist.
|
|
1642
|
+
"""
|
|
1643
|
+
if not _deployment_entity_exists(
|
|
1644
|
+
lambda: self.sagemaker_client.describe_inference_component(
|
|
1645
|
+
InferenceComponentName=inference_component_name
|
|
1646
|
+
)
|
|
1647
|
+
):
|
|
1648
|
+
raise ValueError(
|
|
1649
|
+
"InferenceComponent with name '{}' does not exist; please use an "
|
|
1650
|
+
"existing model name".format(inference_component_name)
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
request = {
|
|
1654
|
+
"InferenceComponentName": inference_component_name,
|
|
1655
|
+
"Specification": specification,
|
|
1656
|
+
"RuntimeConfig": runtime_config,
|
|
1657
|
+
}
|
|
1658
|
+
|
|
1659
|
+
self.sagemaker_client.update_inference_component(**request)
|
|
1660
|
+
|
|
1661
|
+
if wait:
|
|
1662
|
+
self.wait_for_inference_component(inference_component_name)
|
|
1663
|
+
return inference_component_name
|
|
1664
|
+
|
|
1665
|
+
def _create_model_request(
|
|
1666
|
+
self,
|
|
1667
|
+
name,
|
|
1668
|
+
role,
|
|
1669
|
+
container_defs,
|
|
1670
|
+
vpc_config=None,
|
|
1671
|
+
enable_network_isolation=False,
|
|
1672
|
+
primary_container=None,
|
|
1673
|
+
tags=None,
|
|
1674
|
+
): # pylint: disable=redefined-outer-name
|
|
1675
|
+
"""Placeholder docstring"""
|
|
1676
|
+
|
|
1677
|
+
if container_defs and primary_container:
|
|
1678
|
+
raise ValueError("Both container_defs and primary_container can not be passed as input")
|
|
1679
|
+
|
|
1680
|
+
if primary_container:
|
|
1681
|
+
msg = (
|
|
1682
|
+
"primary_container is going to be deprecated in a future release. Please use "
|
|
1683
|
+
"container_defs instead."
|
|
1684
|
+
)
|
|
1685
|
+
warnings.warn(msg, DeprecationWarning)
|
|
1686
|
+
container_defs = primary_container
|
|
1687
|
+
|
|
1688
|
+
role = self.expand_role(role)
|
|
1689
|
+
|
|
1690
|
+
if isinstance(container_defs, list):
|
|
1691
|
+
update_list_of_dicts_with_values_from_config(
|
|
1692
|
+
container_defs, MODEL_CONTAINERS_PATH, sagemaker_session=self
|
|
1693
|
+
)
|
|
1694
|
+
container_definition = container_defs
|
|
1695
|
+
else:
|
|
1696
|
+
container_definition = _expand_container_def(container_defs)
|
|
1697
|
+
container_definition = update_nested_dictionary_with_values_from_config(
|
|
1698
|
+
container_definition, MODEL_PRIMARY_CONTAINER_PATH, sagemaker_session=self
|
|
1699
|
+
)
|
|
1700
|
+
|
|
1701
|
+
request = {"ModelName": name, "ExecutionRoleArn": role}
|
|
1702
|
+
if isinstance(container_definition, list):
|
|
1703
|
+
request["Containers"] = container_definition
|
|
1704
|
+
elif "ModelPackageName" in container_definition:
|
|
1705
|
+
request["Containers"] = [container_definition]
|
|
1706
|
+
else:
|
|
1707
|
+
request["PrimaryContainer"] = container_definition
|
|
1708
|
+
|
|
1709
|
+
if tags:
|
|
1710
|
+
request["Tags"] = format_tags(tags)
|
|
1711
|
+
|
|
1712
|
+
if vpc_config:
|
|
1713
|
+
request["VpcConfig"] = vpc_config
|
|
1714
|
+
|
|
1715
|
+
if enable_network_isolation:
|
|
1716
|
+
# enable_network_isolation may be a pipeline variable which is
|
|
1717
|
+
# parsed in execution time
|
|
1718
|
+
request["EnableNetworkIsolation"] = enable_network_isolation
|
|
1719
|
+
|
|
1720
|
+
return request
|
|
1721
|
+
|
|
1722
|
+
def create_model(
|
|
1723
|
+
self,
|
|
1724
|
+
name,
|
|
1725
|
+
role=None,
|
|
1726
|
+
container_defs=None,
|
|
1727
|
+
vpc_config=None,
|
|
1728
|
+
enable_network_isolation=None,
|
|
1729
|
+
primary_container=None,
|
|
1730
|
+
tags=None,
|
|
1731
|
+
):
|
|
1732
|
+
"""Create an Amazon SageMaker ``Model``.
|
|
1733
|
+
|
|
1734
|
+
Specify the S3 location of the model artifacts and Docker image containing
|
|
1735
|
+
the inference code. Amazon SageMaker uses this information to deploy the
|
|
1736
|
+
model in Amazon SageMaker. This method can also be used to create a Model for an Inference
|
|
1737
|
+
Pipeline if you pass the list of container definitions through the containers parameter.
|
|
1738
|
+
|
|
1739
|
+
Args:
|
|
1740
|
+
name (str): Name of the Amazon SageMaker ``Model`` to create.
|
|
1741
|
+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
|
|
1742
|
+
jobs and APIs that create Amazon SageMaker endpoints use this role to access
|
|
1743
|
+
training data and model artifacts. You must grant sufficient permissions to this
|
|
1744
|
+
role.
|
|
1745
|
+
container_defs (list[dict[str, str]] or [dict[str, str]]): A single container
|
|
1746
|
+
definition or a list of container definitions which will be invoked sequentially
|
|
1747
|
+
while performing the prediction. If the list contains only one container, then
|
|
1748
|
+
it'll be passed to SageMaker Hosting as the ``PrimaryContainer`` and otherwise,
|
|
1749
|
+
it'll be passed as ``Containers``.You can also specify the return value of
|
|
1750
|
+
``sagemaker.get_container_def()`` or ``sagemaker.pipeline_container_def()``,
|
|
1751
|
+
which will used to create more advanced container configurations, including model
|
|
1752
|
+
containers which need artifacts from S3.
|
|
1753
|
+
vpc_config (dict[str, list[str]]): The VpcConfig set on the model (default: None)
|
|
1754
|
+
* 'Subnets' (list[str]): List of subnet ids.
|
|
1755
|
+
* 'SecurityGroupIds' (list[str]): List of security group ids.
|
|
1756
|
+
enable_network_isolation (bool): Whether the model requires network isolation or not.
|
|
1757
|
+
primary_container (str or dict[str, str]): Docker image which defines the inference
|
|
1758
|
+
code. You can also specify the return value of ``sagemaker.container_def()``,
|
|
1759
|
+
which is used to create more advanced container configurations, including model
|
|
1760
|
+
containers which need artifacts from S3. This field is deprecated, please use
|
|
1761
|
+
container_defs instead.
|
|
1762
|
+
tags(Optional[Tags]): Optional. The list of tags to add to the model.
|
|
1763
|
+
|
|
1764
|
+
Example:
|
|
1765
|
+
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
|
|
1766
|
+
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
|
|
1767
|
+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
|
|
1768
|
+
|
|
1769
|
+
Returns:
|
|
1770
|
+
str: Name of the Amazon SageMaker ``Model`` created.
|
|
1771
|
+
"""
|
|
1772
|
+
tags = _append_project_tags(format_tags(tags))
|
|
1773
|
+
tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS))
|
|
1774
|
+
role = resolve_value_from_config(
|
|
1775
|
+
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self
|
|
1776
|
+
)
|
|
1777
|
+
vpc_config = resolve_value_from_config(
|
|
1778
|
+
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self
|
|
1779
|
+
)
|
|
1780
|
+
enable_network_isolation = resolve_value_from_config(
|
|
1781
|
+
direct_input=enable_network_isolation,
|
|
1782
|
+
config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
|
|
1783
|
+
default_value=False,
|
|
1784
|
+
sagemaker_session=self,
|
|
1785
|
+
)
|
|
1786
|
+
|
|
1787
|
+
# Due to ambuiguity in container_defs which accepts both a single
|
|
1788
|
+
# container definition(dtype: dict) and a list of container definitions (dtype: list),
|
|
1789
|
+
# we need to inject environment variables into the container_defs in the helper function
|
|
1790
|
+
# _create_model_request.
|
|
1791
|
+
create_model_request = self._create_model_request(
|
|
1792
|
+
name=name,
|
|
1793
|
+
role=role,
|
|
1794
|
+
container_defs=container_defs,
|
|
1795
|
+
vpc_config=vpc_config,
|
|
1796
|
+
enable_network_isolation=enable_network_isolation,
|
|
1797
|
+
primary_container=primary_container,
|
|
1798
|
+
tags=tags,
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
def submit(request):
|
|
1802
|
+
logger.info("Creating model with name: %s", name)
|
|
1803
|
+
logger.debug("CreateModel request: %s", json.dumps(request, indent=4))
|
|
1804
|
+
try:
|
|
1805
|
+
self.sagemaker_client.create_model(**request)
|
|
1806
|
+
except ClientError as e:
|
|
1807
|
+
error_code = e.response["Error"]["Code"]
|
|
1808
|
+
message = e.response["Error"]["Message"]
|
|
1809
|
+
if (
|
|
1810
|
+
error_code == "ValidationException"
|
|
1811
|
+
and "Cannot create already existing model" in message
|
|
1812
|
+
):
|
|
1813
|
+
logger.warning("Using already existing model: %s", name)
|
|
1814
|
+
else:
|
|
1815
|
+
raise
|
|
1816
|
+
|
|
1817
|
+
self._intercept_create_request(create_model_request, submit, self.create_model.__name__)
|
|
1818
|
+
return name
|
|
1819
|
+
|
|
1820
|
+
def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
|
|
1821
|
+
"""Create a SageMaker Model Package from the results of training with an Algorithm Package.
|
|
1822
|
+
|
|
1823
|
+
Args:
|
|
1824
|
+
name (str): ModelPackage name
|
|
1825
|
+
description (str): Model Package description
|
|
1826
|
+
algorithm_arn (str): arn or name of the algorithm used for training.
|
|
1827
|
+
model_data (str or dict[str, Any]): s3 URI or a dictionary representing a
|
|
1828
|
+
``ModelDataSource`` to the model artifacts produced by training
|
|
1829
|
+
"""
|
|
1830
|
+
sourceAlgorithm = {"AlgorithmName": algorithm_arn}
|
|
1831
|
+
if isinstance(model_data, dict):
|
|
1832
|
+
sourceAlgorithm["ModelDataSource"] = model_data
|
|
1833
|
+
else:
|
|
1834
|
+
sourceAlgorithm["ModelDataUrl"] = model_data
|
|
1835
|
+
|
|
1836
|
+
request = {
|
|
1837
|
+
"ModelPackageName": name,
|
|
1838
|
+
"ModelPackageDescription": description,
|
|
1839
|
+
"SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]},
|
|
1840
|
+
}
|
|
1841
|
+
try:
|
|
1842
|
+
logger.info("Creating model package with name: %s", name)
|
|
1843
|
+
self.sagemaker_client.create_model_package(**request)
|
|
1844
|
+
except ClientError as e:
|
|
1845
|
+
error_code = e.response["Error"]["Code"]
|
|
1846
|
+
message = e.response["Error"]["Message"]
|
|
1847
|
+
|
|
1848
|
+
if error_code == "ValidationException" and "ModelPackage already exists" in message:
|
|
1849
|
+
logger.warning("Using already existing model package: %s", name)
|
|
1850
|
+
else:
|
|
1851
|
+
raise
|
|
1852
|
+
|
|
1853
|
+
def expand_role(self, role):
|
|
1854
|
+
"""Expand an IAM role name into an ARN.
|
|
1855
|
+
|
|
1856
|
+
If the role is already in the form of an ARN, then the role is simply returned. Otherwise
|
|
1857
|
+
we retrieve the full ARN and return it.
|
|
1858
|
+
|
|
1859
|
+
Args:
|
|
1860
|
+
role (str): An AWS IAM role (either name or full ARN).
|
|
1861
|
+
|
|
1862
|
+
Returns:
|
|
1863
|
+
str: The corresponding AWS IAM role ARN.
|
|
1864
|
+
"""
|
|
1865
|
+
if "/" in role:
|
|
1866
|
+
return role
|
|
1867
|
+
return self.boto_session.resource("iam").Role(role).arn
|
|
1868
|
+
|
|
1869
|
+
|
|
1870
|
+
def _expand_container_def(c_def):
|
|
1871
|
+
"""Placeholder docstring"""
|
|
1872
|
+
if isinstance(c_def, six.string_types):
|
|
1873
|
+
return container_def(c_def)
|
|
1874
|
+
return c_def
|
|
1875
|
+
|
|
1876
|
+
|
|
1877
|
+
def expand_role(self, role):
|
|
1878
|
+
"""Expand an IAM role name into an ARN.
|
|
1879
|
+
|
|
1880
|
+
If the role is already in the form of an ARN, then the role is simply returned. Otherwise
|
|
1881
|
+
we retrieve the full ARN and return it.
|
|
1882
|
+
|
|
1883
|
+
Args:
|
|
1884
|
+
role (str): An AWS IAM role (either name or full ARN).
|
|
1885
|
+
|
|
1886
|
+
Returns:
|
|
1887
|
+
str: The corresponding AWS IAM role ARN.
|
|
1888
|
+
"""
|
|
1889
|
+
if "/" in role:
|
|
1890
|
+
return role
|
|
1891
|
+
return self.boto_session.resource("iam").Role(role).arn
|
|
1892
|
+
|
|
1893
|
+
|
|
1894
|
+
def s3_path_join(*args, with_end_slash: bool = False):
|
|
1895
|
+
"""Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
|
|
1896
|
+
|
|
1897
|
+
Behavior of this function:
|
|
1898
|
+
- If the first argument is "s3://", then that is preserved.
|
|
1899
|
+
- The output by default will have no slashes at the beginning or end. There is one exception
|
|
1900
|
+
(see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
|
|
1901
|
+
`"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
|
|
1902
|
+
- Any repeat slashes will be removed in the output (except for "s3://" if provided at the
|
|
1903
|
+
beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
|
|
1904
|
+
`"s3://foo/bar/baz"`.
|
|
1905
|
+
- Empty or None arguments will be skipped. For example
|
|
1906
|
+
`s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
|
|
1907
|
+
|
|
1908
|
+
Alternatives to this function that are NOT recommended for S3 paths:
|
|
1909
|
+
- `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
|
|
1910
|
+
- `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
|
|
1911
|
+
dots (".") and root directories. (for example
|
|
1912
|
+
`pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
|
|
1913
|
+
- `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
|
|
1914
|
+
|
|
1915
|
+
Args:
|
|
1916
|
+
*args: The strings to join with a slash.
|
|
1917
|
+
with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
|
|
1918
|
+
to the end of the path
|
|
1919
|
+
|
|
1920
|
+
Returns:
|
|
1921
|
+
str: The joined string, without a slash at the end unless with_end_slash is True.
|
|
1922
|
+
"""
|
|
1923
|
+
delimiter = "/"
|
|
1924
|
+
|
|
1925
|
+
non_empty_args = list(filter(lambda item: item is not None and item != "", args))
|
|
1926
|
+
|
|
1927
|
+
merged_path = ""
|
|
1928
|
+
for index, path in enumerate(non_empty_args):
|
|
1929
|
+
if (
|
|
1930
|
+
index == 0
|
|
1931
|
+
or (merged_path and merged_path[-1] == delimiter)
|
|
1932
|
+
or (path and path[0] == delimiter)
|
|
1933
|
+
):
|
|
1934
|
+
# dont need to add an extra slash because either this is the beginning of the string,
|
|
1935
|
+
# or one (or more) slash already exists
|
|
1936
|
+
merged_path += path
|
|
1937
|
+
else:
|
|
1938
|
+
merged_path += delimiter + path
|
|
1939
|
+
|
|
1940
|
+
if with_end_slash and merged_path and merged_path[-1] != delimiter:
|
|
1941
|
+
merged_path += delimiter
|
|
1942
|
+
|
|
1943
|
+
# At this point, merged_path may include slashes at the beginning and/or end. And some of the
|
|
1944
|
+
# provided args may have had duplicate slashes inside or at the ends.
|
|
1945
|
+
# For backwards compatibility reasons, these need to be filtered out (done below). In the
|
|
1946
|
+
# future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
|
|
1947
|
+
# one option is to create a new optional argument (or a new function) that only executes the
|
|
1948
|
+
# logic above.
|
|
1949
|
+
filtered_path = merged_path
|
|
1950
|
+
|
|
1951
|
+
# remove duplicate slashes
|
|
1952
|
+
if filtered_path:
|
|
1953
|
+
|
|
1954
|
+
def duplicate_delimiter_remover(sequence, next_char):
|
|
1955
|
+
if sequence[-1] == delimiter and next_char == delimiter:
|
|
1956
|
+
return sequence
|
|
1957
|
+
return sequence + next_char
|
|
1958
|
+
|
|
1959
|
+
if filtered_path.startswith("s3://"):
|
|
1960
|
+
filtered_path = reduce(
|
|
1961
|
+
duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
|
|
1962
|
+
)
|
|
1963
|
+
else:
|
|
1964
|
+
filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
|
|
1965
|
+
|
|
1966
|
+
# remove beginning slashes
|
|
1967
|
+
filtered_path = filtered_path.lstrip(delimiter)
|
|
1968
|
+
|
|
1969
|
+
# remove end slashes
|
|
1970
|
+
if not with_end_slash and filtered_path != "s3://":
|
|
1971
|
+
filtered_path = filtered_path.rstrip(delimiter)
|
|
1972
|
+
|
|
1973
|
+
return filtered_path
|
|
1974
|
+
|
|
1975
|
+
|
|
1976
|
+
def botocore_resolver():
|
|
1977
|
+
"""Get the DNS suffix for the given region.
|
|
1978
|
+
|
|
1979
|
+
Args:
|
|
1980
|
+
region (str): AWS region name
|
|
1981
|
+
|
|
1982
|
+
Returns:
|
|
1983
|
+
str: the DNS suffix
|
|
1984
|
+
"""
|
|
1985
|
+
loader = botocore.loaders.create_loader()
|
|
1986
|
+
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
|
|
1987
|
+
|
|
1988
|
+
|
|
1989
|
+
def sts_regional_endpoint(region):
|
|
1990
|
+
"""Get the AWS STS endpoint specific for the given region.
|
|
1991
|
+
|
|
1992
|
+
We need this function because the AWS SDK does not yet honor
|
|
1993
|
+
the ``region_name`` parameter when creating an AWS STS client.
|
|
1994
|
+
|
|
1995
|
+
For the list of regional endpoints, see
|
|
1996
|
+
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
|
|
1997
|
+
|
|
1998
|
+
Args:
|
|
1999
|
+
region (str): AWS region name
|
|
2000
|
+
|
|
2001
|
+
Returns:
|
|
2002
|
+
str: AWS STS regional endpoint
|
|
2003
|
+
"""
|
|
2004
|
+
endpoint_data = botocore_resolver().construct_endpoint("sts", region)
|
|
2005
|
+
if region == "il-central-1" and not endpoint_data:
|
|
2006
|
+
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
|
|
2007
|
+
return "https://{}".format(endpoint_data["hostname"])
|
|
2008
|
+
|
|
2009
|
+
|
|
2010
|
+
def get_execution_role(sagemaker_session=None, use_default=False):
|
|
2011
|
+
"""Return the role ARN whose credentials are used to call the API.
|
|
2012
|
+
|
|
2013
|
+
Throws an exception if role doesn't exist.
|
|
2014
|
+
|
|
2015
|
+
Args:
|
|
2016
|
+
sagemaker_session (Session): Current sagemaker session.
|
|
2017
|
+
use_default (bool): Use a default role if ``get_caller_identity_arn`` does not
|
|
2018
|
+
return a correct role. This default role will be created if needed.
|
|
2019
|
+
Defaults to ``False``.
|
|
2020
|
+
|
|
2021
|
+
Returns:
|
|
2022
|
+
(str): The role ARN
|
|
2023
|
+
"""
|
|
2024
|
+
if not sagemaker_session:
|
|
2025
|
+
sagemaker_session = Session()
|
|
2026
|
+
arn = sagemaker_session.get_caller_identity_arn()
|
|
2027
|
+
|
|
2028
|
+
if ":role/" in arn:
|
|
2029
|
+
return arn
|
|
2030
|
+
|
|
2031
|
+
if use_default:
|
|
2032
|
+
default_role_name = "AmazonSageMaker-DefaultRole"
|
|
2033
|
+
|
|
2034
|
+
LOGGER.warning("Using default role: %s", default_role_name)
|
|
2035
|
+
|
|
2036
|
+
boto3_session = sagemaker_session.boto_session
|
|
2037
|
+
permissions_policy = json.dumps(
|
|
2038
|
+
{
|
|
2039
|
+
"Version": "2012-10-17",
|
|
2040
|
+
"Statement": [
|
|
2041
|
+
{
|
|
2042
|
+
"Effect": "Allow",
|
|
2043
|
+
"Principal": {"Service": ["sagemaker.amazonaws.com"]},
|
|
2044
|
+
"Action": "sts:AssumeRole",
|
|
2045
|
+
}
|
|
2046
|
+
],
|
|
2047
|
+
}
|
|
2048
|
+
)
|
|
2049
|
+
iam_client = boto3_session.client("iam")
|
|
2050
|
+
try:
|
|
2051
|
+
iam_client.get_role(RoleName=default_role_name)
|
|
2052
|
+
except iam_client.exceptions.NoSuchEntityException:
|
|
2053
|
+
iam_client.create_role(
|
|
2054
|
+
RoleName=default_role_name, AssumeRolePolicyDocument=str(permissions_policy)
|
|
2055
|
+
)
|
|
2056
|
+
|
|
2057
|
+
LOGGER.warning("Created new sagemaker execution role: %s", default_role_name)
|
|
2058
|
+
|
|
2059
|
+
iam_client.attach_role_policy(
|
|
2060
|
+
PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
|
|
2061
|
+
RoleName=default_role_name,
|
|
2062
|
+
)
|
|
2063
|
+
return iam_client.get_role(RoleName=default_role_name)["Role"]["Arn"]
|
|
2064
|
+
|
|
2065
|
+
message = (
|
|
2066
|
+
"The current AWS identity is not a role: {}, therefore it cannot be used as a "
|
|
2067
|
+
"SageMaker execution role"
|
|
2068
|
+
)
|
|
2069
|
+
raise ValueError(message.format(arn))
|
|
2070
|
+
|
|
2071
|
+
|
|
2072
|
+
def get_add_model_package_inference_args(
|
|
2073
|
+
model_package_arn,
|
|
2074
|
+
name,
|
|
2075
|
+
containers=None,
|
|
2076
|
+
content_types=None,
|
|
2077
|
+
response_types=None,
|
|
2078
|
+
inference_instances=None,
|
|
2079
|
+
transform_instances=None,
|
|
2080
|
+
description=None,
|
|
2081
|
+
):
|
|
2082
|
+
"""Get request dictionary for UpdateModelPackage API for additional inference.
|
|
2083
|
+
|
|
2084
|
+
Args:
|
|
2085
|
+
model_package_arn (str): Arn for the model package.
|
|
2086
|
+
name (str): Name to identify the additional inference specification
|
|
2087
|
+
containers (dict): The Amazon ECR registry path of the Docker image
|
|
2088
|
+
that contains the inference code.
|
|
2089
|
+
image_uris (List[str]): The ECR path where inference code is stored.
|
|
2090
|
+
description (str): Description for the additional inference specification
|
|
2091
|
+
content_types (list[str]): The supported MIME types
|
|
2092
|
+
for the input data.
|
|
2093
|
+
response_types (list[str]): The supported MIME types
|
|
2094
|
+
for the output data.
|
|
2095
|
+
inference_instances (list[str]): A list of the instance
|
|
2096
|
+
types that are used to generate inferences in real-time (default: None).
|
|
2097
|
+
transform_instances (list[str]): A list of the instance
|
|
2098
|
+
types on which a transformation job can be run or on which an endpoint can be
|
|
2099
|
+
deployed (default: None).
|
|
2100
|
+
"""
|
|
2101
|
+
|
|
2102
|
+
request_dict = {}
|
|
2103
|
+
if containers is not None:
|
|
2104
|
+
inference_specification = {
|
|
2105
|
+
"Containers": containers,
|
|
2106
|
+
}
|
|
2107
|
+
|
|
2108
|
+
if name is not None:
|
|
2109
|
+
inference_specification.update({"Name": name})
|
|
2110
|
+
|
|
2111
|
+
if description is not None:
|
|
2112
|
+
inference_specification.update({"Description": description})
|
|
2113
|
+
if content_types is not None:
|
|
2114
|
+
inference_specification.update(
|
|
2115
|
+
{
|
|
2116
|
+
"SupportedContentTypes": content_types,
|
|
2117
|
+
}
|
|
2118
|
+
)
|
|
2119
|
+
if response_types is not None:
|
|
2120
|
+
inference_specification.update(
|
|
2121
|
+
{
|
|
2122
|
+
"SupportedResponseMIMETypes": response_types,
|
|
2123
|
+
}
|
|
2124
|
+
)
|
|
2125
|
+
if inference_instances is not None:
|
|
2126
|
+
inference_specification.update(
|
|
2127
|
+
{
|
|
2128
|
+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
|
|
2129
|
+
}
|
|
2130
|
+
)
|
|
2131
|
+
if transform_instances is not None:
|
|
2132
|
+
inference_specification.update(
|
|
2133
|
+
{
|
|
2134
|
+
"SupportedTransformInstanceTypes": transform_instances,
|
|
2135
|
+
}
|
|
2136
|
+
)
|
|
2137
|
+
request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
|
|
2138
|
+
request_dict.update({"ModelPackageArn": model_package_arn})
|
|
2139
|
+
return request_dict
|
|
2140
|
+
|
|
2141
|
+
|
|
2142
|
+
def get_update_model_package_inference_args(
|
|
2143
|
+
model_package_arn,
|
|
2144
|
+
containers=None,
|
|
2145
|
+
content_types=None,
|
|
2146
|
+
response_types=None,
|
|
2147
|
+
inference_instances=None,
|
|
2148
|
+
transform_instances=None,
|
|
2149
|
+
):
|
|
2150
|
+
"""Get request dictionary for UpdateModelPackage API for inference specification.
|
|
2151
|
+
|
|
2152
|
+
Args:
|
|
2153
|
+
model_package_arn (str): Arn for the model package.
|
|
2154
|
+
containers (dict): The Amazon ECR registry path of the Docker image
|
|
2155
|
+
that contains the inference code.
|
|
2156
|
+
content_types (list[str]): The supported MIME types
|
|
2157
|
+
for the input data.
|
|
2158
|
+
response_types (list[str]): The supported MIME types
|
|
2159
|
+
for the output data.
|
|
2160
|
+
inference_instances (list[str]): A list of the instance
|
|
2161
|
+
types that are used to generate inferences in real-time (default: None).
|
|
2162
|
+
transform_instances (list[str]): A list of the instance
|
|
2163
|
+
types on which a transformation job can be run or on which an endpoint can be
|
|
2164
|
+
deployed (default: None).
|
|
2165
|
+
"""
|
|
2166
|
+
|
|
2167
|
+
request_dict = {}
|
|
2168
|
+
if containers is not None:
|
|
2169
|
+
inference_specification = {
|
|
2170
|
+
"Containers": containers,
|
|
2171
|
+
}
|
|
2172
|
+
if content_types is not None:
|
|
2173
|
+
inference_specification.update(
|
|
2174
|
+
{
|
|
2175
|
+
"SupportedContentTypes": content_types,
|
|
2176
|
+
}
|
|
2177
|
+
)
|
|
2178
|
+
if response_types is not None:
|
|
2179
|
+
inference_specification.update(
|
|
2180
|
+
{
|
|
2181
|
+
"SupportedResponseMIMETypes": response_types,
|
|
2182
|
+
}
|
|
2183
|
+
)
|
|
2184
|
+
if inference_instances is not None:
|
|
2185
|
+
inference_specification.update(
|
|
2186
|
+
{
|
|
2187
|
+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
|
|
2188
|
+
}
|
|
2189
|
+
)
|
|
2190
|
+
if transform_instances is not None:
|
|
2191
|
+
inference_specification.update(
|
|
2192
|
+
{
|
|
2193
|
+
"SupportedTransformInstanceTypes": transform_instances,
|
|
2194
|
+
}
|
|
2195
|
+
)
|
|
2196
|
+
request_dict["InferenceSpecification"] = inference_specification
|
|
2197
|
+
request_dict.update({"ModelPackageArn": model_package_arn})
|
|
2198
|
+
return request_dict
|
|
2199
|
+
|
|
2200
|
+
|
|
2201
|
+
def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
|
|
2202
|
+
sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
|
|
2203
|
+
):
|
|
2204
|
+
"""Display logs for a given training job, optionally tailing them until job is complete.
|
|
2205
|
+
|
|
2206
|
+
If the output is a tty or a Jupyter cell, it will be color-coded
|
|
2207
|
+
based on which instance the log entry is from.
|
|
2208
|
+
|
|
2209
|
+
Args:
|
|
2210
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
2211
|
+
object, used for SageMaker interactions.
|
|
2212
|
+
job_name (str): Name of the training job to display the logs for.
|
|
2213
|
+
wait (bool): Whether to keep looking for new log entries until the job completes
|
|
2214
|
+
(default: False).
|
|
2215
|
+
poll (int): The interval in seconds between polling for new log entries and job
|
|
2216
|
+
completion (default: 5).
|
|
2217
|
+
log_type ([str]): A list of strings specifying which logs to print. Acceptable
|
|
2218
|
+
strings are "All", "None", "Training", or "Rules". To maintain backwards
|
|
2219
|
+
compatibility, boolean values are also accepted and converted to strings.
|
|
2220
|
+
timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
|
|
2221
|
+
default.
|
|
2222
|
+
Returns:
|
|
2223
|
+
Last call to sagemaker DescribeTrainingJob
|
|
2224
|
+
Raises:
|
|
2225
|
+
exceptions.CapacityError: If the training job fails with CapacityError.
|
|
2226
|
+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
|
|
2227
|
+
"""
|
|
2228
|
+
sagemaker_client = sagemaker_session.sagemaker_client
|
|
2229
|
+
request_end_time = time.time() + timeout if timeout else None
|
|
2230
|
+
description = _wait_until(
|
|
2231
|
+
lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name)
|
|
2232
|
+
)
|
|
2233
|
+
print(secondary_training_status_message(description, None), end="")
|
|
2234
|
+
|
|
2235
|
+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
|
|
2236
|
+
sagemaker_session.boto_session, description, job="Training"
|
|
2237
|
+
)
|
|
2238
|
+
|
|
2239
|
+
state = _get_initial_job_state(description, "TrainingJobStatus", wait)
|
|
2240
|
+
|
|
2241
|
+
# The loop below implements a state machine that alternates between checking the job status
|
|
2242
|
+
# and reading whatever is available in the logs at this point. Note, that if we were
|
|
2243
|
+
# called with wait == False, we never check the job status.
|
|
2244
|
+
#
|
|
2245
|
+
# If wait == TRUE and job is not completed, the initial state is TAILING
|
|
2246
|
+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
|
|
2247
|
+
# complete).
|
|
2248
|
+
#
|
|
2249
|
+
# The state table:
|
|
2250
|
+
#
|
|
2251
|
+
# STATE ACTIONS CONDITION NEW STATE
|
|
2252
|
+
# ---------------- ---------------- ----------------- ----------------
|
|
2253
|
+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
|
|
2254
|
+
# Else TAILING
|
|
2255
|
+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
|
|
2256
|
+
# COMPLETE Read logs, Exit N/A
|
|
2257
|
+
#
|
|
2258
|
+
# Notes:
|
|
2259
|
+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
|
|
2260
|
+
# Cloudwatch after the job was marked complete.
|
|
2261
|
+
last_describe_job_call = time.time()
|
|
2262
|
+
last_description = description
|
|
2263
|
+
last_debug_rule_statuses = None
|
|
2264
|
+
last_profiler_rule_statuses = None
|
|
2265
|
+
|
|
2266
|
+
while True:
|
|
2267
|
+
_flush_log_streams(
|
|
2268
|
+
stream_names,
|
|
2269
|
+
instance_count,
|
|
2270
|
+
client,
|
|
2271
|
+
log_group,
|
|
2272
|
+
job_name,
|
|
2273
|
+
positions,
|
|
2274
|
+
dot,
|
|
2275
|
+
color_wrap,
|
|
2276
|
+
)
|
|
2277
|
+
if timeout and time.time() > request_end_time:
|
|
2278
|
+
print("Timeout Exceeded. {} seconds elapsed.".format(timeout))
|
|
2279
|
+
break
|
|
2280
|
+
|
|
2281
|
+
if state == LogState.COMPLETE:
|
|
2282
|
+
break
|
|
2283
|
+
|
|
2284
|
+
time.sleep(poll)
|
|
2285
|
+
|
|
2286
|
+
if state == LogState.JOB_COMPLETE:
|
|
2287
|
+
state = LogState.COMPLETE
|
|
2288
|
+
elif time.time() - last_describe_job_call >= 30:
|
|
2289
|
+
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
|
|
2290
|
+
last_describe_job_call = time.time()
|
|
2291
|
+
|
|
2292
|
+
if secondary_training_status_changed(description, last_description):
|
|
2293
|
+
print()
|
|
2294
|
+
print(secondary_training_status_message(description, last_description), end="")
|
|
2295
|
+
last_description = description
|
|
2296
|
+
|
|
2297
|
+
status = description["TrainingJobStatus"]
|
|
2298
|
+
|
|
2299
|
+
if status in ("Completed", "Failed", "Stopped"):
|
|
2300
|
+
print()
|
|
2301
|
+
state = LogState.JOB_COMPLETE
|
|
2302
|
+
|
|
2303
|
+
# Print prettified logs related to the status of SageMaker Debugger rules.
|
|
2304
|
+
debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {})
|
|
2305
|
+
if (
|
|
2306
|
+
debug_rule_statuses
|
|
2307
|
+
and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
|
|
2308
|
+
and (log_type in {"All", "Rules"})
|
|
2309
|
+
):
|
|
2310
|
+
for status in debug_rule_statuses:
|
|
2311
|
+
rule_log = (
|
|
2312
|
+
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
|
|
2313
|
+
)
|
|
2314
|
+
print(rule_log)
|
|
2315
|
+
|
|
2316
|
+
last_debug_rule_statuses = debug_rule_statuses
|
|
2317
|
+
|
|
2318
|
+
# Print prettified logs related to the status of SageMaker Profiler rules.
|
|
2319
|
+
profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {})
|
|
2320
|
+
if (
|
|
2321
|
+
profiler_rule_statuses
|
|
2322
|
+
and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses)
|
|
2323
|
+
and (log_type in {"All", "Rules"})
|
|
2324
|
+
):
|
|
2325
|
+
for status in profiler_rule_statuses:
|
|
2326
|
+
rule_log = (
|
|
2327
|
+
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
|
|
2328
|
+
)
|
|
2329
|
+
print(rule_log)
|
|
2330
|
+
|
|
2331
|
+
last_profiler_rule_statuses = profiler_rule_statuses
|
|
2332
|
+
|
|
2333
|
+
if wait:
|
|
2334
|
+
_check_job_status(job_name, description, "TrainingJobStatus")
|
|
2335
|
+
if dot:
|
|
2336
|
+
print()
|
|
2337
|
+
# Customers are not billed for hardware provisioning, so billable time is less than
|
|
2338
|
+
# total time
|
|
2339
|
+
training_time = description.get("TrainingTimeInSeconds")
|
|
2340
|
+
billable_time = description.get("BillableTimeInSeconds")
|
|
2341
|
+
if training_time is not None:
|
|
2342
|
+
print("Training seconds:", training_time * instance_count)
|
|
2343
|
+
if billable_time is not None:
|
|
2344
|
+
print("Billable seconds:", billable_time * instance_count)
|
|
2345
|
+
if description.get("EnableManagedSpotTraining"):
|
|
2346
|
+
saving = (1 - float(billable_time) / training_time) * 100
|
|
2347
|
+
print("Managed Spot Training savings: {:.1f}%".format(saving))
|
|
2348
|
+
return last_description
|
|
2349
|
+
|
|
2350
|
+
|
|
2351
|
+
def _check_job_status(job, desc, status_key_name):
|
|
2352
|
+
"""Check to see if the job completed successfully.
|
|
2353
|
+
|
|
2354
|
+
If not, construct and raise a exceptions. (UnexpectedStatusException).
|
|
2355
|
+
|
|
2356
|
+
Args:
|
|
2357
|
+
job (str): The name of the job to check.
|
|
2358
|
+
desc (dict[str, str]): The result of ``describe_training_job()``.
|
|
2359
|
+
status_key_name (str): Status key name to check for.
|
|
2360
|
+
|
|
2361
|
+
Raises:
|
|
2362
|
+
exceptions.CapacityError: If the training job fails with CapacityError.
|
|
2363
|
+
exceptions.UnexpectedStatusException: If the training job fails.
|
|
2364
|
+
"""
|
|
2365
|
+
status = desc[status_key_name]
|
|
2366
|
+
# If the status is capital case, then convert it to Camel case
|
|
2367
|
+
status = _STATUS_CODE_TABLE.get(status, status)
|
|
2368
|
+
|
|
2369
|
+
if status == "Stopped":
|
|
2370
|
+
logger.warning(
|
|
2371
|
+
"Job ended with status 'Stopped' rather than 'Completed'. "
|
|
2372
|
+
"This could mean the job timed out or stopped early for some other reason: "
|
|
2373
|
+
"Consider checking whether it completed as you expect."
|
|
2374
|
+
)
|
|
2375
|
+
elif status != "Completed":
|
|
2376
|
+
reason = desc.get("FailureReason", "(No reason provided)")
|
|
2377
|
+
job_type = status_key_name.replace("JobStatus", " job")
|
|
2378
|
+
troubleshooting = (
|
|
2379
|
+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
|
|
2380
|
+
"sagemaker-python-sdk-troubleshooting.html"
|
|
2381
|
+
)
|
|
2382
|
+
message = (
|
|
2383
|
+
"Error for {job_type} {job_name}: {status}. Reason: {reason}. "
|
|
2384
|
+
"Check troubleshooting guide for common errors: {troubleshooting}"
|
|
2385
|
+
).format(
|
|
2386
|
+
job_type=job_type,
|
|
2387
|
+
job_name=job,
|
|
2388
|
+
status=status,
|
|
2389
|
+
reason=reason,
|
|
2390
|
+
troubleshooting=troubleshooting,
|
|
2391
|
+
)
|
|
2392
|
+
if "CapacityError" in str(reason):
|
|
2393
|
+
raise exceptions.CapacityError(
|
|
2394
|
+
message=message,
|
|
2395
|
+
allowed_statuses=["Completed", "Stopped"],
|
|
2396
|
+
actual_status=status,
|
|
2397
|
+
)
|
|
2398
|
+
raise exceptions.UnexpectedStatusException(
|
|
2399
|
+
message=message,
|
|
2400
|
+
allowed_statuses=["Completed", "Stopped"],
|
|
2401
|
+
actual_status=status,
|
|
2402
|
+
)
|
|
2403
|
+
|
|
2404
|
+
|
|
2405
|
+
def _logs_init(boto_session, description, job):
|
|
2406
|
+
"""Placeholder docstring"""
|
|
2407
|
+
if job == "Training":
|
|
2408
|
+
if "InstanceGroups" in description["ResourceConfig"]:
|
|
2409
|
+
instance_count = 0
|
|
2410
|
+
for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
|
|
2411
|
+
instance_count += instanceGroup["InstanceCount"]
|
|
2412
|
+
else:
|
|
2413
|
+
instance_count = description["ResourceConfig"]["InstanceCount"]
|
|
2414
|
+
elif job == "Transform":
|
|
2415
|
+
instance_count = description["TransformResources"]["InstanceCount"]
|
|
2416
|
+
elif job == "Processing":
|
|
2417
|
+
instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
|
|
2418
|
+
elif job == "AutoML":
|
|
2419
|
+
instance_count = 0
|
|
2420
|
+
|
|
2421
|
+
stream_names = [] # The list of log streams
|
|
2422
|
+
positions = {} # The current position in each stream, map of stream name -> position
|
|
2423
|
+
|
|
2424
|
+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
|
|
2425
|
+
# to be interrupted by a transient exception.
|
|
2426
|
+
config = botocore.config.Config(retries={"max_attempts": 15})
|
|
2427
|
+
client = boto_session.client("logs", config=config)
|
|
2428
|
+
log_group = "/aws/sagemaker/" + job + "Jobs"
|
|
2429
|
+
|
|
2430
|
+
dot = False
|
|
2431
|
+
|
|
2432
|
+
color_wrap = sagemaker.core.logs.ColorWrap()
|
|
2433
|
+
|
|
2434
|
+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
|
|
2435
|
+
|
|
2436
|
+
|
|
2437
|
+
def _flush_log_streams(
|
|
2438
|
+
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
|
|
2439
|
+
):
|
|
2440
|
+
"""Placeholder docstring"""
|
|
2441
|
+
if len(stream_names) < instance_count:
|
|
2442
|
+
# Log streams are created whenever a container starts writing to stdout/err, so this list
|
|
2443
|
+
# may be dynamic until we have a stream for every instance.
|
|
2444
|
+
try:
|
|
2445
|
+
streams = client.describe_log_streams(
|
|
2446
|
+
logGroupName=log_group,
|
|
2447
|
+
logStreamNamePrefix=job_name + "/",
|
|
2448
|
+
orderBy="LogStreamName",
|
|
2449
|
+
limit=min(instance_count, 50),
|
|
2450
|
+
)
|
|
2451
|
+
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
|
|
2452
|
+
|
|
2453
|
+
while "nextToken" in streams:
|
|
2454
|
+
streams = client.describe_log_streams(
|
|
2455
|
+
logGroupName=log_group,
|
|
2456
|
+
logStreamNamePrefix=job_name + "/",
|
|
2457
|
+
orderBy="LogStreamName",
|
|
2458
|
+
limit=50,
|
|
2459
|
+
)
|
|
2460
|
+
|
|
2461
|
+
stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
|
|
2462
|
+
|
|
2463
|
+
positions.update(
|
|
2464
|
+
[
|
|
2465
|
+
(s, sagemaker.core.logs.Position(timestamp=0, skip=0))
|
|
2466
|
+
for s in stream_names
|
|
2467
|
+
if s not in positions
|
|
2468
|
+
]
|
|
2469
|
+
)
|
|
2470
|
+
except ClientError as e:
|
|
2471
|
+
# On the very first training job run on an account, there's no log group until
|
|
2472
|
+
# the container starts logging, so ignore any errors thrown about that
|
|
2473
|
+
err = e.response.get("Error", {})
|
|
2474
|
+
if err.get("Code", None) != "ResourceNotFoundException":
|
|
2475
|
+
raise
|
|
2476
|
+
|
|
2477
|
+
if len(stream_names) > 0:
|
|
2478
|
+
if dot:
|
|
2479
|
+
print("")
|
|
2480
|
+
dot = False
|
|
2481
|
+
for idx, event in sagemaker.core.logs.multi_stream_iter(
|
|
2482
|
+
client, log_group, stream_names, positions
|
|
2483
|
+
):
|
|
2484
|
+
color_wrap(idx, event["message"])
|
|
2485
|
+
ts, count = positions[stream_names[idx]]
|
|
2486
|
+
if event["timestamp"] == ts:
|
|
2487
|
+
positions[stream_names[idx]] = sagemaker.core.logs.Position(
|
|
2488
|
+
timestamp=ts, skip=count + 1
|
|
2489
|
+
)
|
|
2490
|
+
else:
|
|
2491
|
+
positions[stream_names[idx]] = sagemaker.core.logs.Position(
|
|
2492
|
+
timestamp=event["timestamp"], skip=1
|
|
2493
|
+
)
|
|
2494
|
+
else:
|
|
2495
|
+
dot = True
|
|
2496
|
+
print(".", end="")
|
|
2497
|
+
sys.stdout.flush()
|
|
2498
|
+
|
|
2499
|
+
|
|
2500
|
+
def _wait_until(callable_fn, poll=5):
|
|
2501
|
+
"""Placeholder docstring"""
|
|
2502
|
+
elapsed_time = 0
|
|
2503
|
+
result = None
|
|
2504
|
+
while result is None:
|
|
2505
|
+
try:
|
|
2506
|
+
elapsed_time += poll
|
|
2507
|
+
time.sleep(poll)
|
|
2508
|
+
result = callable_fn()
|
|
2509
|
+
except botocore.exceptions.ClientError as err:
|
|
2510
|
+
# For initial 5 mins we accept/pass AccessDeniedException.
|
|
2511
|
+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
|
|
2512
|
+
# access policy based on resource tags, The caveat here is for true AccessDenied
|
|
2513
|
+
# cases the routine will fail after 5 mins
|
|
2514
|
+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
|
|
2515
|
+
logger.warning(
|
|
2516
|
+
"Received AccessDeniedException. This could mean the IAM role does not "
|
|
2517
|
+
"have the resource permissions, in which case please add resource access "
|
|
2518
|
+
"and retry. For cases where the role has tag based resource policy, "
|
|
2519
|
+
"continuing to wait for tag propagation.."
|
|
2520
|
+
)
|
|
2521
|
+
continue
|
|
2522
|
+
raise err
|
|
2523
|
+
return result
|
|
2524
|
+
|
|
2525
|
+
|
|
2526
|
+
def _get_initial_job_state(description, status_key, wait):
|
|
2527
|
+
"""Placeholder docstring"""
|
|
2528
|
+
status = description[status_key]
|
|
2529
|
+
job_already_completed = status in ("Completed", "Failed", "Stopped")
|
|
2530
|
+
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
|
|
2531
|
+
|
|
2532
|
+
|
|
2533
|
+
def _rule_statuses_changed(current_statuses, last_statuses):
|
|
2534
|
+
"""Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
|
|
2535
|
+
if not last_statuses:
|
|
2536
|
+
return True
|
|
2537
|
+
|
|
2538
|
+
for current, last in zip(current_statuses, last_statuses):
|
|
2539
|
+
if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and (
|
|
2540
|
+
current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"]
|
|
2541
|
+
):
|
|
2542
|
+
return True
|
|
2543
|
+
|
|
2544
|
+
return False
|
|
2545
|
+
|
|
2546
|
+
|
|
2547
|
+
def update_args(args: Dict[str, Any], **kwargs):
|
|
2548
|
+
"""Updates the request arguments dict with the value if populated.
|
|
2549
|
+
|
|
2550
|
+
This is to handle the case that the service API doesn't like NoneTypes for argument values.
|
|
2551
|
+
|
|
2552
|
+
Args:
|
|
2553
|
+
request_args (Dict[str, Any]): the request arguments dict
|
|
2554
|
+
kwargs: key, value pairs to update the args dict
|
|
2555
|
+
"""
|
|
2556
|
+
for key, value in kwargs.items():
|
|
2557
|
+
if value is not None:
|
|
2558
|
+
args.update({key: value})
|
|
2559
|
+
|
|
2560
|
+
|
|
2561
|
+
def production_variant(
|
|
2562
|
+
model_name=None,
|
|
2563
|
+
instance_type=None,
|
|
2564
|
+
initial_instance_count=None,
|
|
2565
|
+
variant_name="AllTraffic",
|
|
2566
|
+
initial_weight=1,
|
|
2567
|
+
accelerator_type=None,
|
|
2568
|
+
serverless_inference_config=None,
|
|
2569
|
+
volume_size=None,
|
|
2570
|
+
model_data_download_timeout=None,
|
|
2571
|
+
container_startup_health_check_timeout=None,
|
|
2572
|
+
managed_instance_scaling=None,
|
|
2573
|
+
routing_config=None,
|
|
2574
|
+
inference_ami_version=None,
|
|
2575
|
+
):
|
|
2576
|
+
"""Create a production variant description suitable for use in a ``ProductionVariant`` list.
|
|
2577
|
+
|
|
2578
|
+
This is also part of a ``CreateEndpointConfig`` request.
|
|
2579
|
+
|
|
2580
|
+
Args:
|
|
2581
|
+
model_name (str): The name of the SageMaker model this production variant references.
|
|
2582
|
+
instance_type (str): The EC2 instance type for this production variant. For example,
|
|
2583
|
+
'ml.c4.8xlarge'.
|
|
2584
|
+
initial_instance_count (int): The initial instance count for this production variant
|
|
2585
|
+
(default: 1).
|
|
2586
|
+
variant_name (string): The ``VariantName`` of this production variant
|
|
2587
|
+
(default: 'AllTraffic').
|
|
2588
|
+
initial_weight (int): The relative ``InitialVariantWeight`` of this production variant
|
|
2589
|
+
(default: 1).
|
|
2590
|
+
accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
|
|
2591
|
+
For example, 'ml.eia1.medium'.
|
|
2592
|
+
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
|
|
2593
|
+
serverless_inference_config (dict): Specifies configuration dict related to serverless
|
|
2594
|
+
endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
|
|
2595
|
+
object (default: None)
|
|
2596
|
+
volume_size (int): The size, in GB, of the ML storage volume attached to individual
|
|
2597
|
+
inference instance associated with the production variant. Currenly only Amazon EBS
|
|
2598
|
+
gp2 storage volumes are supported.
|
|
2599
|
+
model_data_download_timeout (int): The timeout value, in seconds, to download and extract
|
|
2600
|
+
model data from Amazon S3 to the individual inference instance associated with this
|
|
2601
|
+
production variant.
|
|
2602
|
+
container_startup_health_check_timeout (int): The timeout value, in seconds, for your
|
|
2603
|
+
inference container to pass health check by SageMaker Hosting. For more information
|
|
2604
|
+
about health check see:
|
|
2605
|
+
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
|
|
2606
|
+
#your-algorithms
|
|
2607
|
+
-inference-algo-ping-requests
|
|
2608
|
+
|
|
2609
|
+
Returns:
|
|
2610
|
+
dict[str, str]: An SageMaker ``ProductionVariant`` description
|
|
2611
|
+
"""
|
|
2612
|
+
production_variant_configuration = {
|
|
2613
|
+
"VariantName": variant_name,
|
|
2614
|
+
}
|
|
2615
|
+
if model_name:
|
|
2616
|
+
production_variant_configuration["ModelName"] = model_name
|
|
2617
|
+
production_variant_configuration["InitialVariantWeight"] = initial_weight
|
|
2618
|
+
|
|
2619
|
+
if accelerator_type:
|
|
2620
|
+
production_variant_configuration["AcceleratorType"] = accelerator_type
|
|
2621
|
+
|
|
2622
|
+
if managed_instance_scaling:
|
|
2623
|
+
production_variant_configuration["ManagedInstanceScaling"] = managed_instance_scaling
|
|
2624
|
+
|
|
2625
|
+
if serverless_inference_config:
|
|
2626
|
+
production_variant_configuration["ServerlessConfig"] = serverless_inference_config
|
|
2627
|
+
else:
|
|
2628
|
+
initial_instance_count = initial_instance_count or 1
|
|
2629
|
+
production_variant_configuration["InitialInstanceCount"] = initial_instance_count
|
|
2630
|
+
production_variant_configuration["InstanceType"] = instance_type
|
|
2631
|
+
update_args(
|
|
2632
|
+
production_variant_configuration,
|
|
2633
|
+
VolumeSizeInGB=volume_size,
|
|
2634
|
+
ModelDataDownloadTimeoutInSeconds=model_data_download_timeout,
|
|
2635
|
+
ContainerStartupHealthCheckTimeoutInSeconds=container_startup_health_check_timeout,
|
|
2636
|
+
RoutingConfig=routing_config,
|
|
2637
|
+
)
|
|
2638
|
+
|
|
2639
|
+
if inference_ami_version:
|
|
2640
|
+
production_variant_configuration["InferenceAmiVersion"] = inference_ami_version
|
|
2641
|
+
|
|
2642
|
+
return production_variant_configuration
|
|
2643
|
+
|
|
2644
|
+
|
|
2645
|
+
def _has_permission_for_live_logging(boto_session, endpoint_name) -> bool:
|
|
2646
|
+
"""Validate if customer's role has the right permission to access logs from CloudWatch"""
|
|
2647
|
+
try:
|
|
2648
|
+
cloudwatch_client = boto_session.client("logs")
|
|
2649
|
+
cloudwatch_client.filter_log_events(
|
|
2650
|
+
logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
|
|
2651
|
+
logStreamNamePrefix="AllTraffic/",
|
|
2652
|
+
)
|
|
2653
|
+
return True
|
|
2654
|
+
except ClientError as e:
|
|
2655
|
+
if e.response["Error"]["Code"] == "AccessDeniedException":
|
|
2656
|
+
LOGGER.warning(
|
|
2657
|
+
("Failed to enable live logging: %s. Fallback to default logging..."),
|
|
2658
|
+
e,
|
|
2659
|
+
)
|
|
2660
|
+
|
|
2661
|
+
return False
|
|
2662
|
+
return True
|
|
2663
|
+
|
|
2664
|
+
|
|
2665
|
+
def _deploy_done(sagemaker_client, endpoint_name):
|
|
2666
|
+
"""Placeholder docstring"""
|
|
2667
|
+
hosting_status_codes = {
|
|
2668
|
+
"OutOfService": "x",
|
|
2669
|
+
"Creating": "-",
|
|
2670
|
+
"Updating": "-",
|
|
2671
|
+
"InService": "!",
|
|
2672
|
+
"RollingBack": "<",
|
|
2673
|
+
"Deleting": "o",
|
|
2674
|
+
"Failed": "*",
|
|
2675
|
+
}
|
|
2676
|
+
in_progress_statuses = ["Creating", "Updating"]
|
|
2677
|
+
|
|
2678
|
+
desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
|
2679
|
+
status = desc["EndpointStatus"]
|
|
2680
|
+
|
|
2681
|
+
print(hosting_status_codes.get(status, "?"), end="")
|
|
2682
|
+
sys.stdout.flush()
|
|
2683
|
+
|
|
2684
|
+
return None if status in in_progress_statuses else desc
|
|
2685
|
+
|
|
2686
|
+
|
|
2687
|
+
def _live_logging_deploy_done(sagemaker_client, endpoint_name, paginator, paginator_config, poll):
|
|
2688
|
+
"""Placeholder docstring"""
|
|
2689
|
+
stop = False
|
|
2690
|
+
endpoint_status = None
|
|
2691
|
+
try:
|
|
2692
|
+
desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
|
2693
|
+
endpoint_status = desc["EndpointStatus"]
|
|
2694
|
+
except ClientError as e:
|
|
2695
|
+
if e.response["Error"]["Code"] == "ValidationException":
|
|
2696
|
+
LOGGER.debug("Waiting for endpoint to become visible")
|
|
2697
|
+
return None
|
|
2698
|
+
raise e
|
|
2699
|
+
|
|
2700
|
+
try:
|
|
2701
|
+
# if endpoint is in an invalid state -> set stop to true, sleep, and flush the logs
|
|
2702
|
+
if endpoint_status != "Creating":
|
|
2703
|
+
stop = True
|
|
2704
|
+
if endpoint_status == "InService":
|
|
2705
|
+
LOGGER.info("Created endpoint with name %s. Waiting for it to be InService", endpoint_name)
|
|
2706
|
+
else:
|
|
2707
|
+
time.sleep(poll)
|
|
2708
|
+
|
|
2709
|
+
pages = paginator.paginate(
|
|
2710
|
+
logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
|
|
2711
|
+
logStreamNamePrefix="AllTraffic/",
|
|
2712
|
+
PaginationConfig=paginator_config,
|
|
2713
|
+
)
|
|
2714
|
+
|
|
2715
|
+
for page in pages:
|
|
2716
|
+
if "nextToken" in page:
|
|
2717
|
+
paginator_config["StartingToken"] = page["nextToken"]
|
|
2718
|
+
for event in page["events"]:
|
|
2719
|
+
LOGGER.info(event["message"])
|
|
2720
|
+
else:
|
|
2721
|
+
LOGGER.debug("No log events available")
|
|
2722
|
+
|
|
2723
|
+
# if stop is true -> return the describe response and stop polling
|
|
2724
|
+
if stop:
|
|
2725
|
+
return desc
|
|
2726
|
+
except ClientError as e:
|
|
2727
|
+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
|
|
2728
|
+
LOGGER.debug("Waiting for endpoint log group to appear")
|
|
2729
|
+
return None
|
|
2730
|
+
raise e
|
|
2731
|
+
|
|
2732
|
+
return None
|
|
2733
|
+
|
|
2734
|
+
|
|
2735
|
+
def _deployment_entity_exists(describe_fn):
|
|
2736
|
+
"""Placeholder docstring"""
|
|
2737
|
+
try:
|
|
2738
|
+
describe_fn()
|
|
2739
|
+
return True
|
|
2740
|
+
except ClientError as ce:
|
|
2741
|
+
error_code = ce.response["Error"]["Code"]
|
|
2742
|
+
if not (
|
|
2743
|
+
error_code == "ValidationException"
|
|
2744
|
+
and "Could not find" in ce.response["Error"]["Message"]
|
|
2745
|
+
):
|
|
2746
|
+
raise ce
|
|
2747
|
+
return False
|
|
2748
|
+
|
|
2749
|
+
|
|
2750
|
+
def get_log_events_for_inference_recommender(cw_client, log_group_name, log_stream_name):
|
|
2751
|
+
"""Retrieves log events from the specified CloudWatch log group and log stream.
|
|
2752
|
+
|
|
2753
|
+
Args:
|
|
2754
|
+
cw_client (boto3.client): A boto3 CloudWatch client.
|
|
2755
|
+
log_group_name (str): The name of the CloudWatch log group.
|
|
2756
|
+
log_stream_name (str): The name of the CloudWatch log stream.
|
|
2757
|
+
|
|
2758
|
+
Returns:
|
|
2759
|
+
(dict): A dictionary containing log events from CloudWatch log group and log stream.
|
|
2760
|
+
"""
|
|
2761
|
+
print("Fetching logs from CloudWatch...", flush=True)
|
|
2762
|
+
for _ in retries(
|
|
2763
|
+
max_retry_count=30, # 30*10 = 5min
|
|
2764
|
+
exception_message_prefix="Waiting for cloudwatch stream to appear. ",
|
|
2765
|
+
seconds_to_sleep=10,
|
|
2766
|
+
):
|
|
2767
|
+
try:
|
|
2768
|
+
return cw_client.get_log_events(
|
|
2769
|
+
logGroupName=log_group_name, logStreamName=log_stream_name
|
|
2770
|
+
)
|
|
2771
|
+
except ClientError as e:
|
|
2772
|
+
if e.response["Error"]["Code"] == "ResourceNotFoundException":
|
|
2773
|
+
pass
|
|
2774
|
+
|
|
2775
|
+
|
|
2776
|
+
def _describe_inference_recommendations_job_status(sagemaker_client, job_name: str):
|
|
2777
|
+
"""Describes the status of a job and returns the job description.
|
|
2778
|
+
|
|
2779
|
+
Args:
|
|
2780
|
+
sagemaker_client (boto3.client.sagemaker): A SageMaker client.
|
|
2781
|
+
job_name (str): The name of the job.
|
|
2782
|
+
|
|
2783
|
+
Returns:
|
|
2784
|
+
dict: The job description, or None if the job is still in progress.
|
|
2785
|
+
"""
|
|
2786
|
+
inference_recommendations_job_status_codes = {
|
|
2787
|
+
"PENDING": ".",
|
|
2788
|
+
"IN_PROGRESS": ".",
|
|
2789
|
+
"COMPLETED": "!",
|
|
2790
|
+
"FAILED": "*",
|
|
2791
|
+
"STOPPING": "_",
|
|
2792
|
+
"STOPPED": "s",
|
|
2793
|
+
}
|
|
2794
|
+
in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
|
|
2795
|
+
|
|
2796
|
+
desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
|
|
2797
|
+
status = desc["Status"]
|
|
2798
|
+
|
|
2799
|
+
print(inference_recommendations_job_status_codes.get(status, "?"), end="", flush=True)
|
|
2800
|
+
|
|
2801
|
+
if status in in_progress_statuses:
|
|
2802
|
+
return None
|
|
2803
|
+
|
|
2804
|
+
print("")
|
|
2805
|
+
return desc
|
|
2806
|
+
|
|
2807
|
+
|
|
2808
|
+
def _display_inference_recommendations_job_steps_status(
|
|
2809
|
+
sagemaker_session, sagemaker_client, job_name: str, poll: int = 60
|
|
2810
|
+
):
|
|
2811
|
+
"""Placeholder docstring"""
|
|
2812
|
+
cloudwatch_client = sagemaker_session.boto_session.client("logs")
|
|
2813
|
+
in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
|
|
2814
|
+
log_group_name = "/aws/sagemaker/InferenceRecommendationsJobs"
|
|
2815
|
+
log_stream_name = job_name + "/execution"
|
|
2816
|
+
|
|
2817
|
+
initial_logs_batch = get_log_events_for_inference_recommender(
|
|
2818
|
+
cloudwatch_client, log_group_name, log_stream_name
|
|
2819
|
+
)
|
|
2820
|
+
print(f"Retrieved logStream: {log_stream_name} from logGroup: {log_group_name}", flush=True)
|
|
2821
|
+
events = initial_logs_batch["events"]
|
|
2822
|
+
print(*[event["message"] for event in events], sep="\n", flush=True)
|
|
2823
|
+
|
|
2824
|
+
next_forward_token = initial_logs_batch["nextForwardToken"] if events else None
|
|
2825
|
+
flush_remaining = True
|
|
2826
|
+
while True:
|
|
2827
|
+
logs_batch = (
|
|
2828
|
+
cloudwatch_client.get_log_events(
|
|
2829
|
+
logGroupName=log_group_name,
|
|
2830
|
+
logStreamName=log_stream_name,
|
|
2831
|
+
nextToken=next_forward_token,
|
|
2832
|
+
)
|
|
2833
|
+
if next_forward_token
|
|
2834
|
+
else cloudwatch_client.get_log_events(
|
|
2835
|
+
logGroupName=log_group_name, logStreamName=log_stream_name
|
|
2836
|
+
)
|
|
2837
|
+
)
|
|
2838
|
+
|
|
2839
|
+
events = logs_batch["events"]
|
|
2840
|
+
|
|
2841
|
+
desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
|
|
2842
|
+
status = desc["Status"]
|
|
2843
|
+
|
|
2844
|
+
if not events:
|
|
2845
|
+
if status in in_progress_statuses:
|
|
2846
|
+
time.sleep(poll)
|
|
2847
|
+
continue
|
|
2848
|
+
if flush_remaining:
|
|
2849
|
+
flush_remaining = False
|
|
2850
|
+
time.sleep(poll)
|
|
2851
|
+
continue
|
|
2852
|
+
|
|
2853
|
+
next_forward_token = logs_batch["nextForwardToken"]
|
|
2854
|
+
print(*[event["message"] for event in events], sep="\n", flush=True)
|
|
2855
|
+
|
|
2856
|
+
if status not in in_progress_statuses:
|
|
2857
|
+
break
|
|
2858
|
+
|
|
2859
|
+
time.sleep(poll)
|
|
2860
|
+
|
|
2861
|
+
|
|
2862
|
+
def _optimization_job_status(sagemaker_client, job_name):
|
|
2863
|
+
"""Placeholder docstring"""
|
|
2864
|
+
optimization_job_status_codes = {
|
|
2865
|
+
"INPROGRESS": ".",
|
|
2866
|
+
"COMPLETED": "!",
|
|
2867
|
+
"FAILED": "*",
|
|
2868
|
+
"STARTING": ".",
|
|
2869
|
+
"STOPPING": "_",
|
|
2870
|
+
"STOPPED": "s",
|
|
2871
|
+
}
|
|
2872
|
+
in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"]
|
|
2873
|
+
|
|
2874
|
+
desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name)
|
|
2875
|
+
status = desc["OptimizationJobStatus"]
|
|
2876
|
+
|
|
2877
|
+
print(optimization_job_status_codes.get(status, "?"), end="")
|
|
2878
|
+
sys.stdout.flush()
|
|
2879
|
+
|
|
2880
|
+
if status in in_progress_statuses:
|
|
2881
|
+
return None
|
|
2882
|
+
|
|
2883
|
+
print("")
|
|
2884
|
+
return desc
|
|
2885
|
+
|
|
2886
|
+
|
|
2887
|
+
def container_def(
|
|
2888
|
+
image_uri,
|
|
2889
|
+
model_data_url=None,
|
|
2890
|
+
env=None,
|
|
2891
|
+
container_mode=None,
|
|
2892
|
+
image_config=None,
|
|
2893
|
+
accept_eula=None,
|
|
2894
|
+
additional_model_data_sources=None,
|
|
2895
|
+
model_reference_arn=None,
|
|
2896
|
+
):
|
|
2897
|
+
"""Create a definition for executing a container as part of a SageMaker model.
|
|
2898
|
+
|
|
2899
|
+
Args:
|
|
2900
|
+
image_uri (str): Docker image URI to run for this container.
|
|
2901
|
+
model_data_url (str or dict[str, Any]): S3 location of model data required by this
|
|
2902
|
+
container, e.g. SageMaker training job model artifacts. It can either be a string
|
|
2903
|
+
representing S3 URI of model data, or a dictionary representing a
|
|
2904
|
+
``ModelDataSource`` object. (default: None).
|
|
2905
|
+
env (dict[str, str]): Environment variables to set inside the container (default: None).
|
|
2906
|
+
container_mode (str): The model container mode. Valid modes:
|
|
2907
|
+
* MultiModel: Indicates that model container can support hosting multiple models
|
|
2908
|
+
* SingleModel: Indicates that model container can support hosting a single model
|
|
2909
|
+
This is the default model container mode when container_mode = None
|
|
2910
|
+
image_config (dict[str, str]): Specifies whether the image of model container is pulled
|
|
2911
|
+
from ECR, or private registry in your VPC. By default it is set to pull model
|
|
2912
|
+
container image from ECR. (default: None).
|
|
2913
|
+
accept_eula (bool): For models that require a Model Access Config, specify True or
|
|
2914
|
+
False to indicate whether model terms of use have been accepted.
|
|
2915
|
+
The `accept_eula` value must be explicitly defined as `True` in order to
|
|
2916
|
+
accept the end-user license agreement (EULA) that some
|
|
2917
|
+
models require. (Default: None).
|
|
2918
|
+
additional_model_data_sources (PipelineVariable or dict): Additional location
|
|
2919
|
+
of SageMaker model data (default: None).
|
|
2920
|
+
|
|
2921
|
+
Returns:
|
|
2922
|
+
dict[str, str]: A complete container definition object usable with the CreateModel API if
|
|
2923
|
+
passed via `PrimaryContainers` field.
|
|
2924
|
+
"""
|
|
2925
|
+
if env is None:
|
|
2926
|
+
env = {}
|
|
2927
|
+
c_def = {"Image": image_uri, "Environment": env}
|
|
2928
|
+
|
|
2929
|
+
if additional_model_data_sources:
|
|
2930
|
+
c_def["AdditionalModelDataSources"] = additional_model_data_sources
|
|
2931
|
+
|
|
2932
|
+
if isinstance(model_data_url, str) and (
|
|
2933
|
+
not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz"))
|
|
2934
|
+
or accept_eula is None
|
|
2935
|
+
):
|
|
2936
|
+
c_def["ModelDataUrl"] = model_data_url
|
|
2937
|
+
|
|
2938
|
+
elif isinstance(model_data_url, (dict, str)):
|
|
2939
|
+
if isinstance(model_data_url, dict):
|
|
2940
|
+
c_def["ModelDataSource"] = model_data_url
|
|
2941
|
+
else:
|
|
2942
|
+
c_def["ModelDataSource"] = {
|
|
2943
|
+
"S3DataSource": {
|
|
2944
|
+
"S3Uri": model_data_url,
|
|
2945
|
+
"S3DataType": "S3Object",
|
|
2946
|
+
"CompressionType": "Gzip",
|
|
2947
|
+
}
|
|
2948
|
+
}
|
|
2949
|
+
if accept_eula is not None:
|
|
2950
|
+
c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = {
|
|
2951
|
+
"AcceptEula": accept_eula
|
|
2952
|
+
}
|
|
2953
|
+
if model_reference_arn:
|
|
2954
|
+
c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = {
|
|
2955
|
+
"HubContentArn": model_reference_arn
|
|
2956
|
+
}
|
|
2957
|
+
|
|
2958
|
+
elif model_data_url is not None:
|
|
2959
|
+
c_def["ModelDataUrl"] = model_data_url
|
|
2960
|
+
|
|
2961
|
+
if container_mode:
|
|
2962
|
+
c_def["Mode"] = container_mode
|
|
2963
|
+
if image_config:
|
|
2964
|
+
c_def["ImageConfig"] = image_config
|
|
2965
|
+
return c_def
|