sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""SageMaker remote function data serializer/deserializer."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import json
|
|
18
|
+
|
|
19
|
+
import io
|
|
20
|
+
|
|
21
|
+
import sys
|
|
22
|
+
import hmac
|
|
23
|
+
import hashlib
|
|
24
|
+
import pickle
|
|
25
|
+
|
|
26
|
+
from typing import Any, Callable, Union
|
|
27
|
+
|
|
28
|
+
import cloudpickle
|
|
29
|
+
from tblib import pickling_support
|
|
30
|
+
|
|
31
|
+
from sagemaker.core.remote_function.errors import (
|
|
32
|
+
ServiceError,
|
|
33
|
+
SerializationError,
|
|
34
|
+
DeserializationError,
|
|
35
|
+
)
|
|
36
|
+
from sagemaker.core.s3 import S3Downloader, S3Uploader
|
|
37
|
+
from sagemaker.core.helper.session_helper import Session
|
|
38
|
+
from ._custom_dispatch_table import dispatch_table
|
|
39
|
+
|
|
40
|
+
# Note: do not use os.path.join for s3 uris, fails on windows
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _get_python_version():
|
|
44
|
+
"""Returns the current python version."""
|
|
45
|
+
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclasses.dataclass
|
|
49
|
+
class _MetaData:
|
|
50
|
+
"""Metadata about the serialized data or functions."""
|
|
51
|
+
|
|
52
|
+
sha256_hash: str
|
|
53
|
+
version: str = "2023-04-24"
|
|
54
|
+
python_version: str = _get_python_version()
|
|
55
|
+
serialization_module: str = "cloudpickle"
|
|
56
|
+
|
|
57
|
+
def to_json(self):
|
|
58
|
+
"""Converts metadata to json string."""
|
|
59
|
+
return json.dumps(dataclasses.asdict(self)).encode()
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def from_json(s):
|
|
63
|
+
"""Converts json string to metadata object."""
|
|
64
|
+
try:
|
|
65
|
+
obj = json.loads(s)
|
|
66
|
+
except json.decoder.JSONDecodeError:
|
|
67
|
+
raise DeserializationError("Corrupt metadata file. It is not a valid json file.")
|
|
68
|
+
|
|
69
|
+
sha256_hash = obj.get("sha256_hash")
|
|
70
|
+
metadata = _MetaData(sha256_hash=sha256_hash)
|
|
71
|
+
metadata.version = obj.get("version")
|
|
72
|
+
metadata.python_version = obj.get("python_version")
|
|
73
|
+
metadata.serialization_module = obj.get("serialization_module")
|
|
74
|
+
|
|
75
|
+
if not sha256_hash:
|
|
76
|
+
raise DeserializationError(
|
|
77
|
+
"Corrupt metadata file. SHA256 hash for the serialized data does not exist. "
|
|
78
|
+
"Please make sure to install SageMaker SDK version >= 2.156.0 on the client side "
|
|
79
|
+
"and try again."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if not (
|
|
83
|
+
metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
|
|
84
|
+
):
|
|
85
|
+
raise DeserializationError(
|
|
86
|
+
f"Corrupt metadata file. Serialization approach {s} is not supported."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return metadata
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class CloudpickleSerializer:
|
|
93
|
+
"""Serializer using cloudpickle."""
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def serialize(obj: Any) -> bytes:
|
|
97
|
+
"""Serializes data object and uploads it to S3.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
obj: object to be serialized and persisted
|
|
101
|
+
Raises:
|
|
102
|
+
SerializationError: when fail to serialize object to bytes.
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
io_buffer = io.BytesIO()
|
|
106
|
+
custom_pickler = cloudpickle.CloudPickler(io_buffer)
|
|
107
|
+
dt = pickle.Pickler.dispatch_table.__get__(custom_pickler) # pylint: disable=no-member
|
|
108
|
+
new_dt = dt.new_child(dispatch_table)
|
|
109
|
+
pickle.Pickler.dispatch_table.__set__( # pylint: disable=no-member
|
|
110
|
+
custom_pickler, new_dt
|
|
111
|
+
)
|
|
112
|
+
custom_pickler.dump(obj)
|
|
113
|
+
return io_buffer.getvalue()
|
|
114
|
+
except Exception as e:
|
|
115
|
+
if isinstance(
|
|
116
|
+
e, NotImplementedError
|
|
117
|
+
) and "Instance of Run type is not allowed to be pickled." in str(e):
|
|
118
|
+
raise SerializationError(
|
|
119
|
+
"""You are trying to pass a sagemaker.experiments.run.Run object to
|
|
120
|
+
a remote function
|
|
121
|
+
or are trying to access a global sagemaker.experiments.run.Run object
|
|
122
|
+
from within the function. This is not supported.
|
|
123
|
+
You must use `load_run` to load an existing Run in the remote function
|
|
124
|
+
or instantiate a new Run in the function."""
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
raise SerializationError(
|
|
128
|
+
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
|
|
129
|
+
) from e
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
|
|
133
|
+
"""Downloads from S3 and then deserializes data objects.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
137
|
+
bytes_to_deserialize: bytes to be deserialized.
|
|
138
|
+
Returns :
|
|
139
|
+
List of deserialized python objects.
|
|
140
|
+
Raises:
|
|
141
|
+
DeserializationError: when fail to serialize object to bytes.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
return cloudpickle.loads(bytes_to_deserialize)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
raise DeserializationError(
|
|
148
|
+
"Error when deserializing bytes downloaded from {}: {}. "
|
|
149
|
+
"NOTE: this may be caused by inconsistent sagemaker python sdk versions "
|
|
150
|
+
"where remote function runs versus the one used on client side. "
|
|
151
|
+
"If the sagemaker versions do not match, a warning message would "
|
|
152
|
+
"be logged starting with 'Inconsistent sagemaker versions found'. "
|
|
153
|
+
"Please check it to validate.".format(s3_uri, repr(e))
|
|
154
|
+
) from e
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# TODO: use dask serializer in case dask distributed is installed in users' environment.
|
|
158
|
+
def serialize_func_to_s3(
|
|
159
|
+
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
|
|
160
|
+
):
|
|
161
|
+
"""Serializes function and uploads it to S3.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
165
|
+
The underlying Boto3 session which AWS service calls are delegated to.
|
|
166
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
167
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
|
|
168
|
+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
|
|
169
|
+
func: function to be serialized and persisted
|
|
170
|
+
Raises:
|
|
171
|
+
SerializationError: when fail to serialize function to bytes.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
_upload_payload_and_metadata_to_s3(
|
|
175
|
+
bytes_to_upload=CloudpickleSerializer.serialize(func),
|
|
176
|
+
hmac_key=hmac_key,
|
|
177
|
+
s3_uri=s3_uri,
|
|
178
|
+
sagemaker_session=sagemaker_session,
|
|
179
|
+
s3_kms_key=s3_kms_key,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
|
|
184
|
+
"""Downloads from S3 and then deserializes data objects.
|
|
185
|
+
|
|
186
|
+
This method downloads the serialized training job outputs to a temporary directory and
|
|
187
|
+
then deserializes them using dask.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
191
|
+
The underlying sagemaker session which AWS service calls are delegated to.
|
|
192
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
193
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
|
|
194
|
+
Returns :
|
|
195
|
+
The deserialized function.
|
|
196
|
+
Raises:
|
|
197
|
+
DeserializationError: when fail to serialize function to bytes.
|
|
198
|
+
"""
|
|
199
|
+
metadata = _MetaData.from_json(
|
|
200
|
+
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
|
|
204
|
+
|
|
205
|
+
_perform_integrity_check(
|
|
206
|
+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def serialize_obj_to_s3(
|
|
213
|
+
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
|
|
214
|
+
):
|
|
215
|
+
"""Serializes data object and uploads it to S3.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
219
|
+
The underlying Boto3 session which AWS service calls are delegated to.
|
|
220
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
221
|
+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
|
|
222
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
|
|
223
|
+
obj: object to be serialized and persisted
|
|
224
|
+
Raises:
|
|
225
|
+
SerializationError: when fail to serialize object to bytes.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
_upload_payload_and_metadata_to_s3(
|
|
229
|
+
bytes_to_upload=CloudpickleSerializer.serialize(obj),
|
|
230
|
+
hmac_key=hmac_key,
|
|
231
|
+
s3_uri=s3_uri,
|
|
232
|
+
sagemaker_session=sagemaker_session,
|
|
233
|
+
s3_kms_key=s3_kms_key,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def json_serialize_obj_to_s3(
|
|
238
|
+
obj: Any,
|
|
239
|
+
json_key: str,
|
|
240
|
+
sagemaker_session: Session,
|
|
241
|
+
s3_uri: str,
|
|
242
|
+
s3_kms_key: str = None,
|
|
243
|
+
):
|
|
244
|
+
"""Json serializes data object and uploads it to S3.
|
|
245
|
+
|
|
246
|
+
If a function step's output is data referenced by other steps via JsonGet,
|
|
247
|
+
its output should be json serialized and uploaded to S3.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
obj: (Any) object to be serialized and persisted.
|
|
251
|
+
json_key: (str) the json key pointing to function step output.
|
|
252
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
253
|
+
The underlying Boto3 session which AWS service calls are delegated to.
|
|
254
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
255
|
+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
|
|
256
|
+
"""
|
|
257
|
+
json_serialized_result = {}
|
|
258
|
+
try:
|
|
259
|
+
to_dump = {json_key: obj, "Exception": None}
|
|
260
|
+
json_serialized_result = json.dumps(to_dump)
|
|
261
|
+
except TypeError as e:
|
|
262
|
+
if "is not JSON serializable" in str(e):
|
|
263
|
+
to_dump = {
|
|
264
|
+
json_key: None,
|
|
265
|
+
"Exception": f"The function return ({obj}) is not JSON serializable.",
|
|
266
|
+
}
|
|
267
|
+
json_serialized_result = json.dumps(to_dump)
|
|
268
|
+
|
|
269
|
+
S3Uploader.upload_string_as_file_body(
|
|
270
|
+
body=json_serialized_result,
|
|
271
|
+
desired_s3_uri=s3_uri,
|
|
272
|
+
sagemaker_session=sagemaker_session,
|
|
273
|
+
kms_key=s3_kms_key,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
|
|
278
|
+
"""Downloads from S3 and then deserializes data objects.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
282
|
+
The underlying sagemaker session which AWS service calls are delegated to.
|
|
283
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
284
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
|
|
285
|
+
Returns :
|
|
286
|
+
Deserialized python objects.
|
|
287
|
+
Raises:
|
|
288
|
+
DeserializationError: when fail to serialize object to bytes.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
metadata = _MetaData.from_json(
|
|
292
|
+
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
|
|
296
|
+
|
|
297
|
+
_perform_integrity_check(
|
|
298
|
+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def serialize_exception_to_s3(
|
|
305
|
+
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
|
|
306
|
+
):
|
|
307
|
+
"""Serializes exception with traceback and uploads it to S3.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
311
|
+
The underlying Boto3 session which AWS service calls are delegated to.
|
|
312
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
313
|
+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
|
|
314
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
|
|
315
|
+
exc: Exception to be serialized and persisted
|
|
316
|
+
Raises:
|
|
317
|
+
SerializationError: when fail to serialize object to bytes.
|
|
318
|
+
"""
|
|
319
|
+
pickling_support.install()
|
|
320
|
+
|
|
321
|
+
_upload_payload_and_metadata_to_s3(
|
|
322
|
+
bytes_to_upload=CloudpickleSerializer.serialize(exc),
|
|
323
|
+
hmac_key=hmac_key,
|
|
324
|
+
s3_uri=s3_uri,
|
|
325
|
+
sagemaker_session=sagemaker_session,
|
|
326
|
+
s3_kms_key=s3_kms_key,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _upload_payload_and_metadata_to_s3(
|
|
331
|
+
bytes_to_upload: Union[bytes, io.BytesIO],
|
|
332
|
+
hmac_key: str,
|
|
333
|
+
s3_uri: str,
|
|
334
|
+
sagemaker_session: Session,
|
|
335
|
+
s3_kms_key,
|
|
336
|
+
):
|
|
337
|
+
"""Uploads serialized payload and metadata to s3.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
bytes_to_upload (bytes): Serialized bytes to upload.
|
|
341
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
|
|
342
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
343
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
344
|
+
The underlying Boto3 session which AWS service calls are delegated to.
|
|
345
|
+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
|
|
346
|
+
"""
|
|
347
|
+
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
|
|
348
|
+
|
|
349
|
+
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
|
|
350
|
+
|
|
351
|
+
_upload_bytes_to_s3(
|
|
352
|
+
_MetaData(sha256_hash).to_json(),
|
|
353
|
+
f"{s3_uri}/metadata.json",
|
|
354
|
+
s3_kms_key,
|
|
355
|
+
sagemaker_session,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
|
|
360
|
+
"""Downloads from S3 and then deserializes exception.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
sagemaker_session (sagemaker.core.helper.session.Session):
|
|
364
|
+
The underlying sagemaker session which AWS service calls are delegated to.
|
|
365
|
+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
|
|
366
|
+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
|
|
367
|
+
Returns :
|
|
368
|
+
Deserialized exception with traceback.
|
|
369
|
+
Raises:
|
|
370
|
+
DeserializationError: when fail to serialize object to bytes.
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
metadata = _MetaData.from_json(
|
|
374
|
+
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
|
|
378
|
+
|
|
379
|
+
_perform_integrity_check(
|
|
380
|
+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _upload_bytes_to_s3(b: Union[bytes, io.BytesIO], s3_uri, s3_kms_key, sagemaker_session):
|
|
387
|
+
"""Wrapping s3 uploading with exception translation for remote function."""
|
|
388
|
+
try:
|
|
389
|
+
S3Uploader.upload_bytes(b, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session)
|
|
390
|
+
except Exception as e:
|
|
391
|
+
raise ServiceError(
|
|
392
|
+
"Failed to upload serialized bytes to {}: {}".format(s3_uri, repr(e))
|
|
393
|
+
) from e
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _read_bytes_from_s3(s3_uri, sagemaker_session):
|
|
397
|
+
"""Wrapping s3 downloading with exception translation for remote function."""
|
|
398
|
+
try:
|
|
399
|
+
return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session)
|
|
400
|
+
except Exception as e:
|
|
401
|
+
raise ServiceError(
|
|
402
|
+
"Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e))
|
|
403
|
+
) from e
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _compute_hash(buffer: bytes, secret_key: str) -> str:
|
|
407
|
+
"""Compute the hmac-sha256 hash"""
|
|
408
|
+
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
|
|
412
|
+
"""Performs integrity checks for serialized code/arguments uploaded to s3.
|
|
413
|
+
|
|
414
|
+
Verifies whether the hash read from s3 matches the hash calculated
|
|
415
|
+
during remote function execution.
|
|
416
|
+
"""
|
|
417
|
+
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
|
|
418
|
+
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
|
|
419
|
+
raise DeserializationError(
|
|
420
|
+
"Integrity check for the serialized function or data failed. "
|
|
421
|
+
"Please restrict access to your S3 bucket"
|
|
422
|
+
)
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""SageMaker job function serializer/deserializer."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from sagemaker.core.s3 import s3_path_join
|
|
22
|
+
from sagemaker.core.remote_function import logging_config
|
|
23
|
+
from sagemaker.core.remote_function.core.pipeline_variables import (
|
|
24
|
+
Context,
|
|
25
|
+
resolve_pipeline_variables,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
import sagemaker.core.remote_function.core.serialization as serialization
|
|
29
|
+
from sagemaker.core.helper.session_helper import Session
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
logger = logging_config.get_logger()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
FUNCTION_FOLDER = "function"
|
|
36
|
+
ARGUMENTS_FOLDER = "arguments"
|
|
37
|
+
RESULTS_FOLDER = "results"
|
|
38
|
+
EXCEPTION_FOLDER = "exception"
|
|
39
|
+
JSON_SERIALIZED_RESULT_KEY = "Result"
|
|
40
|
+
JSON_RESULTS_FILE = "results.json"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class _SerializedData:
|
|
45
|
+
"""Data class to store serialized function and arguments"""
|
|
46
|
+
|
|
47
|
+
func: bytes
|
|
48
|
+
args: bytes
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class StoredFunction:
|
|
52
|
+
"""Class representing a remote function stored in S3."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
sagemaker_session: Session,
|
|
57
|
+
s3_base_uri: str,
|
|
58
|
+
hmac_key: str,
|
|
59
|
+
s3_kms_key: str = None,
|
|
60
|
+
context: Context = Context(),
|
|
61
|
+
):
|
|
62
|
+
"""Construct a StoredFunction object.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
sagemaker_session: (sagemaker.session.Session): The underlying sagemaker session which
|
|
66
|
+
AWS service calls are delegated to.
|
|
67
|
+
s3_base_uri: the base uri to which serialized artifacts will be uploaded.
|
|
68
|
+
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
|
|
69
|
+
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
|
|
70
|
+
context: Build or run context of a pipeline step.
|
|
71
|
+
"""
|
|
72
|
+
self.sagemaker_session = sagemaker_session
|
|
73
|
+
self.s3_base_uri = s3_base_uri
|
|
74
|
+
self.s3_kms_key = s3_kms_key
|
|
75
|
+
self.hmac_key = hmac_key
|
|
76
|
+
self.context = context
|
|
77
|
+
|
|
78
|
+
# For pipeline steps, function code is at: base/step_name/build_timestamp/
|
|
79
|
+
# For results, path is: base/step_name/build_timestamp/execution_id/
|
|
80
|
+
# This ensures uniqueness: build_timestamp per build, execution_id per run
|
|
81
|
+
if context.step_name and context.func_step_s3_dir:
|
|
82
|
+
# Pipeline step: include build timestamp in both paths
|
|
83
|
+
self.func_upload_path = s3_path_join(
|
|
84
|
+
s3_base_uri, context.step_name, context.func_step_s3_dir
|
|
85
|
+
)
|
|
86
|
+
self.results_upload_path = s3_path_join(
|
|
87
|
+
s3_base_uri, context.step_name, context.func_step_s3_dir, context.execution_id
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
# Regular remote function: original behavior
|
|
91
|
+
self.func_upload_path = s3_path_join(
|
|
92
|
+
s3_base_uri, context.step_name, context.func_step_s3_dir
|
|
93
|
+
)
|
|
94
|
+
self.results_upload_path = s3_path_join(
|
|
95
|
+
s3_base_uri, context.execution_id, context.step_name
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def save(self, func, *args, **kwargs):
|
|
99
|
+
"""Serialize and persist the function and arguments.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
func: the python function.
|
|
103
|
+
args: the positional arguments to func.
|
|
104
|
+
kwargs: the keyword arguments to func.
|
|
105
|
+
Returns:
|
|
106
|
+
None
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
logger.info(
|
|
110
|
+
"Serializing function code to %s", s3_path_join(self.func_upload_path, FUNCTION_FOLDER)
|
|
111
|
+
)
|
|
112
|
+
serialization.serialize_func_to_s3(
|
|
113
|
+
func=func,
|
|
114
|
+
sagemaker_session=self.sagemaker_session,
|
|
115
|
+
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
|
|
116
|
+
s3_kms_key=self.s3_kms_key,
|
|
117
|
+
hmac_key=self.hmac_key,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
logger.info(
|
|
121
|
+
"Serializing function arguments to %s",
|
|
122
|
+
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
serialization.serialize_obj_to_s3(
|
|
126
|
+
obj=(args, kwargs),
|
|
127
|
+
sagemaker_session=self.sagemaker_session,
|
|
128
|
+
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
|
|
129
|
+
hmac_key=self.hmac_key,
|
|
130
|
+
s3_kms_key=self.s3_kms_key,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def save_pipeline_step_function(self, serialized_data):
|
|
134
|
+
"""Upload serialized function and arguments to s3.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
serialized_data (_SerializedData): The serialized function
|
|
138
|
+
and function arguments of a function step.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
logger.info(
|
|
142
|
+
"Uploading serialized function code to %s",
|
|
143
|
+
s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
|
|
144
|
+
)
|
|
145
|
+
serialization._upload_payload_and_metadata_to_s3(
|
|
146
|
+
bytes_to_upload=serialized_data.func,
|
|
147
|
+
hmac_key=self.hmac_key,
|
|
148
|
+
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
|
|
149
|
+
sagemaker_session=self.sagemaker_session,
|
|
150
|
+
s3_kms_key=self.s3_kms_key,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
logger.info(
|
|
154
|
+
"Uploading serialized function arguments to %s",
|
|
155
|
+
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
|
|
156
|
+
)
|
|
157
|
+
serialization._upload_payload_and_metadata_to_s3(
|
|
158
|
+
bytes_to_upload=serialized_data.args,
|
|
159
|
+
hmac_key=self.hmac_key,
|
|
160
|
+
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
|
|
161
|
+
sagemaker_session=self.sagemaker_session,
|
|
162
|
+
s3_kms_key=self.s3_kms_key,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def load_and_invoke(self) -> Any:
|
|
166
|
+
"""Load and deserialize the function and the arguments and then execute it."""
|
|
167
|
+
|
|
168
|
+
logger.info(
|
|
169
|
+
"Deserializing function code from %s",
|
|
170
|
+
s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
|
|
171
|
+
)
|
|
172
|
+
func = serialization.deserialize_func_from_s3(
|
|
173
|
+
sagemaker_session=self.sagemaker_session,
|
|
174
|
+
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
|
|
175
|
+
hmac_key=self.hmac_key,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
logger.info(
|
|
179
|
+
"Deserializing function arguments from %s",
|
|
180
|
+
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
|
|
181
|
+
)
|
|
182
|
+
args, kwargs = serialization.deserialize_obj_from_s3(
|
|
183
|
+
sagemaker_session=self.sagemaker_session,
|
|
184
|
+
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
|
|
185
|
+
hmac_key=self.hmac_key,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
logger.info("Resolving pipeline variables")
|
|
189
|
+
resolved_args, resolved_kwargs = resolve_pipeline_variables(
|
|
190
|
+
self.context,
|
|
191
|
+
args,
|
|
192
|
+
kwargs,
|
|
193
|
+
hmac_key=self.hmac_key,
|
|
194
|
+
s3_base_uri=self.s3_base_uri,
|
|
195
|
+
sagemaker_session=self.sagemaker_session,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
logger.info("Invoking the function")
|
|
199
|
+
result = func(*resolved_args, **resolved_kwargs)
|
|
200
|
+
|
|
201
|
+
logger.info(
|
|
202
|
+
"Serializing the function return and uploading to %s",
|
|
203
|
+
s3_path_join(self.results_upload_path, RESULTS_FOLDER),
|
|
204
|
+
)
|
|
205
|
+
serialization.serialize_obj_to_s3(
|
|
206
|
+
obj=result,
|
|
207
|
+
sagemaker_session=self.sagemaker_session,
|
|
208
|
+
s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER),
|
|
209
|
+
hmac_key=self.hmac_key,
|
|
210
|
+
s3_kms_key=self.s3_kms_key,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if self.context and self.context.serialize_output_to_json:
|
|
214
|
+
logger.info(
|
|
215
|
+
"JSON Serializing the function return and uploading to %s",
|
|
216
|
+
s3_path_join(self.results_upload_path, RESULTS_FOLDER),
|
|
217
|
+
)
|
|
218
|
+
serialization.json_serialize_obj_to_s3(
|
|
219
|
+
obj=result,
|
|
220
|
+
json_key=JSON_SERIALIZED_RESULT_KEY,
|
|
221
|
+
sagemaker_session=self.sagemaker_session,
|
|
222
|
+
s3_uri=s3_path_join(
|
|
223
|
+
os.path.join(self.results_upload_path, RESULTS_FOLDER, JSON_RESULTS_FILE)
|
|
224
|
+
),
|
|
225
|
+
s3_kms_key=self.s3_kms_key,
|
|
226
|
+
)
|