sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""S3 utilities for SageMaker."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
# Re-export from client
|
|
17
|
+
from sagemaker.core.s3.client import ( # noqa: F401
|
|
18
|
+
S3Uploader,
|
|
19
|
+
S3Downloader,
|
|
20
|
+
parse_s3_url,
|
|
21
|
+
is_s3_url,
|
|
22
|
+
s3_path_join,
|
|
23
|
+
determine_bucket_and_prefix,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Re-export from utils (these are duplicated but kept for compatibility)
|
|
27
|
+
from sagemaker.core.s3.utils import ( # noqa: F401
|
|
28
|
+
parse_s3_url as parse_s3_url_utils,
|
|
29
|
+
is_s3_url as is_s3_url_utils,
|
|
30
|
+
s3_path_join as s3_path_join_utils,
|
|
31
|
+
determine_bucket_and_prefix as determine_bucket_and_prefix_utils,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"S3Uploader",
|
|
36
|
+
"S3Downloader",
|
|
37
|
+
"parse_s3_url",
|
|
38
|
+
"is_s3_url",
|
|
39
|
+
"s3_path_join",
|
|
40
|
+
"determine_bucket_and_prefix",
|
|
41
|
+
]
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains Enums and helper methods related to S3."""
|
|
14
|
+
from __future__ import print_function, absolute_import
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import io
|
|
18
|
+
|
|
19
|
+
from typing import Union
|
|
20
|
+
from functools import reduce
|
|
21
|
+
from typing import Optional
|
|
22
|
+
|
|
23
|
+
from six.moves.urllib.parse import urlparse
|
|
24
|
+
from sagemaker.core.helper.session_helper import Session
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger("sagemaker")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class S3Uploader(object):
|
|
30
|
+
"""Contains static methods for uploading directories or files to S3."""
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None, callback=None):
|
|
34
|
+
"""Static method that uploads a given file or directory to S3.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
local_path (str): Path (absolute or relative) of local file or directory to upload.
|
|
38
|
+
desired_s3_uri (str): The desired S3 location to upload to. It is the prefix to
|
|
39
|
+
which the local filename will be added.
|
|
40
|
+
kms_key (str): The KMS key to use to encrypt the files.
|
|
41
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
42
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
43
|
+
AWS services needed. If not specified, one is created
|
|
44
|
+
using the default AWS configuration chain.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The S3 uri of the uploaded file(s).
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
sagemaker_session = sagemaker_session or Session()
|
|
51
|
+
bucket, key_prefix = parse_s3_url(url=desired_s3_uri)
|
|
52
|
+
if kms_key is not None:
|
|
53
|
+
extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"}
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
extra_args = None
|
|
57
|
+
|
|
58
|
+
return sagemaker_session.upload_data(
|
|
59
|
+
path=local_path,
|
|
60
|
+
bucket=bucket,
|
|
61
|
+
key_prefix=key_prefix,
|
|
62
|
+
callback=callback,
|
|
63
|
+
extra_args=extra_args,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def upload_string_as_file_body(
|
|
68
|
+
body: str, desired_s3_uri=None, kms_key=None, sagemaker_session=None
|
|
69
|
+
):
|
|
70
|
+
"""Static method that uploads a given file or directory to S3.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
body (str): String representing the body of the file.
|
|
74
|
+
desired_s3_uri (str): The desired S3 uri to upload to.
|
|
75
|
+
kms_key (str): The KMS key to use to encrypt the files.
|
|
76
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
77
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
78
|
+
AWS services needed.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
str: The S3 uri of the uploaded file.
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
sagemaker_session = sagemaker_session or Session()
|
|
85
|
+
|
|
86
|
+
bucket, key = parse_s3_url(desired_s3_uri)
|
|
87
|
+
|
|
88
|
+
sagemaker_session.upload_string_as_file_body(
|
|
89
|
+
body=body, bucket=bucket, key=key, kms_key=kms_key
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return desired_s3_uri
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def upload_bytes(b: Union[bytes, io.BytesIO], s3_uri, kms_key=None, sagemaker_session=None):
|
|
96
|
+
"""Static method that uploads a given file or directory to S3.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
b (bytes or io.BytesIO): bytes.
|
|
100
|
+
s3_uri (str): The S3 uri to upload to.
|
|
101
|
+
kms_key (str): The KMS key to use to encrypt the files.
|
|
102
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
103
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
104
|
+
AWS services needed. If not specified, one is created
|
|
105
|
+
using the default AWS configuration chain.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
str: The S3 uri of the uploaded file.
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
sagemaker_session = sagemaker_session or Session()
|
|
112
|
+
|
|
113
|
+
bucket, object_key = parse_s3_url(s3_uri)
|
|
114
|
+
|
|
115
|
+
if kms_key is not None:
|
|
116
|
+
extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"}
|
|
117
|
+
else:
|
|
118
|
+
extra_args = None
|
|
119
|
+
|
|
120
|
+
b = b if isinstance(b, io.BytesIO) else io.BytesIO(b)
|
|
121
|
+
sagemaker_session.s3_resource.Bucket(bucket).upload_fileobj(
|
|
122
|
+
b, object_key, ExtraArgs=extra_args
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return s3_uri
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class S3Downloader(object):
|
|
129
|
+
"""Contains static methods for downloading directories or files from S3."""
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def download(s3_uri, local_path, kms_key=None, sagemaker_session=None):
|
|
133
|
+
"""Static method that downloads a given S3 uri to the local machine.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
s3_uri (str): An S3 uri to download from.
|
|
137
|
+
local_path (str): A local path to download the file(s) to.
|
|
138
|
+
kms_key (str): The KMS key to use to decrypt the files.
|
|
139
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
140
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
141
|
+
AWS services needed. If not specified, one is created
|
|
142
|
+
using the default AWS configuration chain.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
list[str]: List of local paths of downloaded files
|
|
146
|
+
"""
|
|
147
|
+
sagemaker_session = sagemaker_session or Session()
|
|
148
|
+
bucket, key_prefix = parse_s3_url(url=s3_uri)
|
|
149
|
+
if kms_key is not None:
|
|
150
|
+
extra_args = {"SSECustomerKey": kms_key}
|
|
151
|
+
else:
|
|
152
|
+
extra_args = None
|
|
153
|
+
|
|
154
|
+
return sagemaker_session.download_data(
|
|
155
|
+
path=local_path, bucket=bucket, key_prefix=key_prefix, extra_args=extra_args
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def read_file(s3_uri, sagemaker_session=None) -> str:
|
|
160
|
+
"""Static method that returns the contents of a s3 uri file body as a string.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
s3_uri (str): An S3 uri that refers to a single file.
|
|
164
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
165
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
166
|
+
AWS services needed. If not specified, one is created
|
|
167
|
+
using the default AWS configuration chain.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
str: The body of the file.
|
|
171
|
+
"""
|
|
172
|
+
sagemaker_session = sagemaker_session or Session()
|
|
173
|
+
|
|
174
|
+
bucket, object_key = parse_s3_url(url=s3_uri)
|
|
175
|
+
|
|
176
|
+
return sagemaker_session.read_s3_file(bucket=bucket, key_prefix=object_key)
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def read_bytes(s3_uri, sagemaker_session=None) -> bytes:
|
|
180
|
+
"""Static method that returns the contents of a s3 object as bytes.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
s3_uri (str): An S3 uri that refers to a s3 object.
|
|
184
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
185
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
186
|
+
AWS services needed. If not specified, one is created
|
|
187
|
+
using the default AWS configuration chain.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
bytes: The body of the file.
|
|
191
|
+
"""
|
|
192
|
+
sagemaker_session = sagemaker_session or Session()
|
|
193
|
+
|
|
194
|
+
bucket, object_key = parse_s3_url(s3_uri)
|
|
195
|
+
|
|
196
|
+
bytes_io = io.BytesIO()
|
|
197
|
+
sagemaker_session.s3_resource.Bucket(bucket).download_fileobj(object_key, bytes_io)
|
|
198
|
+
bytes_io.seek(0)
|
|
199
|
+
return bytes_io.read()
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def list(s3_uri, sagemaker_session=None):
|
|
203
|
+
"""Static method that lists the contents of an S3 uri.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
s3_uri (str): The S3 base uri to list objects in.
|
|
207
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
208
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
209
|
+
AWS services needed. If not specified, one is created
|
|
210
|
+
using the default AWS configuration chain.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
[str]: The list of S3 URIs in the given S3 base uri.
|
|
214
|
+
"""
|
|
215
|
+
sagemaker_session = sagemaker_session or Session()
|
|
216
|
+
bucket, key_prefix = parse_s3_url(url=s3_uri)
|
|
217
|
+
|
|
218
|
+
file_keys = sagemaker_session.list_s3_files(bucket=bucket, key_prefix=key_prefix)
|
|
219
|
+
return [s3_path_join("s3://", bucket, file_key) for file_key in file_keys]
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def parse_s3_url(url):
|
|
223
|
+
"""Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 scheme.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
url (str):
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
tuple: A tuple containing:
|
|
230
|
+
|
|
231
|
+
- str: S3 bucket name
|
|
232
|
+
- str: S3 key
|
|
233
|
+
"""
|
|
234
|
+
parsed_url = urlparse(url)
|
|
235
|
+
if parsed_url.scheme != "s3":
|
|
236
|
+
raise ValueError("Expecting 's3' scheme, got: {} in {}.".format(parsed_url.scheme, url))
|
|
237
|
+
return parsed_url.netloc, parsed_url.path.lstrip("/")
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def is_s3_url(url):
|
|
241
|
+
"""Returns True if url is an s3 url, False if not
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
url (str):
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
bool:
|
|
248
|
+
"""
|
|
249
|
+
parsed_url = urlparse(url)
|
|
250
|
+
return parsed_url.scheme == "s3"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def s3_path_join(*args, with_end_slash: bool = False):
|
|
254
|
+
"""Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
|
|
255
|
+
|
|
256
|
+
Behavior of this function:
|
|
257
|
+
- If the first argument is "s3://", then that is preserved.
|
|
258
|
+
- The output by default will have no slashes at the beginning or end. There is one exception
|
|
259
|
+
(see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
|
|
260
|
+
`"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
|
|
261
|
+
- Any repeat slashes will be removed in the output (except for "s3://" if provided at the
|
|
262
|
+
beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
|
|
263
|
+
`"s3://foo/bar/baz"`.
|
|
264
|
+
- Empty or None arguments will be skipped. For example
|
|
265
|
+
`s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
|
|
266
|
+
|
|
267
|
+
Alternatives to this function that are NOT recommended for S3 paths:
|
|
268
|
+
- `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
|
|
269
|
+
- `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
|
|
270
|
+
dots (".") and root directories. (for example
|
|
271
|
+
`pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
|
|
272
|
+
- `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
*args: The strings to join with a slash.
|
|
276
|
+
with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
|
|
277
|
+
to the end of the path
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
str: The joined string, without a slash at the end unless with_end_slash is True.
|
|
281
|
+
"""
|
|
282
|
+
delimiter = "/"
|
|
283
|
+
|
|
284
|
+
non_empty_args = list(filter(lambda item: item is not None and item != "", args))
|
|
285
|
+
|
|
286
|
+
merged_path = ""
|
|
287
|
+
for index, path in enumerate(non_empty_args):
|
|
288
|
+
if (
|
|
289
|
+
index == 0
|
|
290
|
+
or (merged_path and merged_path[-1] == delimiter)
|
|
291
|
+
or (path and path[0] == delimiter)
|
|
292
|
+
):
|
|
293
|
+
# dont need to add an extra slash because either this is the beginning of the string,
|
|
294
|
+
# or one (or more) slash already exists
|
|
295
|
+
merged_path += path
|
|
296
|
+
else:
|
|
297
|
+
merged_path += delimiter + path
|
|
298
|
+
|
|
299
|
+
if with_end_slash and merged_path and merged_path[-1] != delimiter:
|
|
300
|
+
merged_path += delimiter
|
|
301
|
+
|
|
302
|
+
# At this point, merged_path may include slashes at the beginning and/or end. And some of the
|
|
303
|
+
# provided args may have had duplicate slashes inside or at the ends.
|
|
304
|
+
# For backwards compatibility reasons, these need to be filtered out (done below). In the
|
|
305
|
+
# future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
|
|
306
|
+
# one option is to create a new optional argument (or a new function) that only executes the
|
|
307
|
+
# logic above.
|
|
308
|
+
filtered_path = merged_path
|
|
309
|
+
|
|
310
|
+
# remove duplicate slashes
|
|
311
|
+
if filtered_path:
|
|
312
|
+
|
|
313
|
+
def duplicate_delimiter_remover(sequence, next_char):
|
|
314
|
+
if sequence[-1] == delimiter and next_char == delimiter:
|
|
315
|
+
return sequence
|
|
316
|
+
return sequence + next_char
|
|
317
|
+
|
|
318
|
+
if filtered_path.startswith("s3://"):
|
|
319
|
+
filtered_path = reduce(
|
|
320
|
+
duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
|
|
324
|
+
|
|
325
|
+
# remove beginning slashes
|
|
326
|
+
filtered_path = filtered_path.lstrip(delimiter)
|
|
327
|
+
|
|
328
|
+
# remove end slashes
|
|
329
|
+
if not with_end_slash and filtered_path != "s3://":
|
|
330
|
+
filtered_path = filtered_path.rstrip(delimiter)
|
|
331
|
+
|
|
332
|
+
return filtered_path
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def determine_bucket_and_prefix(
|
|
336
|
+
bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
|
|
337
|
+
):
|
|
338
|
+
"""Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
bucket (Optional[str]): S3 Bucket to use (if it exists)
|
|
342
|
+
key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
|
|
343
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session to fetch a default bucket and
|
|
344
|
+
prefix from, if bucket doesn't exist. Expected to exist
|
|
345
|
+
|
|
346
|
+
Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
|
|
347
|
+
"""
|
|
348
|
+
if bucket:
|
|
349
|
+
final_bucket = bucket
|
|
350
|
+
final_key_prefix = key_prefix
|
|
351
|
+
else:
|
|
352
|
+
final_bucket = sagemaker_session.default_bucket()
|
|
353
|
+
|
|
354
|
+
# default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
|
|
355
|
+
# exist and we are using the Session's default_bucket.
|
|
356
|
+
final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
|
|
357
|
+
|
|
358
|
+
# We should not append default_bucket_prefix even if the bucket exists but is equal to the
|
|
359
|
+
# default_bucket, because either:
|
|
360
|
+
# (1) the bucket was explicitly passed in by the user and just happens to be the same as the
|
|
361
|
+
# default_bucket (in which case we don't want to change the user's input), or
|
|
362
|
+
# (2) the default_bucket was fetched from Session earlier already (and the default prefix
|
|
363
|
+
# should have been fetched then as well), and then this function was
|
|
364
|
+
# called with it. If we appended the default prefix here, we would be appending it more than
|
|
365
|
+
# once in total.
|
|
366
|
+
|
|
367
|
+
return final_bucket, final_key_prefix
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains helper functions related to S3. You may want to use `s3.py` instead.
|
|
14
|
+
|
|
15
|
+
This has a subset of the functions available through s3.py. This module was initially created with
|
|
16
|
+
functions that were originally in `s3.py` so that those functions could be imported inside
|
|
17
|
+
`session.py` without circular dependencies. (`s3.py` imports Session as a dependency.)
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import print_function, absolute_import
|
|
20
|
+
|
|
21
|
+
import logging
|
|
22
|
+
from functools import reduce
|
|
23
|
+
from typing import Optional
|
|
24
|
+
|
|
25
|
+
from six.moves.urllib.parse import urlparse
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("sagemaker")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def parse_s3_url(url):
|
|
31
|
+
"""Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 scheme.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
url (str):
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
tuple: A tuple containing:
|
|
38
|
+
|
|
39
|
+
- str: S3 bucket name
|
|
40
|
+
- str: S3 key
|
|
41
|
+
"""
|
|
42
|
+
parsed_url = urlparse(url)
|
|
43
|
+
if parsed_url.scheme != "s3":
|
|
44
|
+
raise ValueError("Expecting 's3' scheme, got: {} in {}.".format(parsed_url.scheme, url))
|
|
45
|
+
return parsed_url.netloc, parsed_url.path.lstrip("/")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def is_s3_url(url):
|
|
49
|
+
"""Returns True if url is an s3 url, False if not
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
url (str):
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
bool:
|
|
56
|
+
"""
|
|
57
|
+
parsed_url = urlparse(url)
|
|
58
|
+
return parsed_url.scheme == "s3"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def s3_path_join(*args, with_end_slash: bool = False):
|
|
62
|
+
"""Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
|
|
63
|
+
|
|
64
|
+
Behavior of this function:
|
|
65
|
+
- If the first argument is "s3://", then that is preserved.
|
|
66
|
+
- The output by default will have no slashes at the beginning or end. There is one exception
|
|
67
|
+
(see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
|
|
68
|
+
`"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
|
|
69
|
+
- Any repeat slashes will be removed in the output (except for "s3://" if provided at the
|
|
70
|
+
beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
|
|
71
|
+
`"s3://foo/bar/baz"`.
|
|
72
|
+
- Empty or None arguments will be skipped. For example
|
|
73
|
+
`s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
|
|
74
|
+
|
|
75
|
+
Alternatives to this function that are NOT recommended for S3 paths:
|
|
76
|
+
- `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
|
|
77
|
+
- `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
|
|
78
|
+
dots (".") and root directories. (for example
|
|
79
|
+
`pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
|
|
80
|
+
- `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
*args: The strings to join with a slash.
|
|
84
|
+
with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
|
|
85
|
+
to the end of the path
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
str: The joined string, without a slash at the end unless with_end_slash is True.
|
|
89
|
+
"""
|
|
90
|
+
delimiter = "/"
|
|
91
|
+
|
|
92
|
+
non_empty_args = list(filter(lambda item: item is not None and item != "", args))
|
|
93
|
+
|
|
94
|
+
merged_path = ""
|
|
95
|
+
for index, path in enumerate(non_empty_args):
|
|
96
|
+
if (
|
|
97
|
+
index == 0
|
|
98
|
+
or (merged_path and merged_path[-1] == delimiter)
|
|
99
|
+
or (path and path[0] == delimiter)
|
|
100
|
+
):
|
|
101
|
+
# dont need to add an extra slash because either this is the beginning of the string,
|
|
102
|
+
# or one (or more) slash already exists
|
|
103
|
+
merged_path += path
|
|
104
|
+
else:
|
|
105
|
+
merged_path += delimiter + path
|
|
106
|
+
|
|
107
|
+
if with_end_slash and merged_path and merged_path[-1] != delimiter:
|
|
108
|
+
merged_path += delimiter
|
|
109
|
+
|
|
110
|
+
# At this point, merged_path may include slashes at the beginning and/or end. And some of the
|
|
111
|
+
# provided args may have had duplicate slashes inside or at the ends.
|
|
112
|
+
# For backwards compatibility reasons, these need to be filtered out (done below). In the
|
|
113
|
+
# future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
|
|
114
|
+
# one option is to create a new optional argument (or a new function) that only executes the
|
|
115
|
+
# logic above.
|
|
116
|
+
filtered_path = merged_path
|
|
117
|
+
|
|
118
|
+
# remove duplicate slashes
|
|
119
|
+
if filtered_path:
|
|
120
|
+
|
|
121
|
+
def duplicate_delimiter_remover(sequence, next_char):
|
|
122
|
+
if sequence[-1] == delimiter and next_char == delimiter:
|
|
123
|
+
return sequence
|
|
124
|
+
return sequence + next_char
|
|
125
|
+
|
|
126
|
+
if filtered_path.startswith("s3://"):
|
|
127
|
+
filtered_path = reduce(
|
|
128
|
+
duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
|
|
132
|
+
|
|
133
|
+
# remove beginning slashes
|
|
134
|
+
filtered_path = filtered_path.lstrip(delimiter)
|
|
135
|
+
|
|
136
|
+
# remove end slashes
|
|
137
|
+
if not with_end_slash and filtered_path != "s3://":
|
|
138
|
+
filtered_path = filtered_path.rstrip(delimiter)
|
|
139
|
+
|
|
140
|
+
return filtered_path
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def determine_bucket_and_prefix(
|
|
144
|
+
bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
|
|
145
|
+
):
|
|
146
|
+
"""Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
bucket (Optional[str]): S3 Bucket to use (if it exists)
|
|
150
|
+
key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
|
|
151
|
+
sagemaker_session (sagemaker.session.Session): Session to fetch a default bucket and
|
|
152
|
+
prefix from, if bucket doesn't exist. Expected to exist
|
|
153
|
+
|
|
154
|
+
Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
|
|
155
|
+
"""
|
|
156
|
+
if bucket:
|
|
157
|
+
final_bucket = bucket
|
|
158
|
+
final_key_prefix = key_prefix
|
|
159
|
+
else:
|
|
160
|
+
final_bucket = sagemaker_session.default_bucket()
|
|
161
|
+
|
|
162
|
+
# default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
|
|
163
|
+
# exist and we are using the Session's default_bucket.
|
|
164
|
+
final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
|
|
165
|
+
|
|
166
|
+
# We should not append default_bucket_prefix even if the bucket exists but is equal to the
|
|
167
|
+
# default_bucket, because either:
|
|
168
|
+
# (1) the bucket was explicitly passed in by the user and just happens to be the same as the
|
|
169
|
+
# default_bucket (in which case we don't want to change the user's input), or
|
|
170
|
+
# (2) the default_bucket was fetched from Session earlier already (and the default prefix
|
|
171
|
+
# should have been fetched then as well), and then this function was
|
|
172
|
+
# called with it. If we appended the default prefix here, we would be appending it more than
|
|
173
|
+
# once in total.
|
|
174
|
+
|
|
175
|
+
return final_bucket, final_key_prefix
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Accessors to retrieve the script Amazon S3 URI to run pretrained machine learning models."""
|
|
14
|
+
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
from sagemaker.core.jumpstart import utils as jumpstart_utils
|
|
21
|
+
from sagemaker.core.jumpstart import artifacts
|
|
22
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
23
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
24
|
+
from sagemaker.core.helper.session_helper import Session
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def retrieve(
|
|
30
|
+
region: Optional[str] = None,
|
|
31
|
+
model_id: Optional[str] = None,
|
|
32
|
+
model_version: Optional[str] = None,
|
|
33
|
+
hub_arn: Optional[str] = None,
|
|
34
|
+
script_scope: Optional[str] = None,
|
|
35
|
+
tolerate_vulnerable_model: bool = False,
|
|
36
|
+
tolerate_deprecated_model: bool = False,
|
|
37
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
38
|
+
config_name: Optional[str] = None,
|
|
39
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
40
|
+
) -> str:
|
|
41
|
+
"""Retrieves the script S3 URI associated with the model matching the given arguments.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
region (str): The AWS Region for which to retrieve the model script S3 URI.
|
|
45
|
+
model_id (str): The model ID of the JumpStart model for which to
|
|
46
|
+
retrieve the script S3 URI.
|
|
47
|
+
model_version (str): The version of the JumpStart model for which to retrieve the
|
|
48
|
+
model script S3 URI.
|
|
49
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
50
|
+
model details from. (Default: None).
|
|
51
|
+
script_scope (str): The script type.
|
|
52
|
+
Valid values: "training" and "inference".
|
|
53
|
+
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
|
|
54
|
+
specifications should be tolerated without raising an exception. If ``False``, raises an
|
|
55
|
+
exception if the script used by this version of the model has dependencies with known
|
|
56
|
+
security vulnerabilities. (Default: False).
|
|
57
|
+
tolerate_deprecated_model (bool): ``True`` if deprecated models should be tolerated
|
|
58
|
+
without raising an exception. ``False`` if these models should raise an exception.
|
|
59
|
+
(Default: False).
|
|
60
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
61
|
+
object, used for SageMaker interactions. If not
|
|
62
|
+
specified, one is created using the default AWS configuration
|
|
63
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
64
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
65
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
66
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
67
|
+
Returns:
|
|
68
|
+
str: The model script URI for the corresponding model.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
NotImplementedError: If the scope is not supported.
|
|
72
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
73
|
+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
|
|
74
|
+
known security vulnerabilities.
|
|
75
|
+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
|
|
76
|
+
"""
|
|
77
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving script URIs."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return artifacts._retrieve_script_uri(
|
|
83
|
+
model_id=model_id,
|
|
84
|
+
model_version=model_version,
|
|
85
|
+
hub_arn=hub_arn,
|
|
86
|
+
script_scope=script_scope,
|
|
87
|
+
region=region,
|
|
88
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
89
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
90
|
+
sagemaker_session=sagemaker_session,
|
|
91
|
+
config_name=config_name,
|
|
92
|
+
model_type=model_type,
|
|
93
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Serializers for SageMaker inference."""
|
|
2
|
+
|
|
3
|
+
from __future__ import absolute_import
|
|
4
|
+
|
|
5
|
+
# Re-export from base
|
|
6
|
+
from sagemaker.core.serializers.base import * # noqa: F401, F403
|
|
7
|
+
|
|
8
|
+
# Note: implementations and utils are not imported here to avoid circular imports
|
|
9
|
+
# Import them explicitly if needed:
|
|
10
|
+
# from sagemaker.core.serializers.implementations import ...
|
|
11
|
+
# from sagemaker.core.serializers.utils import ...
|