sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Implements base methods for deserializing data returned from an inference endpoint."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import csv
|
|
17
|
+
|
|
18
|
+
import abc
|
|
19
|
+
import codecs
|
|
20
|
+
import io
|
|
21
|
+
import json
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
from six import with_metaclass
|
|
25
|
+
|
|
26
|
+
# Lazy import to avoid circular dependency with amazon modules
|
|
27
|
+
# from sagemaker.core.serializers.utils import read_records
|
|
28
|
+
from sagemaker.core.common_utils import DeferredError
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import pandas
|
|
32
|
+
except ImportError as e:
|
|
33
|
+
pandas = DeferredError(e)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseDeserializer(abc.ABC):
|
|
37
|
+
"""Abstract base class for creation of new deserializers.
|
|
38
|
+
|
|
39
|
+
Provides a skeleton for customization requiring the overriding of the method
|
|
40
|
+
deserialize and the class attribute ACCEPT.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def deserialize(self, stream, content_type):
|
|
45
|
+
"""Deserialize data received from an inference endpoint.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
49
|
+
content_type (str): The MIME type of the data.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
object: The data deserialized into an object.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def ACCEPT(self):
|
|
58
|
+
"""The content types that are expected from the inference endpoint."""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SimpleBaseDeserializer(with_metaclass(abc.ABCMeta, BaseDeserializer)):
|
|
62
|
+
"""Abstract base class for creation of new deserializers.
|
|
63
|
+
|
|
64
|
+
This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
|
|
65
|
+
user-friendly options for setting the ACCEPT content type header, in situations where it can be
|
|
66
|
+
provided at init and freely updated.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, accept="*/*"):
|
|
70
|
+
"""Initialize a ``SimpleBaseDeserializer`` instance.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
74
|
+
is expected from the inference endpoint (default: "*/*").
|
|
75
|
+
"""
|
|
76
|
+
super(SimpleBaseDeserializer, self).__init__()
|
|
77
|
+
self.accept = accept
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def ACCEPT(self):
|
|
81
|
+
"""The tuple of possible content types that are expected from the inference endpoint."""
|
|
82
|
+
if isinstance(self.accept, str):
|
|
83
|
+
return (self.accept,)
|
|
84
|
+
return self.accept
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class StringDeserializer(SimpleBaseDeserializer):
|
|
88
|
+
"""Deserialize data from an inference endpoint into a decoded string."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, encoding="UTF-8", accept="application/json"):
|
|
91
|
+
"""Initialize a ``StringDeserializer`` instance.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
encoding (str): The string encoding to use (default: UTF-8).
|
|
95
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
96
|
+
is expected from the inference endpoint (default: "application/json").
|
|
97
|
+
"""
|
|
98
|
+
super(StringDeserializer, self).__init__(accept=accept)
|
|
99
|
+
self.encoding = encoding
|
|
100
|
+
|
|
101
|
+
def deserialize(self, stream, content_type):
|
|
102
|
+
"""Deserialize data from an inference endpoint into a decoded string.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
106
|
+
content_type (str): The MIME type of the data.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
str: The data deserialized into a decoded string.
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
return stream.read().decode(self.encoding)
|
|
113
|
+
finally:
|
|
114
|
+
stream.close()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class BytesDeserializer(SimpleBaseDeserializer):
|
|
118
|
+
"""Deserialize a stream of bytes into a bytes object."""
|
|
119
|
+
|
|
120
|
+
def deserialize(self, stream, content_type):
|
|
121
|
+
"""Read a stream of bytes returned from an inference endpoint.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
stream (botocore.response.StreamingBody): A stream of bytes.
|
|
125
|
+
content_type (str): The MIME type of the data.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
bytes: The bytes object read from the stream.
|
|
129
|
+
"""
|
|
130
|
+
try:
|
|
131
|
+
return stream.read()
|
|
132
|
+
finally:
|
|
133
|
+
stream.close()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class CSVDeserializer(SimpleBaseDeserializer):
|
|
137
|
+
"""Deserialize a stream of bytes into a list of lists.
|
|
138
|
+
|
|
139
|
+
Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
|
|
140
|
+
:class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
|
|
141
|
+
responses directly into other data types.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(self, encoding="utf-8", accept="text/csv"):
|
|
145
|
+
"""Initialize a ``CSVDeserializer`` instance.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
encoding (str): The string encoding to use (default: "utf-8").
|
|
149
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
150
|
+
is expected from the inference endpoint (default: "text/csv").
|
|
151
|
+
"""
|
|
152
|
+
super(CSVDeserializer, self).__init__(accept=accept)
|
|
153
|
+
self.encoding = encoding
|
|
154
|
+
|
|
155
|
+
def deserialize(self, stream, content_type):
|
|
156
|
+
"""Deserialize data from an inference endpoint into a list of lists.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
160
|
+
content_type (str): The MIME type of the data.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
list: The data deserialized into a list of lists representing the
|
|
164
|
+
contents of a CSV file.
|
|
165
|
+
"""
|
|
166
|
+
try:
|
|
167
|
+
decoded_string = stream.read().decode(self.encoding)
|
|
168
|
+
return list(csv.reader(decoded_string.splitlines()))
|
|
169
|
+
finally:
|
|
170
|
+
stream.close()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class StreamDeserializer(SimpleBaseDeserializer):
|
|
174
|
+
"""Directly return the data and content-type received from an inference endpoint.
|
|
175
|
+
|
|
176
|
+
It is the user's responsibility to close the data stream once they're done
|
|
177
|
+
reading it.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def deserialize(self, stream, content_type):
|
|
181
|
+
"""Returns a stream of the response body and the MIME type of the data.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
stream (botocore.response.StreamingBody): A stream of bytes.
|
|
185
|
+
content_type (str): The MIME type of the data.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
tuple: A two-tuple containing the stream and content-type.
|
|
189
|
+
"""
|
|
190
|
+
return stream, content_type
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class NumpyDeserializer(SimpleBaseDeserializer):
|
|
194
|
+
"""Deserialize a stream of data in .npy, .npz or UTF-8 CSV/JSON format to a numpy array.
|
|
195
|
+
|
|
196
|
+
Note that when using application/x-npz archive format, the result will usually be a
|
|
197
|
+
dictionary-like object containing multiple arrays (as per ``numpy.load()``) - instead of a
|
|
198
|
+
single array.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
|
|
202
|
+
"""Initialize a ``NumpyDeserializer`` instance.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
dtype (str): The dtype of the data (default: None).
|
|
206
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
207
|
+
is expected from the inference endpoint (default: "application/x-npy").
|
|
208
|
+
allow_pickle (bool): Allow loading pickled object arrays (default: False).
|
|
209
|
+
"""
|
|
210
|
+
super(NumpyDeserializer, self).__init__(accept=accept)
|
|
211
|
+
self.dtype = dtype
|
|
212
|
+
self.allow_pickle = allow_pickle
|
|
213
|
+
|
|
214
|
+
def deserialize(self, stream, content_type):
|
|
215
|
+
"""Deserialize data from an inference endpoint into a NumPy array.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
219
|
+
content_type (str): The MIME type of the data.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
numpy.ndarray: The data deserialized into a NumPy array.
|
|
223
|
+
"""
|
|
224
|
+
try:
|
|
225
|
+
if content_type == "text/csv":
|
|
226
|
+
return np.genfromtxt(
|
|
227
|
+
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
|
|
228
|
+
)
|
|
229
|
+
if content_type == "application/json":
|
|
230
|
+
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
|
|
231
|
+
if content_type == "application/x-npy":
|
|
232
|
+
try:
|
|
233
|
+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
|
|
234
|
+
except ValueError as ve:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
"Please set the param allow_pickle=True \
|
|
237
|
+
to deserialize pickle objects in NumpyDeserializer"
|
|
238
|
+
).with_traceback(ve.__traceback__)
|
|
239
|
+
if content_type == "application/x-npz":
|
|
240
|
+
try:
|
|
241
|
+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
|
|
242
|
+
except ValueError as ve:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"Please set the param allow_pickle=True \
|
|
245
|
+
to deserialize pickle objectsin NumpyDeserializer"
|
|
246
|
+
).with_traceback(ve.__traceback__)
|
|
247
|
+
finally:
|
|
248
|
+
stream.close()
|
|
249
|
+
finally:
|
|
250
|
+
stream.close()
|
|
251
|
+
|
|
252
|
+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class JSONDeserializer(SimpleBaseDeserializer):
|
|
256
|
+
"""Deserialize JSON data from an inference endpoint into a Python object."""
|
|
257
|
+
|
|
258
|
+
def __init__(self, accept="application/json"):
|
|
259
|
+
"""Initialize a ``JSONDeserializer`` instance.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
263
|
+
is expected from the inference endpoint (default: "application/json").
|
|
264
|
+
"""
|
|
265
|
+
super(JSONDeserializer, self).__init__(accept=accept)
|
|
266
|
+
|
|
267
|
+
def deserialize(self, stream, content_type):
|
|
268
|
+
"""Deserialize JSON data from an inference endpoint into a Python object.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
272
|
+
content_type (str): The MIME type of the data.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
object: The JSON-formatted data deserialized into a Python object.
|
|
276
|
+
"""
|
|
277
|
+
try:
|
|
278
|
+
return json.load(codecs.getreader("utf-8")(stream))
|
|
279
|
+
finally:
|
|
280
|
+
stream.close()
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class PandasDeserializer(SimpleBaseDeserializer):
|
|
284
|
+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
|
|
285
|
+
|
|
286
|
+
def __init__(self, accept=("text/csv", "application/json")):
|
|
287
|
+
"""Initialize a ``PandasDeserializer`` instance.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
291
|
+
is expected from the inference endpoint (default: ("text/csv","application/json")).
|
|
292
|
+
"""
|
|
293
|
+
super(PandasDeserializer, self).__init__(accept=accept)
|
|
294
|
+
|
|
295
|
+
def deserialize(self, stream, content_type):
|
|
296
|
+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe.
|
|
297
|
+
|
|
298
|
+
If the data is JSON, the data should be formatted in the 'columns' orient.
|
|
299
|
+
See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_json.html
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
303
|
+
content_type (str): The MIME type of the data.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
pandas.DataFrame: The data deserialized into a pandas DataFrame.
|
|
307
|
+
"""
|
|
308
|
+
if content_type == "text/csv":
|
|
309
|
+
return pandas.read_csv(stream)
|
|
310
|
+
|
|
311
|
+
if content_type == "application/json":
|
|
312
|
+
return pandas.read_json(stream)
|
|
313
|
+
|
|
314
|
+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class JSONLinesDeserializer(SimpleBaseDeserializer):
|
|
318
|
+
"""Deserialize JSON lines data from an inference endpoint."""
|
|
319
|
+
|
|
320
|
+
def __init__(self, accept="application/jsonlines"):
|
|
321
|
+
"""Initialize a ``JSONLinesDeserializer`` instance.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
325
|
+
is expected from the inference endpoint (default: ("text/csv","application/json")).
|
|
326
|
+
"""
|
|
327
|
+
super(JSONLinesDeserializer, self).__init__(accept=accept)
|
|
328
|
+
|
|
329
|
+
def deserialize(self, stream, content_type):
|
|
330
|
+
"""Deserialize JSON lines data from an inference endpoint.
|
|
331
|
+
|
|
332
|
+
See https://docs.python.org/3/library/json.html#py-to-json-table to
|
|
333
|
+
understand how JSON values are converted to Python objects.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
337
|
+
content_type (str): The MIME type of the data.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
list: A list of JSON serializable objects.
|
|
341
|
+
"""
|
|
342
|
+
try:
|
|
343
|
+
body = stream.read().decode("utf-8")
|
|
344
|
+
lines = body.rstrip().split("\n")
|
|
345
|
+
return [json.loads(line) for line in lines]
|
|
346
|
+
finally:
|
|
347
|
+
stream.close()
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class TorchTensorDeserializer(SimpleBaseDeserializer):
|
|
351
|
+
"""Deserialize stream to torch.Tensor.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
355
|
+
content_type (str): The MIME type of the data.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
torch.Tensor: The data deserialized into a torch Tensor.
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(self, accept="tensor/pt"):
|
|
362
|
+
super(TorchTensorDeserializer, self).__init__(accept=accept)
|
|
363
|
+
self.numpy_deserializer = NumpyDeserializer()
|
|
364
|
+
try:
|
|
365
|
+
from torch import from_numpy
|
|
366
|
+
|
|
367
|
+
self.convert_npy_to_tensor = from_numpy
|
|
368
|
+
except ImportError:
|
|
369
|
+
raise Exception("Unable to import pytorch.")
|
|
370
|
+
|
|
371
|
+
def deserialize(self, stream, content_type="tensor/pt"):
|
|
372
|
+
"""Deserialize streamed data to TorchTensor
|
|
373
|
+
|
|
374
|
+
See https://pytorch.org/docs/stable/generated/torch.from_numpy.html
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
stream (botocore.response.StreamingBody): Data to be deserialized.
|
|
378
|
+
content_type (str): The MIME type of the data.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
list: A list of TorchTensor serializable objects.
|
|
382
|
+
"""
|
|
383
|
+
try:
|
|
384
|
+
numpy_array = self.numpy_deserializer.deserialize(
|
|
385
|
+
stream=stream, content_type="application/x-npy"
|
|
386
|
+
)
|
|
387
|
+
return self.convert_npy_to_tensor(numpy_array)
|
|
388
|
+
except Exception:
|
|
389
|
+
raise ValueError(
|
|
390
|
+
"Unable to deserialize your data to torch.Tensor.\
|
|
391
|
+
Please provide custom deserializer in InferenceSpec."
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
# TODO fix the unit test for this deserializer
|
|
396
|
+
class RecordDeserializer(SimpleBaseDeserializer):
|
|
397
|
+
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
|
|
398
|
+
|
|
399
|
+
def __init__(self, accept="application/x-recordio-protobuf"):
|
|
400
|
+
"""Initialize a ``RecordDeserializer`` instance.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
|
|
404
|
+
is expected from the inference endpoint (default:
|
|
405
|
+
"application/x-recordio-protobuf").
|
|
406
|
+
"""
|
|
407
|
+
super(RecordDeserializer, self).__init__(accept=accept)
|
|
408
|
+
|
|
409
|
+
def deserialize(self, data, content_type):
|
|
410
|
+
"""Deserialize RecordIO Protobuf data from an inference endpoint.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
data (object): The protobuf message to deserialize.
|
|
414
|
+
content_type (str): The MIME type of the data.
|
|
415
|
+
Returns:
|
|
416
|
+
list: A list of records.
|
|
417
|
+
"""
|
|
418
|
+
try:
|
|
419
|
+
# Lazy import to avoid circular dependency
|
|
420
|
+
from sagemaker.core.serializers.utils import read_records
|
|
421
|
+
|
|
422
|
+
return read_records(data)
|
|
423
|
+
finally:
|
|
424
|
+
data.close()
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Implements methods for deserializing data returned from an inference endpoint."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from typing import List, Optional
|
|
18
|
+
|
|
19
|
+
# base_deserializers was refactored from deserializers.
|
|
20
|
+
# this import ensures backward compatibility.
|
|
21
|
+
from sagemaker.core.deserializers.base import ( # noqa: F401 # pylint: disable=W0611
|
|
22
|
+
BaseDeserializer,
|
|
23
|
+
BytesDeserializer,
|
|
24
|
+
CSVDeserializer,
|
|
25
|
+
DeferredError,
|
|
26
|
+
JSONDeserializer,
|
|
27
|
+
JSONLinesDeserializer,
|
|
28
|
+
NumpyDeserializer,
|
|
29
|
+
PandasDeserializer,
|
|
30
|
+
SimpleBaseDeserializer,
|
|
31
|
+
StreamDeserializer,
|
|
32
|
+
StringDeserializer,
|
|
33
|
+
TorchTensorDeserializer,
|
|
34
|
+
RecordDeserializer,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
from sagemaker.core.deprecations import deprecated_class
|
|
38
|
+
from sagemaker.core.jumpstart import artifacts, utils as jumpstart_utils
|
|
39
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
40
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
41
|
+
from sagemaker.core.helper.session_helper import Session
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def retrieve_options(
|
|
45
|
+
region: Optional[str] = None,
|
|
46
|
+
model_id: Optional[str] = None,
|
|
47
|
+
model_version: Optional[str] = None,
|
|
48
|
+
hub_arn: Optional[str] = None,
|
|
49
|
+
tolerate_vulnerable_model: bool = False,
|
|
50
|
+
tolerate_deprecated_model: bool = False,
|
|
51
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
52
|
+
) -> List[BaseDeserializer]:
|
|
53
|
+
"""Retrieves the supported deserializers for the model matching the given arguments.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
region (str): The AWS Region for which to retrieve the supported deserializers.
|
|
57
|
+
Defaults to ``None``.
|
|
58
|
+
model_id (str): The model ID of the model for which to
|
|
59
|
+
retrieve the supported deserializers. (Default: None).
|
|
60
|
+
model_version (str): The version of the model for which to retrieve the
|
|
61
|
+
supported deserializers. (Default: None).
|
|
62
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
63
|
+
model details from. (Default: None).
|
|
64
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
65
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
66
|
+
exception if the script used by this version of the model has dependencies with known
|
|
67
|
+
security vulnerabilities. (Default: False).
|
|
68
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
69
|
+
(exception not raised). False if these models should raise an exception.
|
|
70
|
+
(Default: False).
|
|
71
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
72
|
+
object, used for SageMaker interactions. If not
|
|
73
|
+
specified, one is created using the default AWS configuration
|
|
74
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
75
|
+
Returns:
|
|
76
|
+
List[BaseDeserializer]: The supported deserializers to use for the model.
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving deserializers."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return artifacts._retrieve_deserializer_options(
|
|
88
|
+
model_id=model_id,
|
|
89
|
+
model_version=model_version,
|
|
90
|
+
hub_arn=hub_arn,
|
|
91
|
+
region=region,
|
|
92
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
93
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
94
|
+
sagemaker_session=sagemaker_session,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def retrieve_default(
|
|
99
|
+
region: Optional[str] = None,
|
|
100
|
+
model_id: Optional[str] = None,
|
|
101
|
+
model_version: Optional[str] = None,
|
|
102
|
+
hub_arn: Optional[str] = None,
|
|
103
|
+
tolerate_vulnerable_model: bool = False,
|
|
104
|
+
tolerate_deprecated_model: bool = False,
|
|
105
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
106
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
107
|
+
config_name: Optional[str] = None,
|
|
108
|
+
) -> BaseDeserializer:
|
|
109
|
+
"""Retrieves the default deserializer for the model matching the given arguments.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
region (str): The AWS Region for which to retrieve the default deserializer.
|
|
113
|
+
Defaults to ``None``.
|
|
114
|
+
model_id (str): The model ID of the model for which to
|
|
115
|
+
retrieve the default deserializer. (Default: None).
|
|
116
|
+
model_version (str): The version of the model for which to retrieve the
|
|
117
|
+
default deserializer. (Default: None).
|
|
118
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
119
|
+
model details from. (Default: None).
|
|
120
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
121
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
122
|
+
exception if the script used by this version of the model has dependencies with known
|
|
123
|
+
security vulnerabilities. (Default: False).
|
|
124
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
125
|
+
(exception not raised). False if these models should raise an exception.
|
|
126
|
+
(Default: False).
|
|
127
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
128
|
+
object, used for SageMaker interactions. If not
|
|
129
|
+
specified, one is created using the default AWS configuration
|
|
130
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
131
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
132
|
+
Returns:
|
|
133
|
+
BaseDeserializer: The default deserializer to use for the model.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving deserializers."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return artifacts._retrieve_default_deserializer(
|
|
145
|
+
model_id=model_id,
|
|
146
|
+
model_version=model_version,
|
|
147
|
+
hub_arn=hub_arn,
|
|
148
|
+
region=region,
|
|
149
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
150
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
151
|
+
sagemaker_session=sagemaker_session,
|
|
152
|
+
model_type=model_type,
|
|
153
|
+
config_name=config_name,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer")
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This file contains code related to drift check baselines"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from sagemaker.core.model_metrics import MetricsSource, FileSource
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DriftCheckBaselines(object):
|
|
22
|
+
"""Accepts drift check baselines parameters for conversion to request dict."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model_statistics: Optional[MetricsSource] = None,
|
|
27
|
+
model_constraints: Optional[MetricsSource] = None,
|
|
28
|
+
model_data_statistics: Optional[MetricsSource] = None,
|
|
29
|
+
model_data_constraints: Optional[MetricsSource] = None,
|
|
30
|
+
bias_config_file: Optional[FileSource] = None,
|
|
31
|
+
bias_pre_training_constraints: Optional[MetricsSource] = None,
|
|
32
|
+
bias_post_training_constraints: Optional[MetricsSource] = None,
|
|
33
|
+
explainability_constraints: Optional[MetricsSource] = None,
|
|
34
|
+
explainability_config_file: Optional[FileSource] = None,
|
|
35
|
+
):
|
|
36
|
+
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model_statistics (MetricsSource): A metric source object that represents
|
|
40
|
+
model statistics (default: None).
|
|
41
|
+
model_constraints (MetricsSource): A metric source object that represents
|
|
42
|
+
model constraints (default: None).
|
|
43
|
+
model_data_statistics (MetricsSource): A metric source object that represents
|
|
44
|
+
model data statistics (default: None).
|
|
45
|
+
model_data_constraints (MetricsSource): A metric source object that represents
|
|
46
|
+
model data constraints (default: None).
|
|
47
|
+
bias_config_file (FileSource): A file source object that represents bias config
|
|
48
|
+
(default: None).
|
|
49
|
+
bias_pre_training_constraints (MetricsSource):
|
|
50
|
+
A metric source object that represents Pre-training constraints (default: None).
|
|
51
|
+
bias_post_training_constraints (MetricsSource):
|
|
52
|
+
A metric source object that represents Post-training constraits (default: None).
|
|
53
|
+
explainability_constraints (MetricsSource):
|
|
54
|
+
A metric source object that represents explainability constraints (default: None).
|
|
55
|
+
explainability_config_file (FileSource): A file source object that represents
|
|
56
|
+
explainability config (default: None).
|
|
57
|
+
"""
|
|
58
|
+
self.model_statistics = model_statistics
|
|
59
|
+
self.model_constraints = model_constraints
|
|
60
|
+
self.model_data_statistics = model_data_statistics
|
|
61
|
+
self.model_data_constraints = model_data_constraints
|
|
62
|
+
self.bias_config_file = bias_config_file
|
|
63
|
+
self.bias_pre_training_constraints = bias_pre_training_constraints
|
|
64
|
+
self.bias_post_training_constraints = bias_post_training_constraints
|
|
65
|
+
self.explainability_constraints = explainability_constraints
|
|
66
|
+
self.explainability_config_file = explainability_config_file
|
|
67
|
+
|
|
68
|
+
def _to_request_dict(self):
|
|
69
|
+
"""Generates a request dictionary using the parameters provided to the class."""
|
|
70
|
+
drift_check_baselines_request = {}
|
|
71
|
+
|
|
72
|
+
model_quality = {}
|
|
73
|
+
if self.model_statistics is not None:
|
|
74
|
+
model_quality["Statistics"] = self.model_statistics._to_request_dict()
|
|
75
|
+
if self.model_constraints is not None:
|
|
76
|
+
model_quality["Constraints"] = self.model_constraints._to_request_dict()
|
|
77
|
+
if model_quality:
|
|
78
|
+
drift_check_baselines_request["ModelQuality"] = model_quality
|
|
79
|
+
|
|
80
|
+
model_data_quality = {}
|
|
81
|
+
if self.model_data_statistics is not None:
|
|
82
|
+
model_data_quality["Statistics"] = self.model_data_statistics._to_request_dict()
|
|
83
|
+
if self.model_data_constraints is not None:
|
|
84
|
+
model_data_quality["Constraints"] = self.model_data_constraints._to_request_dict()
|
|
85
|
+
if model_data_quality:
|
|
86
|
+
drift_check_baselines_request["ModelDataQuality"] = model_data_quality
|
|
87
|
+
|
|
88
|
+
bias = {}
|
|
89
|
+
if self.bias_config_file is not None:
|
|
90
|
+
bias["ConfigFile"] = self.bias_config_file._to_request_dict()
|
|
91
|
+
if self.bias_pre_training_constraints is not None:
|
|
92
|
+
bias["PreTrainingConstraints"] = self.bias_pre_training_constraints._to_request_dict()
|
|
93
|
+
if self.bias_post_training_constraints is not None:
|
|
94
|
+
bias["PostTrainingConstraints"] = self.bias_post_training_constraints._to_request_dict()
|
|
95
|
+
if bias:
|
|
96
|
+
drift_check_baselines_request["Bias"] = bias
|
|
97
|
+
|
|
98
|
+
explainability = {}
|
|
99
|
+
if self.explainability_constraints is not None:
|
|
100
|
+
explainability["Constraints"] = self.explainability_constraints._to_request_dict()
|
|
101
|
+
if self.explainability_config_file is not None:
|
|
102
|
+
explainability["ConfigFile"] = self.explainability_config_file._to_request_dict()
|
|
103
|
+
if explainability:
|
|
104
|
+
drift_check_baselines_request["Explainability"] = explainability
|
|
105
|
+
|
|
106
|
+
return drift_check_baselines_request
|