sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,739 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Placeholder docstring"""
|
|
14
|
+
from __future__ import absolute_import, annotations
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import platform
|
|
18
|
+
from typing import Dict
|
|
19
|
+
|
|
20
|
+
import boto3
|
|
21
|
+
from botocore.exceptions import ClientError
|
|
22
|
+
import jsonschema
|
|
23
|
+
|
|
24
|
+
from sagemaker.core.config.config_schema import (
|
|
25
|
+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
|
|
26
|
+
SESSION_DEFAULT_S3_BUCKET_PATH,
|
|
27
|
+
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
|
|
28
|
+
)
|
|
29
|
+
from sagemaker.core.config.config import (
|
|
30
|
+
load_local_mode_config,
|
|
31
|
+
load_sagemaker_config,
|
|
32
|
+
validate_sagemaker_config,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
from sagemaker.core.local.image import _SageMakerContainer
|
|
36
|
+
from sagemaker.core.local.utils import get_docker_host
|
|
37
|
+
from sagemaker.core.local.entities import (
|
|
38
|
+
_LocalEndpointConfig,
|
|
39
|
+
_LocalEndpoint,
|
|
40
|
+
_LocalModel,
|
|
41
|
+
_LocalProcessingJob,
|
|
42
|
+
_LocalTrainingJob,
|
|
43
|
+
_LocalTransformJob,
|
|
44
|
+
)
|
|
45
|
+
from sagemaker.core.helper.session_helper import Session
|
|
46
|
+
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
|
|
47
|
+
from sagemaker.core.telemetry.constants import Feature
|
|
48
|
+
from sagemaker.core.common_utils import (
|
|
49
|
+
get_config_value,
|
|
50
|
+
_module_import_error,
|
|
51
|
+
resolve_value_from_config,
|
|
52
|
+
format_tags,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
logger = logging.getLogger(__name__)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LocalSagemakerClient(object): # pylint: disable=too-many-public-methods
|
|
59
|
+
"""A SageMakerClient that implements the API calls locally.
|
|
60
|
+
|
|
61
|
+
Used for doing local training and hosting local endpoints. It still needs access to
|
|
62
|
+
a boto client to interact with S3 but it won't perform any SageMaker call.
|
|
63
|
+
|
|
64
|
+
Implements the methods with the same signature as the boto SageMakerClient.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
_processing_jobs = {}
|
|
73
|
+
_training_jobs = {}
|
|
74
|
+
_transform_jobs = {}
|
|
75
|
+
_models = {}
|
|
76
|
+
_endpoint_configs = {}
|
|
77
|
+
_endpoints = {}
|
|
78
|
+
|
|
79
|
+
def __init__(self, sagemaker_session=None):
|
|
80
|
+
"""Initialize a LocalSageMakerClient.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
sagemaker_session (sagemaker.core.helper.session.Session): a session to use to read configurations
|
|
84
|
+
from, and use its boto client.
|
|
85
|
+
"""
|
|
86
|
+
self.sagemaker_session = sagemaker_session or LocalSession()
|
|
87
|
+
|
|
88
|
+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_processing_job")
|
|
89
|
+
def create_processing_job(
|
|
90
|
+
self,
|
|
91
|
+
ProcessingJobName,
|
|
92
|
+
AppSpecification,
|
|
93
|
+
ProcessingResources,
|
|
94
|
+
Environment=None,
|
|
95
|
+
ProcessingInputs=None,
|
|
96
|
+
ProcessingOutputConfig=None,
|
|
97
|
+
**kwargs,
|
|
98
|
+
):
|
|
99
|
+
"""Creates a processing job in Local Mode
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
ProcessingJobName(str): local processing job name.
|
|
103
|
+
AppSpecification(dict): Identifies the container and application to run.
|
|
104
|
+
ProcessingResources(dict): Identifies the resources to use for local processing.
|
|
105
|
+
Environment(dict, optional): Describes the environment variables to pass
|
|
106
|
+
to the container. (Default value = None)
|
|
107
|
+
ProcessingInputs(dict, optional): Describes the processing input data.
|
|
108
|
+
(Default value = None)
|
|
109
|
+
ProcessingOutputConfig(dict, optional): Describes the processing output
|
|
110
|
+
configuration. (Default value = None)
|
|
111
|
+
**kwargs: Keyword arguments
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
Environment = Environment or {}
|
|
117
|
+
ProcessingInputs = ProcessingInputs or []
|
|
118
|
+
ProcessingOutputConfig = ProcessingOutputConfig or {}
|
|
119
|
+
|
|
120
|
+
container_entrypoint = None
|
|
121
|
+
if "ContainerEntrypoint" in AppSpecification:
|
|
122
|
+
container_entrypoint = AppSpecification["ContainerEntrypoint"]
|
|
123
|
+
|
|
124
|
+
container_arguments = None
|
|
125
|
+
if "ContainerArguments" in AppSpecification:
|
|
126
|
+
container_arguments = AppSpecification["ContainerArguments"]
|
|
127
|
+
|
|
128
|
+
if "ExperimentConfig" in kwargs:
|
|
129
|
+
logger.warning("Experiment configuration is not supported in local mode.")
|
|
130
|
+
if "NetworkConfig" in kwargs:
|
|
131
|
+
logger.warning("Network configuration is not supported in local mode.")
|
|
132
|
+
if "StoppingCondition" in kwargs:
|
|
133
|
+
logger.warning("Stopping condition is not supported in local mode.")
|
|
134
|
+
|
|
135
|
+
container = _SageMakerContainer(
|
|
136
|
+
ProcessingResources["ClusterConfig"]["InstanceType"],
|
|
137
|
+
ProcessingResources["ClusterConfig"]["InstanceCount"],
|
|
138
|
+
AppSpecification["ImageUri"],
|
|
139
|
+
sagemaker_session=self.sagemaker_session,
|
|
140
|
+
container_entrypoint=container_entrypoint,
|
|
141
|
+
container_arguments=container_arguments,
|
|
142
|
+
)
|
|
143
|
+
processing_job = _LocalProcessingJob(container)
|
|
144
|
+
logger.info("Starting processing job")
|
|
145
|
+
processing_job.start(
|
|
146
|
+
ProcessingInputs, ProcessingOutputConfig, Environment, ProcessingJobName
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
LocalSagemakerClient._processing_jobs[ProcessingJobName] = processing_job
|
|
150
|
+
|
|
151
|
+
def describe_processing_job(self, ProcessingJobName):
|
|
152
|
+
"""Describes a local processing job.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
ProcessingJobName(str): Processing job name to describe.
|
|
156
|
+
Returns: (dict) DescribeProcessingJob Response.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
if ProcessingJobName not in LocalSagemakerClient._processing_jobs:
|
|
162
|
+
error_response = {
|
|
163
|
+
"Error": {
|
|
164
|
+
"Code": "ValidationException",
|
|
165
|
+
"Message": "Could not find local processing job",
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
raise ClientError(error_response, "describe_processing_job")
|
|
169
|
+
return LocalSagemakerClient._processing_jobs[ProcessingJobName].describe()
|
|
170
|
+
|
|
171
|
+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_training_job")
|
|
172
|
+
def create_training_job(
|
|
173
|
+
self,
|
|
174
|
+
TrainingJobName,
|
|
175
|
+
AlgorithmSpecification,
|
|
176
|
+
OutputDataConfig,
|
|
177
|
+
ResourceConfig,
|
|
178
|
+
InputDataConfig=None,
|
|
179
|
+
Environment=None,
|
|
180
|
+
**kwargs,
|
|
181
|
+
):
|
|
182
|
+
"""Create a training job in Local Mode.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
TrainingJobName(str): local training job name.
|
|
186
|
+
AlgorithmSpecification(dict): Identifies the training algorithm to use.
|
|
187
|
+
InputDataConfig(dict, optional): Describes the training dataset and the location where
|
|
188
|
+
it is stored. (Default value = None)
|
|
189
|
+
OutputDataConfig(dict): Identifies the location where you want to save the results of
|
|
190
|
+
model training.
|
|
191
|
+
ResourceConfig(dict): Identifies the resources to use for local model training.
|
|
192
|
+
Environment(dict, optional): Describes the environment variables to pass
|
|
193
|
+
to the container. (Default value = None)
|
|
194
|
+
HyperParameters(dict) [optional]: Specifies these algorithm-specific parameters to
|
|
195
|
+
influence the quality of the final model.
|
|
196
|
+
**kwargs:
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
|
|
200
|
+
"""
|
|
201
|
+
InputDataConfig = InputDataConfig or {}
|
|
202
|
+
Environment = Environment or {}
|
|
203
|
+
container = _SageMakerContainer(
|
|
204
|
+
ResourceConfig["InstanceType"],
|
|
205
|
+
ResourceConfig["InstanceCount"],
|
|
206
|
+
AlgorithmSpecification["TrainingImage"],
|
|
207
|
+
sagemaker_session=self.sagemaker_session,
|
|
208
|
+
)
|
|
209
|
+
if AlgorithmSpecification.get("ContainerEntrypoint", None):
|
|
210
|
+
container.container_entrypoint = AlgorithmSpecification["ContainerEntrypoint"]
|
|
211
|
+
if AlgorithmSpecification.get("ContainerArguments", None):
|
|
212
|
+
container.container_arguments = AlgorithmSpecification["ContainerArguments"]
|
|
213
|
+
training_job = _LocalTrainingJob(container)
|
|
214
|
+
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
|
|
215
|
+
logger.info("Starting training job")
|
|
216
|
+
training_job.start(
|
|
217
|
+
InputDataConfig, OutputDataConfig, hyperparameters, Environment, TrainingJobName
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
|
|
221
|
+
|
|
222
|
+
def describe_training_job(self, TrainingJobName):
|
|
223
|
+
"""Describe a local training job.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
TrainingJobName(str): Training job name to describe.
|
|
227
|
+
Returns: (dict) DescribeTrainingJob Response.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
|
|
231
|
+
"""
|
|
232
|
+
if TrainingJobName not in LocalSagemakerClient._training_jobs:
|
|
233
|
+
error_response = {
|
|
234
|
+
"Error": {
|
|
235
|
+
"Code": "ValidationException",
|
|
236
|
+
"Message": "Could not find local training job",
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
raise ClientError(error_response, "describe_training_job")
|
|
240
|
+
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()
|
|
241
|
+
|
|
242
|
+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_transform_job")
|
|
243
|
+
def create_transform_job(
|
|
244
|
+
self,
|
|
245
|
+
TransformJobName,
|
|
246
|
+
ModelName,
|
|
247
|
+
TransformInput,
|
|
248
|
+
TransformOutput,
|
|
249
|
+
TransformResources,
|
|
250
|
+
**kwargs,
|
|
251
|
+
):
|
|
252
|
+
"""Create the transform job.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
TransformJobName:
|
|
256
|
+
ModelName:
|
|
257
|
+
TransformInput:
|
|
258
|
+
TransformOutput:
|
|
259
|
+
TransformResources:
|
|
260
|
+
**kwargs:
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
|
|
264
|
+
"""
|
|
265
|
+
transform_job = _LocalTransformJob(TransformJobName, ModelName, self.sagemaker_session)
|
|
266
|
+
LocalSagemakerClient._transform_jobs[TransformJobName] = transform_job
|
|
267
|
+
transform_job.start(TransformInput, TransformOutput, TransformResources, **kwargs)
|
|
268
|
+
|
|
269
|
+
def describe_transform_job(self, TransformJobName):
|
|
270
|
+
"""Describe the transform job.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
TransformJobName:
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
|
|
277
|
+
"""
|
|
278
|
+
if TransformJobName not in LocalSagemakerClient._transform_jobs:
|
|
279
|
+
error_response = {
|
|
280
|
+
"Error": {
|
|
281
|
+
"Code": "ValidationException",
|
|
282
|
+
"Message": "Could not find local transform job",
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
raise ClientError(error_response, "describe_transform_job")
|
|
286
|
+
return LocalSagemakerClient._transform_jobs[TransformJobName].describe()
|
|
287
|
+
|
|
288
|
+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_model")
|
|
289
|
+
def create_model(
|
|
290
|
+
self, ModelName, PrimaryContainer, *args, **kwargs
|
|
291
|
+
): # pylint: disable=unused-argument
|
|
292
|
+
"""Create a Local Model Object.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
ModelName (str): the Model Name
|
|
296
|
+
PrimaryContainer (dict): a SageMaker primary container definition
|
|
297
|
+
*args:
|
|
298
|
+
**kwargs:
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
"""
|
|
302
|
+
LocalSagemakerClient._models[ModelName] = _LocalModel(ModelName, PrimaryContainer)
|
|
303
|
+
|
|
304
|
+
def describe_model(self, ModelName):
|
|
305
|
+
"""Describe the model.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
ModelName:
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
"""
|
|
312
|
+
if ModelName not in LocalSagemakerClient._models:
|
|
313
|
+
error_response = {
|
|
314
|
+
"Error": {"Code": "ValidationException", "Message": "Could not find local model"}
|
|
315
|
+
}
|
|
316
|
+
raise ClientError(error_response, "describe_model")
|
|
317
|
+
return LocalSagemakerClient._models[ModelName].describe()
|
|
318
|
+
|
|
319
|
+
def describe_endpoint_config(self, EndpointConfigName):
|
|
320
|
+
"""Describe the endpoint configuration.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
EndpointConfigName:
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
|
|
327
|
+
"""
|
|
328
|
+
if EndpointConfigName not in LocalSagemakerClient._endpoint_configs:
|
|
329
|
+
error_response = {
|
|
330
|
+
"Error": {
|
|
331
|
+
"Code": "ValidationException",
|
|
332
|
+
"Message": "Could not find local endpoint config",
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
raise ClientError(error_response, "describe_endpoint_config")
|
|
336
|
+
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()
|
|
337
|
+
|
|
338
|
+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint_config")
|
|
339
|
+
def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
|
|
340
|
+
"""Create the endpoint configuration.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
EndpointConfigName:
|
|
344
|
+
ProductionVariants:
|
|
345
|
+
Tags: (Default value = None)
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
|
|
349
|
+
"""
|
|
350
|
+
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
|
|
351
|
+
EndpointConfigName, ProductionVariants, format_tags(Tags)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def describe_endpoint(self, EndpointName):
|
|
355
|
+
"""Describe the endpoint.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
EndpointName:
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
|
|
362
|
+
"""
|
|
363
|
+
if EndpointName not in LocalSagemakerClient._endpoints:
|
|
364
|
+
error_response = {
|
|
365
|
+
"Error": {"Code": "ValidationException", "Message": "Could not find local endpoint"}
|
|
366
|
+
}
|
|
367
|
+
raise ClientError(error_response, "describe_endpoint")
|
|
368
|
+
return LocalSagemakerClient._endpoints[EndpointName].describe()
|
|
369
|
+
|
|
370
|
+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint")
|
|
371
|
+
def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
|
|
372
|
+
"""Create the endpoint.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
EndpointName:
|
|
376
|
+
EndpointConfigName:
|
|
377
|
+
Tags: (Default value = None)
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
|
|
381
|
+
"""
|
|
382
|
+
endpoint = _LocalEndpoint(
|
|
383
|
+
EndpointName,
|
|
384
|
+
EndpointConfigName,
|
|
385
|
+
format_tags(Tags),
|
|
386
|
+
self.sagemaker_session,
|
|
387
|
+
)
|
|
388
|
+
LocalSagemakerClient._endpoints[EndpointName] = endpoint
|
|
389
|
+
endpoint.serve()
|
|
390
|
+
|
|
391
|
+
def update_endpoint(self, EndpointName, EndpointConfigName): # pylint: disable=unused-argument
|
|
392
|
+
"""Update the endpoint.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
EndpointName:
|
|
396
|
+
EndpointConfigName:
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
|
|
400
|
+
"""
|
|
401
|
+
raise NotImplementedError("Update endpoint name is not supported in local session.")
|
|
402
|
+
|
|
403
|
+
def delete_endpoint(self, EndpointName):
|
|
404
|
+
"""Delete the endpoint.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
EndpointName:
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
|
|
411
|
+
"""
|
|
412
|
+
if EndpointName in LocalSagemakerClient._endpoints:
|
|
413
|
+
LocalSagemakerClient._endpoints[EndpointName].stop()
|
|
414
|
+
|
|
415
|
+
def delete_endpoint_config(self, EndpointConfigName):
|
|
416
|
+
"""Delete the endpoint configuration.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
EndpointConfigName:
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
|
|
423
|
+
"""
|
|
424
|
+
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
|
|
425
|
+
del LocalSagemakerClient._endpoint_configs[EndpointConfigName]
|
|
426
|
+
|
|
427
|
+
def delete_model(self, ModelName):
|
|
428
|
+
"""Delete the model.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
ModelName:
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
|
|
435
|
+
"""
|
|
436
|
+
if ModelName in LocalSagemakerClient._models:
|
|
437
|
+
del LocalSagemakerClient._models[ModelName]
|
|
438
|
+
|
|
439
|
+
# Pipeline methods have been moved to sagemaker.mlops.local.LocalPipelineSession
|
|
440
|
+
# For backward compatibility, see sagemaker.mlops.local package
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class LocalSagemakerRuntimeClient(object):
|
|
444
|
+
"""A SageMaker Runtime client that calls a local endpoint only."""
|
|
445
|
+
|
|
446
|
+
def __init__(self, config=None):
|
|
447
|
+
"""Initializes a LocalSageMakerRuntimeClient.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
config (dict): Optional configuration for this client. In particular only
|
|
451
|
+
the local port is read.
|
|
452
|
+
"""
|
|
453
|
+
try:
|
|
454
|
+
import urllib3
|
|
455
|
+
except ImportError as e:
|
|
456
|
+
logger.error(_module_import_error("urllib3", "Local mode", "local"))
|
|
457
|
+
raise e
|
|
458
|
+
|
|
459
|
+
self.http = urllib3.PoolManager()
|
|
460
|
+
self.serving_port = 8080
|
|
461
|
+
self.config = config
|
|
462
|
+
|
|
463
|
+
@property
|
|
464
|
+
def config(self) -> dict:
|
|
465
|
+
"""Local config getter"""
|
|
466
|
+
return self._config
|
|
467
|
+
|
|
468
|
+
@config.setter
|
|
469
|
+
def config(self, value: dict):
|
|
470
|
+
"""Local config setter, this method also updates the `serving_port` attribute.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
value (dict): the new config value
|
|
474
|
+
"""
|
|
475
|
+
self._config = value
|
|
476
|
+
self.serving_port = get_config_value("local.serving_port", self._config) or 8080
|
|
477
|
+
|
|
478
|
+
def invoke_endpoint(
|
|
479
|
+
self,
|
|
480
|
+
Body,
|
|
481
|
+
EndpointName, # pylint: disable=unused-argument
|
|
482
|
+
ContentType=None,
|
|
483
|
+
Accept=None,
|
|
484
|
+
CustomAttributes=None,
|
|
485
|
+
TargetModel=None,
|
|
486
|
+
TargetVariant=None,
|
|
487
|
+
InferenceId=None,
|
|
488
|
+
):
|
|
489
|
+
"""Invoke the endpoint.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
Body: Input data for which you want the model to provide inference.
|
|
493
|
+
EndpointName: The name of the endpoint that you specified when you
|
|
494
|
+
created the endpoint using the CreateEndpoint API.
|
|
495
|
+
ContentType: The MIME type of the input data in the request body (Default value = None)
|
|
496
|
+
Accept: The desired MIME type of the inference in the response (Default value = None)
|
|
497
|
+
CustomAttributes: Provides additional information about a request for an inference
|
|
498
|
+
submitted to a model hosted at an Amazon SageMaker endpoint (Default value = None)
|
|
499
|
+
TargetModel: The model to request for inference when invoking a multi-model endpoint
|
|
500
|
+
(Default value = None)
|
|
501
|
+
TargetVariant: Specify the production variant to send the inference request to when
|
|
502
|
+
invoking an endpoint that is running two or more variants (Default value = None)
|
|
503
|
+
InferenceId: If you provide a value, it is added to the captured data when you enable
|
|
504
|
+
data capture on the endpoint (Default value = None)
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
object: Inference for the given input.
|
|
508
|
+
"""
|
|
509
|
+
url = "http://%s:%d/invocations" % (get_docker_host(), self.serving_port)
|
|
510
|
+
headers = {}
|
|
511
|
+
|
|
512
|
+
if ContentType is not None:
|
|
513
|
+
headers["Content-type"] = ContentType
|
|
514
|
+
|
|
515
|
+
if Accept is not None:
|
|
516
|
+
headers["Accept"] = Accept
|
|
517
|
+
|
|
518
|
+
if CustomAttributes is not None:
|
|
519
|
+
headers["X-Amzn-SageMaker-Custom-Attributes"] = CustomAttributes
|
|
520
|
+
|
|
521
|
+
if TargetModel is not None:
|
|
522
|
+
headers["X-Amzn-SageMaker-Target-Model"] = TargetModel
|
|
523
|
+
|
|
524
|
+
if TargetVariant is not None:
|
|
525
|
+
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant
|
|
526
|
+
|
|
527
|
+
if InferenceId is not None:
|
|
528
|
+
headers["X-Amzn-SageMaker-Inference-Id"] = InferenceId
|
|
529
|
+
|
|
530
|
+
# The http client encodes all strings using latin-1, which is not what we want.
|
|
531
|
+
if isinstance(Body, str):
|
|
532
|
+
Body = Body.encode("utf-8")
|
|
533
|
+
r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)
|
|
534
|
+
|
|
535
|
+
return {"Body": r, "ContentType": Accept}
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class LocalSession(Session):
|
|
539
|
+
"""A SageMaker ``Session`` class for Local Mode.
|
|
540
|
+
|
|
541
|
+
This class provides alternative Local Mode implementations for the functionality of
|
|
542
|
+
:class:`~sagemaker.core.helper.session.Session`.
|
|
543
|
+
"""
|
|
544
|
+
|
|
545
|
+
def __init__(
|
|
546
|
+
self,
|
|
547
|
+
boto_session=None,
|
|
548
|
+
default_bucket=None,
|
|
549
|
+
s3_endpoint_url=None,
|
|
550
|
+
disable_local_code=False,
|
|
551
|
+
sagemaker_config: dict = None,
|
|
552
|
+
default_bucket_prefix=None,
|
|
553
|
+
):
|
|
554
|
+
"""Create a Local SageMaker Session.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
|
|
558
|
+
calls are delegated to (default: None). If not provided, one is created with
|
|
559
|
+
default AWS configuration chain.
|
|
560
|
+
s3_endpoint_url (str): Override the default endpoint URL for Amazon S3, if set
|
|
561
|
+
(default: None).
|
|
562
|
+
disable_local_code (bool): Set ``True`` to override the default AWS configuration
|
|
563
|
+
chain to disable the ``local.local_code`` setting, which may not be supported for
|
|
564
|
+
some SDK features (default: False).
|
|
565
|
+
sagemaker_config: A dictionary containing default values for the
|
|
566
|
+
SageMaker Python SDK. (default: None). The dictionary must adhere to the schema
|
|
567
|
+
defined at `~sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA`.
|
|
568
|
+
If sagemaker_config is not provided and configuration files exist (at the default
|
|
569
|
+
paths for admins and users, or paths set through the environment variables
|
|
570
|
+
SAGEMAKER_ADMIN_CONFIG_OVERRIDE and SAGEMAKER_USER_CONFIG_OVERRIDE),
|
|
571
|
+
a new dictionary will be generated from those configuration files. Alternatively,
|
|
572
|
+
this dictionary can be generated by calling
|
|
573
|
+
:func:`~sagemaker.config.load_sagemaker_config` and then be provided to the
|
|
574
|
+
Session.
|
|
575
|
+
default_bucket_prefix (str): The default prefix to use for S3 Object Keys. When
|
|
576
|
+
objects are saved to the Session's default_bucket, the Object Key used will
|
|
577
|
+
start with the default_bucket_prefix. If not provided here or within
|
|
578
|
+
sagemaker_config, no additional prefix will be added.
|
|
579
|
+
"""
|
|
580
|
+
self.s3_endpoint_url = s3_endpoint_url
|
|
581
|
+
# We use this local variable to avoid disrupting the __init__->_initialize API of the
|
|
582
|
+
# parent class... But overwriting it after constructor won't do anything, so prefix _ to
|
|
583
|
+
# discourage external use:
|
|
584
|
+
self._disable_local_code = disable_local_code
|
|
585
|
+
|
|
586
|
+
super(LocalSession, self).__init__(
|
|
587
|
+
boto_session=boto_session,
|
|
588
|
+
default_bucket=default_bucket,
|
|
589
|
+
sagemaker_config=sagemaker_config,
|
|
590
|
+
default_bucket_prefix=default_bucket_prefix,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if platform.system() == "Windows":
|
|
594
|
+
logger.warning("Windows Support for Local Mode is Experimental")
|
|
595
|
+
|
|
596
|
+
def _initialize(
|
|
597
|
+
self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs
|
|
598
|
+
): # pylint: disable=unused-argument
|
|
599
|
+
"""Initialize this Local SageMaker Session.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
boto_session:
|
|
603
|
+
sagemaker_client:
|
|
604
|
+
sagemaker_runtime_client:
|
|
605
|
+
kwargs:
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
if boto_session is None:
|
|
612
|
+
self.boto_session = boto3.Session()
|
|
613
|
+
else:
|
|
614
|
+
self.boto_session = boto_session
|
|
615
|
+
|
|
616
|
+
self._region_name = self.boto_session.region_name
|
|
617
|
+
|
|
618
|
+
if self._region_name is None:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
"Must setup local AWS configuration with a region supported by SageMaker."
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
self.sagemaker_client = LocalSagemakerClient(self)
|
|
624
|
+
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
|
|
625
|
+
|
|
626
|
+
self.local_mode = True
|
|
627
|
+
sagemaker_config = kwargs.get("sagemaker_config", None)
|
|
628
|
+
if sagemaker_config:
|
|
629
|
+
validate_sagemaker_config(sagemaker_config)
|
|
630
|
+
|
|
631
|
+
if self.s3_endpoint_url is not None:
|
|
632
|
+
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
|
|
633
|
+
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
|
|
634
|
+
self.sagemaker_config = (
|
|
635
|
+
sagemaker_config
|
|
636
|
+
if sagemaker_config
|
|
637
|
+
else load_sagemaker_config(s3_resource=self.s3_resource)
|
|
638
|
+
)
|
|
639
|
+
else:
|
|
640
|
+
self.s3_resource = self.boto_session.resource("s3", region_name=self._region_name)
|
|
641
|
+
self.s3_client = self.boto_session.client("s3", region_name=self._region_name)
|
|
642
|
+
self.sagemaker_config = (
|
|
643
|
+
sagemaker_config if sagemaker_config else load_sagemaker_config()
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
sagemaker_config = kwargs.get("sagemaker_config", None)
|
|
647
|
+
if sagemaker_config:
|
|
648
|
+
validate_sagemaker_config(sagemaker_config)
|
|
649
|
+
self.sagemaker_config = sagemaker_config
|
|
650
|
+
else:
|
|
651
|
+
# self.s3_resource might be None. If it is None, load_sagemaker_config will
|
|
652
|
+
# create a default S3 resource, but only if it needs to fetch from S3
|
|
653
|
+
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
|
|
654
|
+
|
|
655
|
+
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
|
|
656
|
+
self._default_bucket_name_override = resolve_value_from_config(
|
|
657
|
+
direct_input=self._default_bucket_name_override,
|
|
658
|
+
config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
|
|
659
|
+
sagemaker_session=self,
|
|
660
|
+
)
|
|
661
|
+
# after sagemaker_config initialization, update self.default_bucket_prefix if needed
|
|
662
|
+
self.default_bucket_prefix = resolve_value_from_config(
|
|
663
|
+
direct_input=self.default_bucket_prefix,
|
|
664
|
+
config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
|
|
665
|
+
sagemaker_session=self,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
self.config = load_local_mode_config()
|
|
669
|
+
if self._disable_local_code and self.config and "local" in self.config:
|
|
670
|
+
self.config["local"]["local_code"] = False
|
|
671
|
+
|
|
672
|
+
@Session.config.setter
|
|
673
|
+
def config(self, value: Dict | None):
|
|
674
|
+
"""Setter of the local mode config"""
|
|
675
|
+
if value is not None:
|
|
676
|
+
try:
|
|
677
|
+
jsonschema.validate(value, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)
|
|
678
|
+
except jsonschema.ValidationError as e:
|
|
679
|
+
logger.error("Failed to validate the local mode config")
|
|
680
|
+
raise e
|
|
681
|
+
self._config = value
|
|
682
|
+
else:
|
|
683
|
+
self._config = value
|
|
684
|
+
|
|
685
|
+
# update the runtime client on config changed
|
|
686
|
+
if getattr(self, "sagemaker_runtime_client", None):
|
|
687
|
+
self.sagemaker_runtime_client.config = self._config
|
|
688
|
+
|
|
689
|
+
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
|
|
690
|
+
"""A no-op method meant to override the sagemaker client.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
job_name:
|
|
694
|
+
wait: (Default value = False)
|
|
695
|
+
poll: (Default value = 5)
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
|
|
699
|
+
"""
|
|
700
|
+
# override logs_for_job() as it doesn't need to perform any action
|
|
701
|
+
# on local mode.
|
|
702
|
+
pass # pylint: disable=unnecessary-pass
|
|
703
|
+
|
|
704
|
+
def logs_for_processing_job(self, job_name, wait=False, poll=10):
|
|
705
|
+
"""A no-op method meant to override the sagemaker client.
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
job_name:
|
|
709
|
+
wait: (Default value = False)
|
|
710
|
+
poll: (Default value = 10)
|
|
711
|
+
|
|
712
|
+
Returns:
|
|
713
|
+
|
|
714
|
+
"""
|
|
715
|
+
# override logs_for_job() as it doesn't need to perform any action
|
|
716
|
+
# on local mode.
|
|
717
|
+
pass # pylint: disable=unnecessary-pass
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
class FileInput(object):
|
|
721
|
+
"""Amazon SageMaker channel configuration for FILE data sources, used in local mode."""
|
|
722
|
+
|
|
723
|
+
def __init__(self, fileUri, content_type=None):
|
|
724
|
+
"""Create a definition for input data used by an SageMaker training job in local mode."""
|
|
725
|
+
self.config = {
|
|
726
|
+
"DataSource": {
|
|
727
|
+
"FileDataSource": {
|
|
728
|
+
"FileDataDistributionType": "FullyReplicated",
|
|
729
|
+
"FileUri": fileUri,
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
if content_type is not None:
|
|
735
|
+
self.config["ContentType"] = content_type
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
# Backward compatibility alias
|
|
739
|
+
file_input = FileInput
|