sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Placeholder docstring"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import datetime
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import tempfile
|
|
21
|
+
import time
|
|
22
|
+
|
|
23
|
+
import sagemaker.core.local.data
|
|
24
|
+
|
|
25
|
+
from sagemaker.core.local.image import _SageMakerContainer
|
|
26
|
+
from sagemaker.core.local.utils import (
|
|
27
|
+
copy_directory_structure,
|
|
28
|
+
move_to_destination,
|
|
29
|
+
get_docker_host,
|
|
30
|
+
)
|
|
31
|
+
from sagemaker.core.common_utils import DeferredError, get_config_value, format_tags
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
import urllib3
|
|
37
|
+
except ImportError as e:
|
|
38
|
+
logger.warning("urllib3 failed to import. Local mode features will be impaired or broken.")
|
|
39
|
+
# Any subsequent attempt to use urllib3 will raise the ImportError
|
|
40
|
+
urllib3 = DeferredError(e)
|
|
41
|
+
|
|
42
|
+
_UNUSED_ARN = "local:arn-does-not-matter"
|
|
43
|
+
HEALTH_CHECK_TIMEOUT_LIMIT = 120
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _LocalProcessingJob:
|
|
47
|
+
"""Defines and starts a local processing job."""
|
|
48
|
+
|
|
49
|
+
_STARTING = "Starting"
|
|
50
|
+
_PROCESSING = "Processing"
|
|
51
|
+
_COMPLETED = "Completed"
|
|
52
|
+
|
|
53
|
+
def __init__(self, container):
|
|
54
|
+
"""Creates a local processing job.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
container: the local container object.
|
|
58
|
+
"""
|
|
59
|
+
self.container = container
|
|
60
|
+
self.state = "Created"
|
|
61
|
+
self.start_time = None
|
|
62
|
+
self.end_time = None
|
|
63
|
+
self.processing_job_name = ""
|
|
64
|
+
self.processing_inputs = None
|
|
65
|
+
self.processing_output_config = None
|
|
66
|
+
self.environment = None
|
|
67
|
+
|
|
68
|
+
def start(self, processing_inputs, processing_output_config, environment, processing_job_name):
|
|
69
|
+
"""Starts a local processing job.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
processing_inputs: The processing input configuration.
|
|
73
|
+
processing_output_config: The processing input configuration.
|
|
74
|
+
environment: The collection of environment variables passed to the job.
|
|
75
|
+
processing_job_name: The processing job name.
|
|
76
|
+
"""
|
|
77
|
+
self.state = self._STARTING
|
|
78
|
+
|
|
79
|
+
for item in processing_inputs:
|
|
80
|
+
if "DatasetDefinition" in item:
|
|
81
|
+
raise RuntimeError("DatasetDefinition is not currently supported in Local Mode")
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
s3_input = item["S3Input"]
|
|
85
|
+
except KeyError:
|
|
86
|
+
raise ValueError("Processing input must have a valid ['S3Input']")
|
|
87
|
+
|
|
88
|
+
item["DataUri"] = s3_input["S3Uri"]
|
|
89
|
+
|
|
90
|
+
if "S3InputMode" in s3_input and s3_input["S3InputMode"] != "File":
|
|
91
|
+
raise RuntimeError(
|
|
92
|
+
"S3InputMode: %s is not currently supported in Local Mode"
|
|
93
|
+
% s3_input["S3InputMode"]
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if (
|
|
97
|
+
"S3DataDistributionType" in s3_input
|
|
98
|
+
and s3_input["S3DataDistributionType"] != "FullyReplicated"
|
|
99
|
+
):
|
|
100
|
+
raise RuntimeError(
|
|
101
|
+
"DataDistribution: %s is not currently supported in Local Mode"
|
|
102
|
+
% s3_input["S3DataDistributionType"]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if "S3CompressionType" in s3_input and s3_input["S3CompressionType"] != "None":
|
|
106
|
+
raise RuntimeError(
|
|
107
|
+
"CompressionType: %s is not currently supported in Local Mode"
|
|
108
|
+
% s3_input["S3CompressionType"]
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if processing_output_config and "Outputs" in processing_output_config:
|
|
112
|
+
processing_outputs = processing_output_config["Outputs"]
|
|
113
|
+
|
|
114
|
+
for item in processing_outputs:
|
|
115
|
+
if "FeatureStoreOutput" in item:
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"FeatureStoreOutput is not currently supported in Local Mode"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
s3_output = item["S3Output"]
|
|
122
|
+
except KeyError:
|
|
123
|
+
raise ValueError("Processing output must have a valid ['S3Output']")
|
|
124
|
+
|
|
125
|
+
if s3_output["S3UploadMode"] != "EndOfJob":
|
|
126
|
+
raise RuntimeError(
|
|
127
|
+
"UploadMode: %s is not currently supported in Local Mode."
|
|
128
|
+
% s3_output["S3UploadMode"]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.start_time = datetime.datetime.now()
|
|
132
|
+
self.state = self._PROCESSING
|
|
133
|
+
|
|
134
|
+
self.processing_job_name = processing_job_name
|
|
135
|
+
self.processing_inputs = processing_inputs
|
|
136
|
+
self.processing_output_config = processing_output_config
|
|
137
|
+
self.environment = environment
|
|
138
|
+
|
|
139
|
+
self.container.process(
|
|
140
|
+
processing_inputs, processing_output_config, environment, processing_job_name
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
self.end_time = datetime.datetime.now()
|
|
144
|
+
self.state = self._COMPLETED
|
|
145
|
+
|
|
146
|
+
def describe(self):
|
|
147
|
+
"""Describes a local processing job.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
An object describing the processing job.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
response = {
|
|
154
|
+
"ProcessingJobArn": self.processing_job_name,
|
|
155
|
+
"ProcessingJobName": self.processing_job_name,
|
|
156
|
+
"AppSpecification": {
|
|
157
|
+
"ImageUri": self.container.image,
|
|
158
|
+
"ContainerEntrypoint": self.container.container_entrypoint,
|
|
159
|
+
"ContainerArguments": self.container.container_arguments,
|
|
160
|
+
},
|
|
161
|
+
"Environment": self.environment,
|
|
162
|
+
"ProcessingInputs": self.processing_inputs,
|
|
163
|
+
"ProcessingOutputConfig": self.processing_output_config,
|
|
164
|
+
"ProcessingResources": {
|
|
165
|
+
"ClusterConfig": {
|
|
166
|
+
"InstanceCount": self.container.instance_count,
|
|
167
|
+
"InstanceType": self.container.instance_type,
|
|
168
|
+
"VolumeSizeInGB": 30,
|
|
169
|
+
"VolumeKmsKeyId": None,
|
|
170
|
+
}
|
|
171
|
+
},
|
|
172
|
+
"RoleArn": "<no_role>",
|
|
173
|
+
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
|
|
174
|
+
"ProcessingJobStatus": self.state,
|
|
175
|
+
"ProcessingStartTime": self.start_time,
|
|
176
|
+
"ProcessingEndTime": self.end_time,
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
return response
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class _LocalTrainingJob(object):
|
|
183
|
+
"""Defines and starts a local training job."""
|
|
184
|
+
|
|
185
|
+
_STARTING = "Starting"
|
|
186
|
+
_TRAINING = "Training"
|
|
187
|
+
_COMPLETED = "Completed"
|
|
188
|
+
_states = ["Starting", "Training", "Completed"]
|
|
189
|
+
|
|
190
|
+
def __init__(self, container):
|
|
191
|
+
"""Creates a local training job.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
container: the local container object.
|
|
195
|
+
"""
|
|
196
|
+
self.container = container
|
|
197
|
+
self.model_artifacts = None
|
|
198
|
+
self.state = "created"
|
|
199
|
+
self.start_time = None
|
|
200
|
+
self.end_time = None
|
|
201
|
+
self.environment = None
|
|
202
|
+
self.training_job_name = ""
|
|
203
|
+
self.output_data_config = None
|
|
204
|
+
|
|
205
|
+
def start(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
|
|
206
|
+
"""Starts a local training job.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
input_data_config (dict): The Input Data Configuration, this contains data such as the
|
|
210
|
+
channels to be used for training.
|
|
211
|
+
output_data_config (dict): The configuration of the output data.
|
|
212
|
+
hyperparameters (dict): The HyperParameters for the training job.
|
|
213
|
+
environment (dict): The collection of environment variables passed to the job.
|
|
214
|
+
job_name (str): Name of the local training job being run.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
ValueError: If the input data configuration is not valid.
|
|
218
|
+
RuntimeError: If the data distribution type is not supported.
|
|
219
|
+
"""
|
|
220
|
+
for channel in input_data_config:
|
|
221
|
+
if channel["DataSource"] and "S3DataSource" in channel["DataSource"]:
|
|
222
|
+
data_distribution = channel["DataSource"]["S3DataSource"].get(
|
|
223
|
+
"S3DataDistributionType", None
|
|
224
|
+
)
|
|
225
|
+
data_uri = channel["DataSource"]["S3DataSource"]["S3Uri"]
|
|
226
|
+
elif channel["DataSource"] and "FileDataSource" in channel["DataSource"]:
|
|
227
|
+
data_distribution = channel["DataSource"]["FileDataSource"][
|
|
228
|
+
"FileDataDistributionType"
|
|
229
|
+
]
|
|
230
|
+
data_uri = channel["DataSource"]["FileDataSource"]["FileUri"]
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"Need channel['DataSource'] to have ['S3DataSource'] or ['FileDataSource']"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# use a single Data URI - this makes handling S3 and File Data easier down the stack
|
|
237
|
+
channel["DataUri"] = data_uri
|
|
238
|
+
|
|
239
|
+
supported_distributions = ["FullyReplicated"]
|
|
240
|
+
if data_distribution and data_distribution not in supported_distributions:
|
|
241
|
+
raise RuntimeError(
|
|
242
|
+
"Invalid DataDistribution: '{}'. Local mode currently supports: {}.".format(
|
|
243
|
+
data_distribution, ", ".join(supported_distributions)
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
self.start_time = datetime.datetime.now()
|
|
248
|
+
self.state = self._TRAINING
|
|
249
|
+
self.environment = environment
|
|
250
|
+
|
|
251
|
+
self.model_artifacts = self.container.train(
|
|
252
|
+
input_data_config, output_data_config, hyperparameters, environment, job_name
|
|
253
|
+
)
|
|
254
|
+
self.end_time = datetime.datetime.now()
|
|
255
|
+
self.state = self._COMPLETED
|
|
256
|
+
self.training_job_name = job_name
|
|
257
|
+
self.output_data_config = output_data_config
|
|
258
|
+
|
|
259
|
+
def describe(self):
|
|
260
|
+
"""Placeholder docstring"""
|
|
261
|
+
response = {
|
|
262
|
+
"TrainingJobName": self.training_job_name,
|
|
263
|
+
"TrainingJobArn": _UNUSED_ARN,
|
|
264
|
+
"ResourceConfig": {"InstanceCount": self.container.instance_count},
|
|
265
|
+
"TrainingJobStatus": self.state,
|
|
266
|
+
"TrainingStartTime": self.start_time,
|
|
267
|
+
"TrainingEndTime": self.end_time,
|
|
268
|
+
"ModelArtifacts": {"S3ModelArtifacts": self.model_artifacts},
|
|
269
|
+
"OutputDataConfig": self.output_data_config,
|
|
270
|
+
"Environment": self.environment,
|
|
271
|
+
"AlgorithmSpecification": {
|
|
272
|
+
"ContainerEntrypoint": self.container.container_entrypoint,
|
|
273
|
+
},
|
|
274
|
+
}
|
|
275
|
+
return response
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class _LocalTransformJob(object):
|
|
279
|
+
"""Placeholder docstring"""
|
|
280
|
+
|
|
281
|
+
_CREATING = "Creating"
|
|
282
|
+
_COMPLETED = "Completed"
|
|
283
|
+
|
|
284
|
+
def __init__(self, transform_job_name, model_name, local_session=None):
|
|
285
|
+
from sagemaker.core.local.local_session import LocalSession
|
|
286
|
+
|
|
287
|
+
self.local_session = local_session or LocalSession()
|
|
288
|
+
local_client = self.local_session.sagemaker_client
|
|
289
|
+
|
|
290
|
+
self.name = transform_job_name
|
|
291
|
+
self.model_name = model_name
|
|
292
|
+
|
|
293
|
+
# TODO - support SageMaker Models not just local models. This is not
|
|
294
|
+
# ideal but it may be a good thing to do.
|
|
295
|
+
self.primary_container = local_client.describe_model(model_name)["PrimaryContainer"]
|
|
296
|
+
self.container = None
|
|
297
|
+
self.start_time = None
|
|
298
|
+
self.end_time = None
|
|
299
|
+
self.batch_strategy = None
|
|
300
|
+
self.transform_resources = None
|
|
301
|
+
self.input_data = None
|
|
302
|
+
self.output_data = None
|
|
303
|
+
self.environment = {}
|
|
304
|
+
self.state = _LocalTransformJob._CREATING
|
|
305
|
+
|
|
306
|
+
def start(self, input_data, output_data, transform_resources, **kwargs):
|
|
307
|
+
"""Start the Local Transform Job
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
input_data (dict): Describes the dataset to be transformed and the
|
|
311
|
+
location where it is stored.
|
|
312
|
+
output_data (dict): Identifies the location where to save the
|
|
313
|
+
results from the transform job
|
|
314
|
+
transform_resources (dict): compute instances for the transform job.
|
|
315
|
+
Currently only supports local or local_gpu
|
|
316
|
+
**kwargs: additional arguments coming from the boto request object
|
|
317
|
+
"""
|
|
318
|
+
self.transform_resources = transform_resources
|
|
319
|
+
self.input_data = input_data
|
|
320
|
+
self.output_data = output_data
|
|
321
|
+
|
|
322
|
+
image = self.primary_container["Image"]
|
|
323
|
+
instance_type = transform_resources["InstanceType"]
|
|
324
|
+
instance_count = 1
|
|
325
|
+
|
|
326
|
+
environment = self._get_container_environment(**kwargs)
|
|
327
|
+
|
|
328
|
+
# Start the container, pass the environment and wait for it to start up
|
|
329
|
+
self.container = _SageMakerContainer(
|
|
330
|
+
instance_type, instance_count, image, self.local_session
|
|
331
|
+
)
|
|
332
|
+
self.container.serve(self.primary_container["ModelDataUrl"], environment)
|
|
333
|
+
|
|
334
|
+
serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080
|
|
335
|
+
_wait_for_serving_container(serving_port)
|
|
336
|
+
|
|
337
|
+
# Get capabilities from Container if needed
|
|
338
|
+
endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port)
|
|
339
|
+
response, code = _perform_request(endpoint_url)
|
|
340
|
+
if code == 200:
|
|
341
|
+
execution_parameters = json.loads(response.data.decode("utf-8"))
|
|
342
|
+
# MaxConcurrentTransforms is ignored because we currently only support 1
|
|
343
|
+
for setting in ("BatchStrategy", "MaxPayloadInMB"):
|
|
344
|
+
if setting not in kwargs and setting in execution_parameters:
|
|
345
|
+
kwargs[setting] = execution_parameters[setting]
|
|
346
|
+
|
|
347
|
+
# Apply Defaults if none was provided
|
|
348
|
+
kwargs.update(self._get_required_defaults(**kwargs))
|
|
349
|
+
|
|
350
|
+
self.start_time = datetime.datetime.now()
|
|
351
|
+
self.batch_strategy = kwargs["BatchStrategy"]
|
|
352
|
+
if "Environment" in kwargs:
|
|
353
|
+
self.environment = kwargs["Environment"]
|
|
354
|
+
|
|
355
|
+
# run the batch inference requests
|
|
356
|
+
self._perform_batch_inference(input_data, output_data, **kwargs)
|
|
357
|
+
self.end_time = datetime.datetime.now()
|
|
358
|
+
self.state = self._COMPLETED
|
|
359
|
+
|
|
360
|
+
def describe(self):
|
|
361
|
+
"""Describe this _LocalTransformJob
|
|
362
|
+
|
|
363
|
+
The response is a JSON-like dictionary that follows the response of
|
|
364
|
+
the boto describe_transform_job() API.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
dict: description of this _LocalTransformJob
|
|
368
|
+
"""
|
|
369
|
+
response = {
|
|
370
|
+
"TransformJobStatus": self.state,
|
|
371
|
+
"ModelName": self.model_name,
|
|
372
|
+
"TransformJobName": self.name,
|
|
373
|
+
"TransformJobArn": _UNUSED_ARN,
|
|
374
|
+
"TransformEndTime": self.end_time,
|
|
375
|
+
"CreationTime": self.start_time,
|
|
376
|
+
"TransformStartTime": self.start_time,
|
|
377
|
+
"Environment": {},
|
|
378
|
+
"BatchStrategy": self.batch_strategy,
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
if self.transform_resources:
|
|
382
|
+
response["TransformResources"] = self.transform_resources
|
|
383
|
+
|
|
384
|
+
if self.output_data:
|
|
385
|
+
response["TransformOutput"] = self.output_data
|
|
386
|
+
|
|
387
|
+
if self.input_data:
|
|
388
|
+
response["TransformInput"] = self.input_data
|
|
389
|
+
|
|
390
|
+
return response
|
|
391
|
+
|
|
392
|
+
def _get_container_environment(self, **kwargs):
|
|
393
|
+
"""Get all the Environment variables that will be passed to the container.
|
|
394
|
+
|
|
395
|
+
Certain input fields such as BatchStrategy have different values for
|
|
396
|
+
the API vs the Environment variables, such as SingleRecord vs
|
|
397
|
+
SINGLE_RECORD. This method also handles this conversion.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
**kwargs: existing transform arguments
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
dict: All the environment variables that should be set in the
|
|
404
|
+
container
|
|
405
|
+
"""
|
|
406
|
+
environment = {}
|
|
407
|
+
environment.update(self.primary_container["Environment"])
|
|
408
|
+
environment["SAGEMAKER_BATCH"] = "True"
|
|
409
|
+
if "MaxPayloadInMB" in kwargs:
|
|
410
|
+
environment["SAGEMAKER_MAX_PAYLOAD_IN_MB"] = str(kwargs["MaxPayloadInMB"])
|
|
411
|
+
|
|
412
|
+
if "BatchStrategy" in kwargs:
|
|
413
|
+
if kwargs["BatchStrategy"] == "SingleRecord":
|
|
414
|
+
strategy_env_value = "SINGLE_RECORD"
|
|
415
|
+
elif kwargs["BatchStrategy"] == "MultiRecord":
|
|
416
|
+
strategy_env_value = "MULTI_RECORD"
|
|
417
|
+
else:
|
|
418
|
+
raise ValueError("Invalid BatchStrategy, must be 'SingleRecord' or 'MultiRecord'")
|
|
419
|
+
environment["SAGEMAKER_BATCH_STRATEGY"] = strategy_env_value
|
|
420
|
+
|
|
421
|
+
# we only do 1 max concurrent transform in Local Mode
|
|
422
|
+
if "MaxConcurrentTransforms" in kwargs and int(kwargs["MaxConcurrentTransforms"]) > 1:
|
|
423
|
+
logger.warning(
|
|
424
|
+
"Local Mode only supports 1 ConcurrentTransform. Setting MaxConcurrentTransforms "
|
|
425
|
+
"to 1"
|
|
426
|
+
)
|
|
427
|
+
environment["SAGEMAKER_MAX_CONCURRENT_TRANSFORMS"] = "1"
|
|
428
|
+
|
|
429
|
+
# if there were environment variables passed to the Transformer we will pass them to the
|
|
430
|
+
# container as well.
|
|
431
|
+
if "Environment" in kwargs:
|
|
432
|
+
environment.update(kwargs["Environment"])
|
|
433
|
+
return environment
|
|
434
|
+
|
|
435
|
+
def _get_required_defaults(self, **kwargs):
|
|
436
|
+
"""Return the default values.
|
|
437
|
+
|
|
438
|
+
The values might be anything that was not provided by either the user or the container
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
**kwargs: current transform arguments
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
dict: key/values for the default parameters that are missing.
|
|
445
|
+
"""
|
|
446
|
+
defaults = {}
|
|
447
|
+
if "BatchStrategy" not in kwargs:
|
|
448
|
+
defaults["BatchStrategy"] = "MultiRecord"
|
|
449
|
+
|
|
450
|
+
if "MaxPayloadInMB" not in kwargs:
|
|
451
|
+
defaults["MaxPayloadInMB"] = 6
|
|
452
|
+
|
|
453
|
+
return defaults
|
|
454
|
+
|
|
455
|
+
def _get_working_directory(self):
|
|
456
|
+
"""Placeholder docstring"""
|
|
457
|
+
# Root dir to use for intermediate data location. To make things simple we will write here
|
|
458
|
+
# regardless of the final destination. At the end the files will either be moved or
|
|
459
|
+
# uploaded to S3 and deleted.
|
|
460
|
+
root_dir = get_config_value("local.container_root", self.local_session.config)
|
|
461
|
+
if root_dir:
|
|
462
|
+
root_dir = os.path.abspath(root_dir)
|
|
463
|
+
|
|
464
|
+
working_dir = tempfile.mkdtemp(dir=root_dir)
|
|
465
|
+
return working_dir
|
|
466
|
+
|
|
467
|
+
def _prepare_data_transformation(self, input_data, batch_strategy):
|
|
468
|
+
"""Prepares the data for transformation.
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
input_data: Input data source.
|
|
472
|
+
batch_strategy: Strategy for batch transformation to get.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
A (data source, batch provider) pair.
|
|
476
|
+
"""
|
|
477
|
+
input_path = input_data["DataSource"]["S3DataSource"]["S3Uri"]
|
|
478
|
+
data_source = sagemaker.core.local.data.get_data_source_instance(
|
|
479
|
+
input_path, self.local_session
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
split_type = input_data["SplitType"] if "SplitType" in input_data else None
|
|
483
|
+
splitter = sagemaker.core.local.data.get_splitter_instance(split_type)
|
|
484
|
+
|
|
485
|
+
batch_provider = sagemaker.core.local.data.get_batch_strategy_instance(
|
|
486
|
+
batch_strategy, splitter
|
|
487
|
+
)
|
|
488
|
+
return data_source, batch_provider
|
|
489
|
+
|
|
490
|
+
def _perform_batch_inference(self, input_data, output_data, **kwargs):
|
|
491
|
+
"""Perform batch inference on the given input data.
|
|
492
|
+
|
|
493
|
+
Transforms the input data to feed the serving container. It first gathers
|
|
494
|
+
the files from S3 or Local FileSystem. It then splits the files as required
|
|
495
|
+
(Line, RecordIO, None), and finally, it batch them according to the batch
|
|
496
|
+
strategy and limit the request size.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
input_data: Input data source.
|
|
500
|
+
output_data: Output data source.
|
|
501
|
+
**kwargs: Additional configuration arguments.
|
|
502
|
+
"""
|
|
503
|
+
batch_strategy = kwargs["BatchStrategy"]
|
|
504
|
+
max_payload = int(kwargs["MaxPayloadInMB"])
|
|
505
|
+
data_source, batch_provider = self._prepare_data_transformation(input_data, batch_strategy)
|
|
506
|
+
|
|
507
|
+
# Output settings
|
|
508
|
+
accept = output_data["Accept"] if "Accept" in output_data else None
|
|
509
|
+
|
|
510
|
+
working_dir = self._get_working_directory()
|
|
511
|
+
dataset_dir = data_source.get_root_dir()
|
|
512
|
+
|
|
513
|
+
for fn in data_source.get_file_list():
|
|
514
|
+
|
|
515
|
+
relative_path = os.path.dirname(os.path.relpath(fn, dataset_dir))
|
|
516
|
+
filename = os.path.basename(fn)
|
|
517
|
+
copy_directory_structure(working_dir, relative_path)
|
|
518
|
+
destination_path = os.path.join(working_dir, relative_path, filename + ".out")
|
|
519
|
+
|
|
520
|
+
with open(destination_path, "wb") as f:
|
|
521
|
+
for item in batch_provider.pad(fn, max_payload):
|
|
522
|
+
# call the container and add the result to inference.
|
|
523
|
+
response = self.local_session.sagemaker_runtime_client.invoke_endpoint(
|
|
524
|
+
item, "", input_data["ContentType"], accept
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
response_body = response["Body"]
|
|
528
|
+
data = response_body.read()
|
|
529
|
+
response_body.close()
|
|
530
|
+
f.write(data)
|
|
531
|
+
if "AssembleWith" in output_data and output_data["AssembleWith"] == "Line":
|
|
532
|
+
f.write(b"\n")
|
|
533
|
+
|
|
534
|
+
move_to_destination(working_dir, output_data["S3OutputPath"], self.name, self.local_session)
|
|
535
|
+
self.container.stop_serving()
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class _LocalModel(object):
|
|
539
|
+
"""Placeholder docstring"""
|
|
540
|
+
|
|
541
|
+
def __init__(self, model_name, primary_container):
|
|
542
|
+
self.model_name = model_name
|
|
543
|
+
self.primary_container = primary_container
|
|
544
|
+
self.creation_time = datetime.datetime.now()
|
|
545
|
+
|
|
546
|
+
def describe(self):
|
|
547
|
+
"""Placeholder docstring"""
|
|
548
|
+
response = {
|
|
549
|
+
"ModelName": self.model_name,
|
|
550
|
+
"CreationTime": self.creation_time,
|
|
551
|
+
"ExecutionRoleArn": _UNUSED_ARN,
|
|
552
|
+
"ModelArn": _UNUSED_ARN,
|
|
553
|
+
"PrimaryContainer": self.primary_container,
|
|
554
|
+
}
|
|
555
|
+
return response
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class _LocalEndpointConfig(object):
|
|
559
|
+
"""Placeholder docstring"""
|
|
560
|
+
|
|
561
|
+
def __init__(self, config_name, production_variants, tags=None):
|
|
562
|
+
self.name = config_name
|
|
563
|
+
self.production_variants = production_variants
|
|
564
|
+
self.tags = format_tags(tags)
|
|
565
|
+
self.creation_time = datetime.datetime.now()
|
|
566
|
+
|
|
567
|
+
def describe(self):
|
|
568
|
+
"""Placeholder docstring"""
|
|
569
|
+
response = {
|
|
570
|
+
"EndpointConfigName": self.name,
|
|
571
|
+
"EndpointConfigArn": _UNUSED_ARN,
|
|
572
|
+
"Tags": self.tags,
|
|
573
|
+
"CreationTime": self.creation_time,
|
|
574
|
+
"ProductionVariants": self.production_variants,
|
|
575
|
+
}
|
|
576
|
+
return response
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
class _LocalEndpoint(object):
|
|
580
|
+
"""Placeholder docstring"""
|
|
581
|
+
|
|
582
|
+
_CREATING = "Creating"
|
|
583
|
+
_IN_SERVICE = "InService"
|
|
584
|
+
_FAILED = "Failed"
|
|
585
|
+
|
|
586
|
+
def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session=None):
|
|
587
|
+
# runtime import since there is a cyclic dependency between entities and local_session
|
|
588
|
+
from sagemaker.core.local.local_session import LocalSession
|
|
589
|
+
|
|
590
|
+
self.local_session = local_session or LocalSession()
|
|
591
|
+
local_client = self.local_session.sagemaker_client
|
|
592
|
+
|
|
593
|
+
self.name = endpoint_name
|
|
594
|
+
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
|
|
595
|
+
self.production_variant = self.endpoint_config["ProductionVariants"][0]
|
|
596
|
+
self.tags = format_tags(tags)
|
|
597
|
+
|
|
598
|
+
model_name = self.production_variant["ModelName"]
|
|
599
|
+
self.primary_container = local_client.describe_model(model_name)["PrimaryContainer"]
|
|
600
|
+
|
|
601
|
+
self.container = None
|
|
602
|
+
self.create_time = None
|
|
603
|
+
self.state = _LocalEndpoint._CREATING
|
|
604
|
+
|
|
605
|
+
def serve(self):
|
|
606
|
+
"""Placeholder docstring"""
|
|
607
|
+
image = self.primary_container["Image"]
|
|
608
|
+
instance_type = self.production_variant["InstanceType"]
|
|
609
|
+
instance_count = self.production_variant["InitialInstanceCount"]
|
|
610
|
+
|
|
611
|
+
accelerator_type = self.production_variant.get("AcceleratorType")
|
|
612
|
+
if accelerator_type == "local_sagemaker_notebook":
|
|
613
|
+
self.primary_container["Environment"][
|
|
614
|
+
"SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
|
|
615
|
+
] = "true"
|
|
616
|
+
|
|
617
|
+
self.create_time = datetime.datetime.now()
|
|
618
|
+
self.container = _SageMakerContainer(
|
|
619
|
+
instance_type, instance_count, image, self.local_session
|
|
620
|
+
)
|
|
621
|
+
self.container.serve(
|
|
622
|
+
self.primary_container["ModelDataUrl"], self.primary_container["Environment"]
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080
|
|
626
|
+
_wait_for_serving_container(serving_port)
|
|
627
|
+
# the container is running and it passed the healthcheck status is now InService
|
|
628
|
+
self.state = _LocalEndpoint._IN_SERVICE
|
|
629
|
+
|
|
630
|
+
def stop(self):
|
|
631
|
+
"""Placeholder docstring"""
|
|
632
|
+
if self.container:
|
|
633
|
+
self.container.stop_serving()
|
|
634
|
+
|
|
635
|
+
def describe(self):
|
|
636
|
+
"""Placeholder docstring"""
|
|
637
|
+
response = {
|
|
638
|
+
"EndpointConfigName": self.endpoint_config["EndpointConfigName"],
|
|
639
|
+
"CreationTime": self.create_time,
|
|
640
|
+
"ProductionVariants": self.endpoint_config["ProductionVariants"],
|
|
641
|
+
"Tags": self.tags,
|
|
642
|
+
"EndpointName": self.name,
|
|
643
|
+
"EndpointArn": _UNUSED_ARN,
|
|
644
|
+
"EndpointStatus": self.state,
|
|
645
|
+
}
|
|
646
|
+
return response
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def _wait_for_serving_container(serving_port):
|
|
650
|
+
"""Placeholder docstring."""
|
|
651
|
+
i = 0
|
|
652
|
+
http = urllib3.PoolManager()
|
|
653
|
+
|
|
654
|
+
endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port)
|
|
655
|
+
while True:
|
|
656
|
+
i += 5
|
|
657
|
+
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:
|
|
658
|
+
raise RuntimeError("Giving up, endpoint didn't launch correctly")
|
|
659
|
+
|
|
660
|
+
logger.info("Checking if serving container is up, attempt: %s", i)
|
|
661
|
+
_, code = _perform_request(endpoint_url, http)
|
|
662
|
+
if code != 200:
|
|
663
|
+
logger.info("Container still not up, got: %s", code)
|
|
664
|
+
else:
|
|
665
|
+
return
|
|
666
|
+
|
|
667
|
+
time.sleep(5)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def _perform_request(endpoint_url, pool_manager=None):
|
|
671
|
+
"""Placeholder docstring."""
|
|
672
|
+
http = pool_manager or urllib3.PoolManager()
|
|
673
|
+
try:
|
|
674
|
+
r = http.request("GET", endpoint_url)
|
|
675
|
+
code = r.status
|
|
676
|
+
except urllib3.exceptions.RequestError:
|
|
677
|
+
return None, -1
|
|
678
|
+
return r, code
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Custom Exceptions for local mode."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
# StepExecutionException has been moved to sagemaker.mlops.local.exceptions
|
|
17
|
+
# as it's specific to pipeline execution which is now in MLOps
|