sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Implements base methods for serializing data for an inference endpoint."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
from collections.abc import Iterable
|
|
18
|
+
import csv
|
|
19
|
+
import io
|
|
20
|
+
import json
|
|
21
|
+
import numpy as np
|
|
22
|
+
from pandas import DataFrame
|
|
23
|
+
from six import with_metaclass
|
|
24
|
+
|
|
25
|
+
# Lazy import to avoid circular dependency with amazon modules
|
|
26
|
+
# from sagemaker.core.serializers.utils import write_numpy_to_dense_tensor
|
|
27
|
+
from sagemaker.core.common_utils import DeferredError
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import scipy.sparse
|
|
31
|
+
except ImportError as e:
|
|
32
|
+
scipy = DeferredError(e)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseSerializer(abc.ABC):
|
|
36
|
+
"""Abstract base class for creation of new serializers.
|
|
37
|
+
|
|
38
|
+
Provides a skeleton for customization requiring the overriding of the method
|
|
39
|
+
serialize and the class attribute CONTENT_TYPE.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def serialize(self, data):
|
|
44
|
+
"""Serialize data into the media type specified by CONTENT_TYPE.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
data (object): Data to be serialized.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
object: Serialized data used for a request.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def CONTENT_TYPE(self):
|
|
56
|
+
"""The MIME type of the data sent to the inference endpoint."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SimpleBaseSerializer(with_metaclass(abc.ABCMeta, BaseSerializer)):
|
|
60
|
+
"""Abstract base class for creation of new serializers.
|
|
61
|
+
|
|
62
|
+
This class extends the API of :class:~`sagemaker.serializers.BaseSerializer` with more
|
|
63
|
+
user-friendly options for setting the Content-Type header, in situations where it can be
|
|
64
|
+
provided at init and freely updated.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, content_type="application/json"):
|
|
68
|
+
"""Initialize a ``SimpleBaseSerializer`` instance.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
72
|
+
request data (default: "application/json").
|
|
73
|
+
"""
|
|
74
|
+
super(SimpleBaseSerializer, self).__init__()
|
|
75
|
+
if not isinstance(content_type, str):
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"content_type must be a string specifying the MIME type of the data sent in "
|
|
78
|
+
"requests: e.g. 'application/json', 'text/csv', etc. Got %s" % content_type
|
|
79
|
+
)
|
|
80
|
+
self.content_type = content_type
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def CONTENT_TYPE(self):
|
|
84
|
+
"""The data MIME type set in the Content-Type header on prediction endpoint requests."""
|
|
85
|
+
return self.content_type
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class CSVSerializer(SimpleBaseSerializer):
|
|
89
|
+
"""Serialize data of various formats to a CSV-formatted string."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, content_type="text/csv"):
|
|
92
|
+
"""Initialize a ``CSVSerializer`` instance.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
96
|
+
request data (default: "text/csv").
|
|
97
|
+
"""
|
|
98
|
+
super(CSVSerializer, self).__init__(content_type=content_type)
|
|
99
|
+
|
|
100
|
+
def serialize(self, data):
|
|
101
|
+
"""Serialize data of various formats to a CSV-formatted string.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
data (object): Data to be serialized. Can be a NumPy array, list,
|
|
105
|
+
file, Pandas DataFrame, or buffer.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
str: The data serialized as a CSV-formatted string.
|
|
109
|
+
"""
|
|
110
|
+
if hasattr(data, "read"):
|
|
111
|
+
return data.read()
|
|
112
|
+
|
|
113
|
+
if isinstance(data, DataFrame):
|
|
114
|
+
return data.to_csv(header=False, index=False)
|
|
115
|
+
|
|
116
|
+
is_mutable_sequence_like = self._is_sequence_like(data) and hasattr(data, "__setitem__")
|
|
117
|
+
has_multiple_rows = len(data) > 0 and self._is_sequence_like(data[0])
|
|
118
|
+
|
|
119
|
+
if is_mutable_sequence_like and has_multiple_rows:
|
|
120
|
+
return "\n".join([self._serialize_row(row) for row in data])
|
|
121
|
+
|
|
122
|
+
return self._serialize_row(data)
|
|
123
|
+
|
|
124
|
+
def _serialize_row(self, data):
|
|
125
|
+
"""Serialize data as a CSV-formatted row.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
data (object): Data to be serialized in a row.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
str: The data serialized as a CSV-formatted row.
|
|
132
|
+
"""
|
|
133
|
+
if isinstance(data, str):
|
|
134
|
+
return data
|
|
135
|
+
|
|
136
|
+
if isinstance(data, np.ndarray):
|
|
137
|
+
data = np.ndarray.flatten(data)
|
|
138
|
+
|
|
139
|
+
if hasattr(data, "__len__"):
|
|
140
|
+
if len(data) == 0:
|
|
141
|
+
raise ValueError("Cannot serialize empty array")
|
|
142
|
+
csv_buffer = io.StringIO()
|
|
143
|
+
csv_writer = csv.writer(csv_buffer, delimiter=",")
|
|
144
|
+
csv_writer.writerow(data)
|
|
145
|
+
return csv_buffer.getvalue().rstrip("\r\n")
|
|
146
|
+
|
|
147
|
+
raise ValueError("Unable to handle input format: %s" % type(data))
|
|
148
|
+
|
|
149
|
+
def _is_sequence_like(self, data):
|
|
150
|
+
"""Returns true if obj is iterable and subscriptable."""
|
|
151
|
+
return hasattr(data, "__iter__") and hasattr(data, "__getitem__")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class NumpySerializer(SimpleBaseSerializer):
|
|
155
|
+
"""Serialize data to a buffer using the .npy format."""
|
|
156
|
+
|
|
157
|
+
def __init__(self, dtype=None, content_type="application/x-npy"):
|
|
158
|
+
"""Initialize a ``NumpySerializer`` instance.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
162
|
+
request data (default: "application/x-npy").
|
|
163
|
+
dtype (str): The dtype of the data.
|
|
164
|
+
"""
|
|
165
|
+
super(NumpySerializer, self).__init__(content_type=content_type)
|
|
166
|
+
self.dtype = dtype
|
|
167
|
+
|
|
168
|
+
def serialize(self, data):
|
|
169
|
+
"""Serialize data to a buffer using the .npy format.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
data (object): Data to be serialized. Can be a NumPy array, list,
|
|
173
|
+
file, or buffer.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
io.BytesIO: A buffer containing data serialzied in the .npy format.
|
|
177
|
+
"""
|
|
178
|
+
if isinstance(data, np.ndarray):
|
|
179
|
+
if data.size == 0:
|
|
180
|
+
raise ValueError("Cannot serialize empty array.")
|
|
181
|
+
return self._serialize_array(data)
|
|
182
|
+
|
|
183
|
+
if isinstance(data, list):
|
|
184
|
+
if len(data) == 0:
|
|
185
|
+
raise ValueError("Cannot serialize empty array.")
|
|
186
|
+
return self._serialize_array(np.array(data, self.dtype))
|
|
187
|
+
|
|
188
|
+
# files and buffers. Assumed to hold npy-formatted data.
|
|
189
|
+
if hasattr(data, "read"):
|
|
190
|
+
return data.read()
|
|
191
|
+
|
|
192
|
+
return self._serialize_array(np.array(data))
|
|
193
|
+
|
|
194
|
+
def _serialize_array(self, array):
|
|
195
|
+
"""Saves a NumPy array in a buffer.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
array (numpy.ndarray): The array to serialize.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
io.BytesIO: A buffer containing the serialized array.
|
|
202
|
+
"""
|
|
203
|
+
buffer = io.BytesIO()
|
|
204
|
+
np.save(buffer, array)
|
|
205
|
+
return buffer.getvalue()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class JSONSerializer(SimpleBaseSerializer):
|
|
209
|
+
"""Serialize data to a JSON formatted string."""
|
|
210
|
+
|
|
211
|
+
def serialize(self, data):
|
|
212
|
+
"""Serialize data of various formats to a JSON formatted string.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
data (object): Data to be serialized.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
str: The data serialized as a JSON string.
|
|
219
|
+
"""
|
|
220
|
+
if isinstance(data, dict):
|
|
221
|
+
return json.dumps(
|
|
222
|
+
{
|
|
223
|
+
key: value.tolist() if isinstance(value, np.ndarray) else value
|
|
224
|
+
for key, value in data.items()
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if hasattr(data, "read"):
|
|
229
|
+
return data.read()
|
|
230
|
+
|
|
231
|
+
if isinstance(data, np.ndarray):
|
|
232
|
+
return json.dumps(data.tolist())
|
|
233
|
+
|
|
234
|
+
return json.dumps(data)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class IdentitySerializer(SimpleBaseSerializer):
|
|
238
|
+
"""Serialize data by returning data without modification.
|
|
239
|
+
|
|
240
|
+
This serializer may be useful if, for example, you're sending raw bytes such as from an image
|
|
241
|
+
file's .read() method.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
def __init__(self, content_type="application/octet-stream"):
|
|
245
|
+
"""Initialize an ``IdentitySerializer`` instance.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
249
|
+
request data (default: "application/octet-stream").
|
|
250
|
+
"""
|
|
251
|
+
super(IdentitySerializer, self).__init__(content_type=content_type)
|
|
252
|
+
|
|
253
|
+
def serialize(self, data):
|
|
254
|
+
"""Return data without modification.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
data (object): Data to be serialized.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
object: The unmodified data.
|
|
261
|
+
"""
|
|
262
|
+
return data
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class JSONLinesSerializer(SimpleBaseSerializer):
|
|
266
|
+
"""Serialize data to a JSON Lines formatted string."""
|
|
267
|
+
|
|
268
|
+
def __init__(self, content_type="application/jsonlines"):
|
|
269
|
+
"""Initialize a ``JSONLinesSerializer`` instance.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
273
|
+
request data (default: "application/jsonlines").
|
|
274
|
+
"""
|
|
275
|
+
super(JSONLinesSerializer, self).__init__(content_type=content_type)
|
|
276
|
+
|
|
277
|
+
def serialize(self, data):
|
|
278
|
+
"""Serialize data of various formats to a JSON Lines formatted string.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
data (object): Data to be serialized. The data can be a string,
|
|
282
|
+
iterable of JSON serializable objects, or a file-like object.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
str: The data serialized as a string containing newline-separated
|
|
286
|
+
JSON values.
|
|
287
|
+
"""
|
|
288
|
+
if isinstance(data, str):
|
|
289
|
+
return data
|
|
290
|
+
|
|
291
|
+
if hasattr(data, "read"):
|
|
292
|
+
return data.read()
|
|
293
|
+
|
|
294
|
+
if isinstance(data, Iterable):
|
|
295
|
+
return "\n".join(json.dumps(element) for element in data)
|
|
296
|
+
|
|
297
|
+
raise ValueError("Object of type %s is not JSON Lines serializable." % type(data))
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class SparseMatrixSerializer(SimpleBaseSerializer):
|
|
301
|
+
"""Serialize a sparse matrix to a buffer using the .npz format."""
|
|
302
|
+
|
|
303
|
+
def __init__(self, content_type="application/x-npz"):
|
|
304
|
+
"""Initialize a ``SparseMatrixSerializer`` instance.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
308
|
+
request data (default: "application/x-npz").
|
|
309
|
+
"""
|
|
310
|
+
super(SparseMatrixSerializer, self).__init__(content_type=content_type)
|
|
311
|
+
|
|
312
|
+
def serialize(self, data):
|
|
313
|
+
"""Serialize a sparse matrix to a buffer using the .npz format.
|
|
314
|
+
|
|
315
|
+
Sparse matrices can be in the ``csc``, ``csr``, ``bsr``, ``dia`` or
|
|
316
|
+
``coo`` formats.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
data (scipy.sparse.spmatrix): The sparse matrix to serialize.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
io.BytesIO: A buffer containing the serialized sparse matrix.
|
|
323
|
+
"""
|
|
324
|
+
buffer = io.BytesIO()
|
|
325
|
+
scipy.sparse.save_npz(buffer, data)
|
|
326
|
+
return buffer.getvalue()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class LibSVMSerializer(SimpleBaseSerializer):
|
|
330
|
+
"""Serialize data of various formats to a LibSVM-formatted string.
|
|
331
|
+
|
|
332
|
+
The data must already be in LIBSVM file format:
|
|
333
|
+
<label> <index1>:<value1> <index2>:<value2> ...
|
|
334
|
+
|
|
335
|
+
It is suitable for sparse datasets since it does not store zero-valued
|
|
336
|
+
features.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def __init__(self, content_type="text/libsvm"):
|
|
340
|
+
"""Initialize a ``LibSVMSerializer`` instance.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
344
|
+
request data (default: "text/libsvm").
|
|
345
|
+
"""
|
|
346
|
+
super(LibSVMSerializer, self).__init__(content_type=content_type)
|
|
347
|
+
|
|
348
|
+
def serialize(self, data):
|
|
349
|
+
"""Serialize data of various formats to a LibSVM-formatted string.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
data (object): Data to be serialized. Can be a string or a
|
|
353
|
+
file-like object.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
str: The data serialized as a LibSVM-formatted string.
|
|
357
|
+
|
|
358
|
+
Raises:
|
|
359
|
+
ValueError: If unable to handle input format
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
if isinstance(data, str):
|
|
363
|
+
return data
|
|
364
|
+
|
|
365
|
+
if hasattr(data, "read"):
|
|
366
|
+
return data.read()
|
|
367
|
+
|
|
368
|
+
raise ValueError("Unable to handle input format: %s" % type(data))
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class DataSerializer(SimpleBaseSerializer):
|
|
372
|
+
"""Serialize data in any file by extracting raw bytes from the file."""
|
|
373
|
+
|
|
374
|
+
def __init__(self, content_type="file-path/raw-bytes"):
|
|
375
|
+
"""Initialize a ``DataSerializer`` instance.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
379
|
+
request data (default: "file-path/raw-bytes").
|
|
380
|
+
"""
|
|
381
|
+
super(DataSerializer, self).__init__(content_type=content_type)
|
|
382
|
+
|
|
383
|
+
def serialize(self, data):
|
|
384
|
+
"""Serialize file data to a raw bytes.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
data (object): Data to be serialized. The data can be a string
|
|
388
|
+
representing file-path or the raw bytes from a file.
|
|
389
|
+
Returns:
|
|
390
|
+
raw-bytes: The data serialized as raw-bytes from the input.
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
if isinstance(data, str):
|
|
394
|
+
try:
|
|
395
|
+
with open(data, "rb") as data_file:
|
|
396
|
+
data_file_info = data_file.read()
|
|
397
|
+
return data_file_info
|
|
398
|
+
except Exception as e:
|
|
399
|
+
raise ValueError(f"Could not open/read file: {data}. {e}")
|
|
400
|
+
if isinstance(data, bytes):
|
|
401
|
+
return data
|
|
402
|
+
if isinstance(data, dict) and "data" in data:
|
|
403
|
+
return self.serialize(data["data"])
|
|
404
|
+
|
|
405
|
+
raise ValueError(f"Object of type {type(data)} is not Data serializable.")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class StringSerializer(SimpleBaseSerializer):
|
|
409
|
+
"""Encode the string to utf-8 bytes."""
|
|
410
|
+
|
|
411
|
+
def __init__(self, content_type="text/plain"):
|
|
412
|
+
"""Initialize a ``StringSerializer`` instance.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
416
|
+
request data (default: "text/plain").
|
|
417
|
+
"""
|
|
418
|
+
super(StringSerializer, self).__init__(content_type=content_type)
|
|
419
|
+
|
|
420
|
+
def serialize(self, data):
|
|
421
|
+
"""Encode the string to utf-8 bytes.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
data (object): Data to be serialized.
|
|
425
|
+
Returns:
|
|
426
|
+
raw-bytes: The data serialized as raw-bytes from the input.
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
if isinstance(data, str):
|
|
430
|
+
return data.encode("utf-8")
|
|
431
|
+
|
|
432
|
+
raise ValueError(f"Object of type {type(data)} is not String serializable.")
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class TorchTensorSerializer(SimpleBaseSerializer):
|
|
436
|
+
"""Serialize torch.Tensor to a buffer by converting tensor to numpy and call NumpySerializer.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
data (object): Data to be serialized. The data must be of torch.Tensor type.
|
|
440
|
+
Returns:
|
|
441
|
+
raw-bytes: The data serialized as raw-bytes from the input.
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
def __init__(self, content_type="tensor/pt"):
|
|
445
|
+
super(TorchTensorSerializer, self).__init__(content_type=content_type)
|
|
446
|
+
from torch import Tensor
|
|
447
|
+
|
|
448
|
+
self.torch_tensor = Tensor
|
|
449
|
+
self.numpy_serializer = NumpySerializer()
|
|
450
|
+
|
|
451
|
+
def serialize(self, data):
|
|
452
|
+
"""Serialize torch.Tensor to a buffer.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
data (object): Data to be serialized. The data must be of torch.Tensor type.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
raw-bytes: The data serialized as raw-bytes from the input.
|
|
459
|
+
"""
|
|
460
|
+
if isinstance(data, self.torch_tensor):
|
|
461
|
+
try:
|
|
462
|
+
return self.numpy_serializer.serialize(data.detach().numpy())
|
|
463
|
+
except Exception as e:
|
|
464
|
+
raise ValueError(
|
|
465
|
+
"Unable to serialize your data because: %s.\
|
|
466
|
+
Please provide custom serialization in InferenceSpec. "
|
|
467
|
+
% e
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
raise ValueError("Object of type %s is not a torch.Tensor" % type(data))
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
# TODO fix the unit test for this serializer
|
|
474
|
+
class RecordSerializer(SimpleBaseSerializer):
|
|
475
|
+
"""Serialize a NumPy array for an inference request."""
|
|
476
|
+
|
|
477
|
+
def __init__(self, content_type="application/x-recordio-protobuf"):
|
|
478
|
+
"""Initialize a ``RecordSerializer`` instance.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
content_type (str): The MIME type to signal to the inference endpoint when sending
|
|
482
|
+
request data (default: "application/x-recordio-protobuf").
|
|
483
|
+
"""
|
|
484
|
+
super(RecordSerializer, self).__init__(content_type=content_type)
|
|
485
|
+
|
|
486
|
+
def serialize(self, data):
|
|
487
|
+
"""Serialize a NumPy array into a buffer containing RecordIO records.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
data (numpy.ndarray): The data to serialize.
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
io.BytesIO: A buffer containing the data serialized as records.
|
|
494
|
+
"""
|
|
495
|
+
if len(data.shape) == 1:
|
|
496
|
+
data = data.reshape(1, data.shape[0])
|
|
497
|
+
|
|
498
|
+
if len(data.shape) != 2:
|
|
499
|
+
raise ValueError(
|
|
500
|
+
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
buffer = io.BytesIO()
|
|
504
|
+
# Lazy import to avoid circular dependency
|
|
505
|
+
from sagemaker.core.serializers.utils import write_numpy_to_dense_tensor
|
|
506
|
+
|
|
507
|
+
write_numpy_to_dense_tensor(buffer, data)
|
|
508
|
+
buffer.seek(0)
|
|
509
|
+
|
|
510
|
+
return buffer
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Implements methods for serializing data for an inference endpoint."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import List, Optional
|
|
17
|
+
|
|
18
|
+
# base_serializers was refactored from serializers.
|
|
19
|
+
# this import ensures backward compatibility.
|
|
20
|
+
from sagemaker.core.serializers.base import ( # noqa: F401 # pylint: disable=W0611
|
|
21
|
+
BaseSerializer,
|
|
22
|
+
CSVSerializer,
|
|
23
|
+
DataSerializer,
|
|
24
|
+
IdentitySerializer,
|
|
25
|
+
JSONLinesSerializer,
|
|
26
|
+
JSONSerializer,
|
|
27
|
+
LibSVMSerializer,
|
|
28
|
+
NumpySerializer,
|
|
29
|
+
SimpleBaseSerializer,
|
|
30
|
+
SparseMatrixSerializer,
|
|
31
|
+
TorchTensorSerializer,
|
|
32
|
+
StringSerializer,
|
|
33
|
+
RecordSerializer,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from sagemaker.core.deprecations import deprecated_class
|
|
37
|
+
from sagemaker.core.jumpstart import artifacts, utils as jumpstart_utils
|
|
38
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
39
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
40
|
+
from sagemaker.core.helper.session_helper import Session
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def retrieve_options(
|
|
44
|
+
region: Optional[str] = None,
|
|
45
|
+
model_id: Optional[str] = None,
|
|
46
|
+
model_version: Optional[str] = None,
|
|
47
|
+
hub_arn: Optional[str] = None,
|
|
48
|
+
tolerate_vulnerable_model: bool = False,
|
|
49
|
+
tolerate_deprecated_model: bool = False,
|
|
50
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
51
|
+
config_name: Optional[str] = None,
|
|
52
|
+
) -> List[BaseSerializer]:
|
|
53
|
+
"""Retrieves the supported serializers for the model matching the given arguments.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
region (str): The AWS Region for which to retrieve the supported serializers.
|
|
57
|
+
Defaults to ``None``.
|
|
58
|
+
model_id (str): The model ID of the model for which to
|
|
59
|
+
retrieve the supported serializers. (Default: None).
|
|
60
|
+
model_version (str): The version of the model for which to retrieve the
|
|
61
|
+
supported serializers. (Default: None).
|
|
62
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
63
|
+
model details from. (Default: None).
|
|
64
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
65
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
66
|
+
exception if the script used by this version of the model has dependencies with known
|
|
67
|
+
security vulnerabilities. (Default: False).
|
|
68
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
69
|
+
(exception not raised). False if these models should raise an exception.
|
|
70
|
+
(Default: False).
|
|
71
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
72
|
+
object, used for SageMaker interactions. If not
|
|
73
|
+
specified, one is created using the default AWS configuration
|
|
74
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
75
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
76
|
+
Returns:
|
|
77
|
+
List[SimpleBaseSerializer]: The supported serializers to use for the model.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving serializers."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return artifacts._retrieve_serializer_options(
|
|
89
|
+
model_id=model_id,
|
|
90
|
+
model_version=model_version,
|
|
91
|
+
hub_arn=hub_arn,
|
|
92
|
+
region=region,
|
|
93
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
94
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
95
|
+
sagemaker_session=sagemaker_session,
|
|
96
|
+
config_name=config_name,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def retrieve_default(
|
|
101
|
+
region: Optional[str] = None,
|
|
102
|
+
model_id: Optional[str] = None,
|
|
103
|
+
model_version: Optional[str] = None,
|
|
104
|
+
hub_arn: Optional[str] = None,
|
|
105
|
+
tolerate_vulnerable_model: bool = False,
|
|
106
|
+
tolerate_deprecated_model: bool = False,
|
|
107
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
108
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
109
|
+
config_name: Optional[str] = None,
|
|
110
|
+
) -> BaseSerializer:
|
|
111
|
+
"""Retrieves the default serializer for the model matching the given arguments.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
region (str): The AWS Region for which to retrieve the default serializer.
|
|
115
|
+
Defaults to ``None``.
|
|
116
|
+
model_id (str): The model ID of the model for which to
|
|
117
|
+
retrieve the default serializer. (Default: None).
|
|
118
|
+
model_version (str): The version of the model for which to retrieve the
|
|
119
|
+
default serializer. (Default: None).
|
|
120
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
121
|
+
model details from. (Default: None).
|
|
122
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
123
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
124
|
+
exception if the script used by this version of the model has dependencies with known
|
|
125
|
+
security vulnerabilities. (Default: False).
|
|
126
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
127
|
+
(exception not raised). False if these models should raise an exception.
|
|
128
|
+
(Default: False).
|
|
129
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
130
|
+
object, used for SageMaker interactions. If not
|
|
131
|
+
specified, one is created using the default AWS configuration
|
|
132
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
133
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
134
|
+
Returns:
|
|
135
|
+
SimpleBaseSerializer: The default serializer to use for the model.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving serializers."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return artifacts._retrieve_default_serializer(
|
|
147
|
+
model_id=model_id,
|
|
148
|
+
model_version=model_version,
|
|
149
|
+
hub_arn=hub_arn,
|
|
150
|
+
region=region,
|
|
151
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
152
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
153
|
+
sagemaker_session=sagemaker_session,
|
|
154
|
+
model_type=model_type,
|
|
155
|
+
config_name=config_name,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer")
|