sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1597 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains code related to the ``Processor`` class.
|
|
14
|
+
|
|
15
|
+
which is used for Amazon SageMaker Processing Jobs. These jobs let users perform
|
|
16
|
+
data pre-processing, post-processing, feature engineering, data validation, and model evaluation,
|
|
17
|
+
and interpretation on Amazon SageMaker.
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import absolute_import
|
|
20
|
+
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
import pathlib
|
|
25
|
+
import re
|
|
26
|
+
from typing import Dict, List, Optional, Union
|
|
27
|
+
import time
|
|
28
|
+
from copy import copy
|
|
29
|
+
from textwrap import dedent
|
|
30
|
+
from six.moves.urllib.parse import urlparse
|
|
31
|
+
from six.moves.urllib.request import url2pathname
|
|
32
|
+
from sagemaker.core.network import NetworkConfig
|
|
33
|
+
from sagemaker.core import s3
|
|
34
|
+
from sagemaker.core.apiutils._base_types import ApiObject
|
|
35
|
+
from sagemaker.core.config.config_schema import (
|
|
36
|
+
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
|
|
37
|
+
PROCESSING_JOB_ENVIRONMENT_PATH,
|
|
38
|
+
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
|
|
39
|
+
PROCESSING_JOB_KMS_KEY_ID_PATH,
|
|
40
|
+
PROCESSING_JOB_ROLE_ARN_PATH,
|
|
41
|
+
PROCESSING_JOB_SECURITY_GROUP_IDS_PATH,
|
|
42
|
+
PROCESSING_JOB_SUBNETS_PATH,
|
|
43
|
+
PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH,
|
|
44
|
+
PROCESSING_JOB_INPUTS_PATH,
|
|
45
|
+
PROCESSING_JOB_NETWORK_CONFIG_PATH,
|
|
46
|
+
PROCESSING_OUTPUT_CONFIG_PATH,
|
|
47
|
+
PROCESSING_JOB_PROCESSING_RESOURCES_PATH,
|
|
48
|
+
SAGEMAKER,
|
|
49
|
+
PROCESSING_JOB,
|
|
50
|
+
TAGS,
|
|
51
|
+
)
|
|
52
|
+
from sagemaker.core.local.local_session import LocalSession
|
|
53
|
+
from sagemaker.core.helper.session_helper import Session
|
|
54
|
+
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input
|
|
55
|
+
from sagemaker.core.resources import ProcessingJob
|
|
56
|
+
from sagemaker.core.workflow.pipeline_context import PipelineSession
|
|
57
|
+
from sagemaker.core.common_utils import (
|
|
58
|
+
Tags,
|
|
59
|
+
base_name_from_image,
|
|
60
|
+
check_and_get_run_experiment_config,
|
|
61
|
+
format_tags,
|
|
62
|
+
name_from_base,
|
|
63
|
+
resolve_class_attribute_from_config,
|
|
64
|
+
resolve_value_from_config,
|
|
65
|
+
resolve_nested_dict_value_from_config,
|
|
66
|
+
update_list_of_dicts_with_values_from_config,
|
|
67
|
+
update_nested_dictionary_with_values_from_config,
|
|
68
|
+
_get_initial_job_state,
|
|
69
|
+
_wait_until,
|
|
70
|
+
_flush_log_streams,
|
|
71
|
+
_logs_init,
|
|
72
|
+
LogState,
|
|
73
|
+
_check_job_status,
|
|
74
|
+
)
|
|
75
|
+
from sagemaker.core.workflow import is_pipeline_variable
|
|
76
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
77
|
+
from sagemaker.core.workflow.execution_variables import ExecutionVariables
|
|
78
|
+
from sagemaker.core.workflow.functions import Join
|
|
79
|
+
from sagemaker.core.workflow.pipeline_context import runnable_by_pipeline
|
|
80
|
+
|
|
81
|
+
from sagemaker.core._studio import _append_project_tags
|
|
82
|
+
from sagemaker.core.config.config_utils import _append_sagemaker_config_tags
|
|
83
|
+
from sagemaker.core.utils.utils import serialize
|
|
84
|
+
|
|
85
|
+
logger = logging.getLogger(__name__)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class Processor(object):
|
|
89
|
+
"""Handles Amazon SageMaker Processing tasks."""
|
|
90
|
+
|
|
91
|
+
JOB_CLASS_NAME = "processing-job"
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
role: str = None,
|
|
96
|
+
image_uri: Union[str, PipelineVariable] = None,
|
|
97
|
+
instance_count: Union[int, PipelineVariable] = None,
|
|
98
|
+
instance_type: Union[str, PipelineVariable] = None,
|
|
99
|
+
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
100
|
+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
|
|
101
|
+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
102
|
+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
103
|
+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
|
|
104
|
+
base_job_name: Optional[str] = None,
|
|
105
|
+
sagemaker_session: Optional[Session] = None,
|
|
106
|
+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
107
|
+
tags: Optional[Tags] = None,
|
|
108
|
+
network_config: Optional[NetworkConfig] = None,
|
|
109
|
+
):
|
|
110
|
+
"""Initializes a ``Processor`` instance.
|
|
111
|
+
|
|
112
|
+
The ``Processor`` handles Amazon SageMaker Processing tasks.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker Processing
|
|
116
|
+
uses this role to access AWS resources, such as
|
|
117
|
+
data stored in Amazon S3.
|
|
118
|
+
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
|
|
119
|
+
processing jobs.
|
|
120
|
+
instance_count (int or PipelineVariable): The number of instances to run
|
|
121
|
+
a processing job with.
|
|
122
|
+
instance_type (str or PipelineVariable): The type of EC2 instance to use for
|
|
123
|
+
processing, for example, 'ml.c4.xlarge'.
|
|
124
|
+
entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the
|
|
125
|
+
processing job (default: None). This is in the form of a list of strings
|
|
126
|
+
that make a command.
|
|
127
|
+
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
|
|
128
|
+
to use for storing data during processing (default: 30).
|
|
129
|
+
volume_kms_key (str or PipelineVariable): A KMS key for the processing
|
|
130
|
+
volume (default: None).
|
|
131
|
+
output_kms_key (str or PipelineVariable): The KMS key ID for processing job
|
|
132
|
+
outputs (default: None).
|
|
133
|
+
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds (default: None).
|
|
134
|
+
After this amount of time, Amazon SageMaker terminates the job,
|
|
135
|
+
regardless of its current status. If `max_runtime_in_seconds` is not
|
|
136
|
+
specified, the default value is 24 hours.
|
|
137
|
+
base_job_name (str): Prefix for processing job name. If not specified,
|
|
138
|
+
the processor generates a default job name, based on the
|
|
139
|
+
processing image name and current timestamp.
|
|
140
|
+
sagemaker_session (:class:`~sagemaker.session.Session`):
|
|
141
|
+
Session object which manages interactions with Amazon SageMaker and
|
|
142
|
+
any other AWS services needed. If not specified, the processor creates
|
|
143
|
+
one using the default AWS configuration chain.
|
|
144
|
+
env (dict[str, str] or dict[str, PipelineVariable]): Environment variables
|
|
145
|
+
to be passed to the processing jobs (default: None).
|
|
146
|
+
tags (Optional[Tags]): Tags to be passed to the processing job (default: None).
|
|
147
|
+
For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
|
|
148
|
+
network_config (:class:`~sagemaker.network.NetworkConfig`):
|
|
149
|
+
A :class:`~sagemaker.network.NetworkConfig`
|
|
150
|
+
object that configures network isolation, encryption of
|
|
151
|
+
inter-container traffic, security group IDs, and subnets.
|
|
152
|
+
"""
|
|
153
|
+
self.image_uri = image_uri
|
|
154
|
+
self.instance_count = instance_count
|
|
155
|
+
self.instance_type = instance_type
|
|
156
|
+
self.entrypoint = entrypoint
|
|
157
|
+
self.volume_size_in_gb = volume_size_in_gb
|
|
158
|
+
self.max_runtime_in_seconds = max_runtime_in_seconds
|
|
159
|
+
self.base_job_name = base_job_name
|
|
160
|
+
self.tags = format_tags(tags)
|
|
161
|
+
|
|
162
|
+
self.jobs = []
|
|
163
|
+
self.latest_job = None
|
|
164
|
+
self._current_job_name = None
|
|
165
|
+
self.arguments = None
|
|
166
|
+
|
|
167
|
+
if self.instance_type in ("local", "local_gpu"):
|
|
168
|
+
if not isinstance(sagemaker_session, LocalSession):
|
|
169
|
+
# Until Local Mode Processing supports local code, we need to disable it:
|
|
170
|
+
sagemaker_session = LocalSession(disable_local_code=True)
|
|
171
|
+
|
|
172
|
+
self.sagemaker_session = sagemaker_session or Session()
|
|
173
|
+
self.output_kms_key = resolve_value_from_config(
|
|
174
|
+
output_kms_key, PROCESSING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
|
|
175
|
+
)
|
|
176
|
+
self.volume_kms_key = resolve_value_from_config(
|
|
177
|
+
volume_kms_key,
|
|
178
|
+
PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH,
|
|
179
|
+
sagemaker_session=self.sagemaker_session,
|
|
180
|
+
)
|
|
181
|
+
self.network_config = resolve_class_attribute_from_config(
|
|
182
|
+
NetworkConfig,
|
|
183
|
+
network_config,
|
|
184
|
+
"subnets",
|
|
185
|
+
PROCESSING_JOB_SUBNETS_PATH,
|
|
186
|
+
sagemaker_session=self.sagemaker_session,
|
|
187
|
+
)
|
|
188
|
+
self.network_config = resolve_class_attribute_from_config(
|
|
189
|
+
NetworkConfig,
|
|
190
|
+
self.network_config,
|
|
191
|
+
"security_group_ids",
|
|
192
|
+
PROCESSING_JOB_SECURITY_GROUP_IDS_PATH,
|
|
193
|
+
sagemaker_session=self.sagemaker_session,
|
|
194
|
+
)
|
|
195
|
+
self.network_config = resolve_class_attribute_from_config(
|
|
196
|
+
NetworkConfig,
|
|
197
|
+
self.network_config,
|
|
198
|
+
"enable_network_isolation",
|
|
199
|
+
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
|
|
200
|
+
sagemaker_session=self.sagemaker_session,
|
|
201
|
+
)
|
|
202
|
+
self.network_config = resolve_class_attribute_from_config(
|
|
203
|
+
NetworkConfig,
|
|
204
|
+
self.network_config,
|
|
205
|
+
"encrypt_inter_container_traffic",
|
|
206
|
+
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
|
|
207
|
+
sagemaker_session=self.sagemaker_session,
|
|
208
|
+
)
|
|
209
|
+
self.role = resolve_value_from_config(
|
|
210
|
+
role, PROCESSING_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
|
|
211
|
+
)
|
|
212
|
+
if not self.role:
|
|
213
|
+
# Originally IAM role was a required parameter.
|
|
214
|
+
# Now we marked that as Optional because we can fetch it from SageMakerConfig
|
|
215
|
+
# Because of marking that parameter as optional, we should validate if it is None, even
|
|
216
|
+
# after fetching the config.
|
|
217
|
+
raise ValueError("An AWS IAM role is required to create a Processing job.")
|
|
218
|
+
|
|
219
|
+
self.env = resolve_value_from_config(
|
|
220
|
+
env, PROCESSING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
@runnable_by_pipeline
|
|
224
|
+
def run(
|
|
225
|
+
self,
|
|
226
|
+
inputs: Optional[List[ProcessingInput]] = None,
|
|
227
|
+
outputs: Optional[List[ProcessingOutput]] = None,
|
|
228
|
+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
229
|
+
wait: bool = True,
|
|
230
|
+
logs: bool = True,
|
|
231
|
+
job_name: Optional[str] = None,
|
|
232
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
233
|
+
kms_key: Optional[str] = None,
|
|
234
|
+
):
|
|
235
|
+
"""Runs a processing job.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
inputs (list[:class:`~sagemaker.core.shapes.ProcessingInput`]): Input files for
|
|
239
|
+
the processing job. These must be provided as
|
|
240
|
+
:class:`~sagemaker.core.shapes.ProcessingInput` objects (default: None).
|
|
241
|
+
outputs (list[:class:`~sagemaker.core.shapes.ProcessingOutput`]): Outputs for
|
|
242
|
+
the processing job. These can be specified as either path strings or
|
|
243
|
+
:class:`~sagemaker.core.shapes.ProcessingOutput` objects (default: None).
|
|
244
|
+
arguments (list[str] or list[PipelineVariable]): A list of string arguments
|
|
245
|
+
to be passed to a processing job (default: None).
|
|
246
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
247
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
248
|
+
Only meaningful when ``wait`` is True (default: True).
|
|
249
|
+
job_name (str): Processing job name. If not specified, the processor generates
|
|
250
|
+
a default job name, based on the base job name and current timestamp.
|
|
251
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
252
|
+
Optionally, the dict can contain three keys:
|
|
253
|
+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
|
|
254
|
+
The behavior of setting these keys is as follows:
|
|
255
|
+
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
|
|
256
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
257
|
+
* If `TrialName` is supplied and the Trial already exists the job's Trial Component
|
|
258
|
+
will be associated with the Trial.
|
|
259
|
+
* If both `ExperimentName` and `TrialName` are not supplied the trial component
|
|
260
|
+
will be unassociated.
|
|
261
|
+
* `TrialComponentDisplayName` is used for display in Studio.
|
|
262
|
+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
|
|
263
|
+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
|
|
264
|
+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
|
|
265
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
266
|
+
user code file (default: None).
|
|
267
|
+
Returns:
|
|
268
|
+
None or pipeline step arguments in case the Processor instance is built with
|
|
269
|
+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
|
|
270
|
+
Raises:
|
|
271
|
+
ValueError: if ``logs`` is True but ``wait`` is False.
|
|
272
|
+
"""
|
|
273
|
+
if logs and not wait:
|
|
274
|
+
raise ValueError(
|
|
275
|
+
"""Logs can only be shown if wait is set to True.
|
|
276
|
+
Please either set wait to True or set logs to False."""
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
normalized_inputs, normalized_outputs = self._normalize_args(
|
|
280
|
+
job_name=job_name,
|
|
281
|
+
arguments=arguments,
|
|
282
|
+
inputs=inputs,
|
|
283
|
+
kms_key=kms_key,
|
|
284
|
+
outputs=outputs,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
experiment_config = check_and_get_run_experiment_config(experiment_config)
|
|
288
|
+
self.latest_job = self._start_new(
|
|
289
|
+
inputs=normalized_inputs,
|
|
290
|
+
outputs=normalized_outputs,
|
|
291
|
+
experiment_config=experiment_config,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
if not isinstance(self.sagemaker_session, PipelineSession):
|
|
295
|
+
self.jobs.append(self.latest_job)
|
|
296
|
+
if wait:
|
|
297
|
+
self.latest_job.wait(logs=logs)
|
|
298
|
+
|
|
299
|
+
def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
|
|
300
|
+
"""Extend inputs and outputs based on extra parameters"""
|
|
301
|
+
return inputs, outputs
|
|
302
|
+
|
|
303
|
+
def _normalize_args(
|
|
304
|
+
self,
|
|
305
|
+
job_name=None,
|
|
306
|
+
arguments=None,
|
|
307
|
+
inputs=None,
|
|
308
|
+
outputs=None,
|
|
309
|
+
code=None,
|
|
310
|
+
kms_key=None,
|
|
311
|
+
):
|
|
312
|
+
"""Normalizes the arguments so that they can be passed to the job run
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
job_name (str): Name of the processing job to be created. If not specified, one
|
|
316
|
+
is generated, using the base name given to the constructor, if applicable
|
|
317
|
+
(default: None).
|
|
318
|
+
arguments (list[str]): A list of string arguments to be passed to a
|
|
319
|
+
processing job (default: None).
|
|
320
|
+
inputs (list[:class:`~sagemaker.core.shapes.ProcessingInput`]): Input files for
|
|
321
|
+
the processing job. These must be provided as
|
|
322
|
+
:class:`~sagemaker.core.shapes.ProcessingInput` objects (default: None).
|
|
323
|
+
outputs (list[:class:`~sagemaker.core.shapes.ProcessingOutput`]): Outputs for
|
|
324
|
+
the processing job. These can be specified as either path strings or
|
|
325
|
+
:class:`~sagemaker.core.shapes.ProcessingOutput` objects (default: None).
|
|
326
|
+
code (str): This can be an S3 URI or a local path to a file with the framework
|
|
327
|
+
script to run (default: None). A no op in the base class.
|
|
328
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
329
|
+
user code file (default: None).
|
|
330
|
+
"""
|
|
331
|
+
if code and is_pipeline_variable(code):
|
|
332
|
+
raise ValueError(
|
|
333
|
+
"code argument has to be a valid S3 URI or local file path "
|
|
334
|
+
+ "rather than a pipeline variable"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
self._current_job_name = self._generate_current_job_name(job_name=job_name)
|
|
338
|
+
|
|
339
|
+
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)
|
|
340
|
+
normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key)
|
|
341
|
+
normalized_outputs = self._normalize_outputs(outputs)
|
|
342
|
+
self.arguments = arguments
|
|
343
|
+
|
|
344
|
+
return normalized_inputs, normalized_outputs
|
|
345
|
+
|
|
346
|
+
def _include_code_in_inputs(self, inputs, _code, _kms_key):
|
|
347
|
+
"""A no op in the base class to include code in the processing job inputs.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
inputs (list[:class:`~sagemaker.core.shapes.ProcessingInput`]): Input files for
|
|
351
|
+
the processing job. These must be provided as
|
|
352
|
+
:class:`~sagemaker.core.shapes.ProcessingInput` objects.
|
|
353
|
+
_code (str): This can be an S3 URI or a local path to a file with the framework
|
|
354
|
+
script to run (default: None). A no op in the base class.
|
|
355
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
356
|
+
user code file (default: None).
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
list[:class:`~sagemaker.core.shapes.ProcessingInput`]: inputs
|
|
360
|
+
"""
|
|
361
|
+
return inputs
|
|
362
|
+
|
|
363
|
+
def _generate_current_job_name(self, job_name=None):
|
|
364
|
+
"""Generates the job name before running a processing job.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
job_name (str): Name of the processing job to be created. If not
|
|
368
|
+
specified, one is generated, using the base name given to the
|
|
369
|
+
constructor if applicable.
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
str: The supplied or generated job name.
|
|
373
|
+
"""
|
|
374
|
+
if job_name is not None:
|
|
375
|
+
return job_name
|
|
376
|
+
# Honor supplied base_job_name or generate it.
|
|
377
|
+
if self.base_job_name:
|
|
378
|
+
base_name = self.base_job_name
|
|
379
|
+
else:
|
|
380
|
+
base_name = base_name_from_image(
|
|
381
|
+
self.image_uri, default_base_name=Processor.JOB_CLASS_NAME
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Replace invalid characters with hyphens to comply with AWS naming constraints
|
|
385
|
+
base_name = re.sub(r"[^a-zA-Z0-9-]", "-", base_name)
|
|
386
|
+
return name_from_base(base_name)
|
|
387
|
+
|
|
388
|
+
def _normalize_inputs(self, inputs=None, kms_key=None):
|
|
389
|
+
"""Ensures that all the ``ProcessingInput`` objects have names and S3 URIs.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
inputs (list[sagemaker.core.shapes.ProcessingInput]): A list of ``ProcessingInput``
|
|
393
|
+
objects to be normalized (default: None). If not specified,
|
|
394
|
+
an empty list is returned.
|
|
395
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
396
|
+
user code file (default: None).
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
list[sagemaker.core.shapes.ProcessingInput]: The list of normalized
|
|
400
|
+
``ProcessingInput`` objects.
|
|
401
|
+
|
|
402
|
+
Raises:
|
|
403
|
+
TypeError: if the inputs are not ``ProcessingInput`` objects.
|
|
404
|
+
"""
|
|
405
|
+
from sagemaker.core.workflow.utilities import _pipeline_config
|
|
406
|
+
|
|
407
|
+
# Initialize a list of normalized ProcessingInput objects.
|
|
408
|
+
normalized_inputs = []
|
|
409
|
+
if inputs is not None:
|
|
410
|
+
# Iterate through the provided list of inputs.
|
|
411
|
+
for count, file_input in enumerate(inputs, 1):
|
|
412
|
+
if not isinstance(file_input, ProcessingInput):
|
|
413
|
+
raise TypeError("Your inputs must be provided as ProcessingInput objects.")
|
|
414
|
+
# Generate a name for the ProcessingInput if it doesn't have one.
|
|
415
|
+
if file_input.input_name is None:
|
|
416
|
+
file_input.input_name = "input-{}".format(count)
|
|
417
|
+
|
|
418
|
+
if file_input.dataset_definition:
|
|
419
|
+
normalized_inputs.append(file_input)
|
|
420
|
+
continue
|
|
421
|
+
if file_input.s3_input and is_pipeline_variable(file_input.s3_input.s3_uri):
|
|
422
|
+
normalized_inputs.append(file_input)
|
|
423
|
+
continue
|
|
424
|
+
# If the s3_uri is not an s3_uri, create one.
|
|
425
|
+
parse_result = urlparse(file_input.s3_input.s3_uri)
|
|
426
|
+
if parse_result.scheme != "s3":
|
|
427
|
+
if _pipeline_config:
|
|
428
|
+
desired_s3_uri = s3.s3_path_join(
|
|
429
|
+
"s3://",
|
|
430
|
+
self.sagemaker_session.default_bucket(),
|
|
431
|
+
self.sagemaker_session.default_bucket_prefix,
|
|
432
|
+
_pipeline_config.pipeline_name,
|
|
433
|
+
_pipeline_config.step_name,
|
|
434
|
+
"input",
|
|
435
|
+
file_input.input_name,
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
desired_s3_uri = s3.s3_path_join(
|
|
439
|
+
"s3://",
|
|
440
|
+
self.sagemaker_session.default_bucket(),
|
|
441
|
+
self.sagemaker_session.default_bucket_prefix,
|
|
442
|
+
self._current_job_name,
|
|
443
|
+
"input",
|
|
444
|
+
file_input.input_name,
|
|
445
|
+
)
|
|
446
|
+
s3_uri = s3.S3Uploader.upload(
|
|
447
|
+
local_path=file_input.s3_input.s3_uri,
|
|
448
|
+
desired_s3_uri=desired_s3_uri,
|
|
449
|
+
sagemaker_session=self.sagemaker_session,
|
|
450
|
+
kms_key=kms_key,
|
|
451
|
+
)
|
|
452
|
+
file_input.s3_input.s3_uri = s3_uri
|
|
453
|
+
normalized_inputs.append(file_input)
|
|
454
|
+
return normalized_inputs
|
|
455
|
+
|
|
456
|
+
def _normalize_outputs(self, outputs=None):
|
|
457
|
+
"""Ensures that all the outputs are ``ProcessingOutput`` objects with names and S3 URIs.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
outputs (list[sagemaker.core.shapes.ProcessingOutput]): A list
|
|
461
|
+
of outputs to be normalized (default: None). Can be either strings or
|
|
462
|
+
``ProcessingOutput`` objects. If not specified,
|
|
463
|
+
an empty list is returned.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
list[sagemaker.core.shapes.ProcessingOutput]: The list of normalized
|
|
467
|
+
``ProcessingOutput`` objects.
|
|
468
|
+
|
|
469
|
+
Raises:
|
|
470
|
+
TypeError: if the outputs are not ``ProcessingOutput`` objects.
|
|
471
|
+
"""
|
|
472
|
+
# Initialize a list of normalized ProcessingOutput objects.
|
|
473
|
+
from sagemaker.core.workflow.utilities import _pipeline_config
|
|
474
|
+
|
|
475
|
+
normalized_outputs = []
|
|
476
|
+
if outputs is not None:
|
|
477
|
+
# Iterate through the provided list of outputs.
|
|
478
|
+
for count, output in enumerate(outputs, 1):
|
|
479
|
+
if not isinstance(output, ProcessingOutput):
|
|
480
|
+
raise TypeError("Your outputs must be provided as ProcessingOutput objects.")
|
|
481
|
+
# Generate a name for the ProcessingOutput if it doesn't have one.
|
|
482
|
+
if output.output_name is None:
|
|
483
|
+
output.output_name = "output-{}".format(count)
|
|
484
|
+
if output.s3_output and is_pipeline_variable(output.s3_output.s3_uri):
|
|
485
|
+
normalized_outputs.append(output)
|
|
486
|
+
continue
|
|
487
|
+
# If the output's s3_uri is not an s3_uri, create one.
|
|
488
|
+
parse_result = urlparse(output.s3_output.s3_uri)
|
|
489
|
+
if parse_result.scheme != "s3":
|
|
490
|
+
if _pipeline_config:
|
|
491
|
+
s3_uri = Join(
|
|
492
|
+
on="/",
|
|
493
|
+
values=[
|
|
494
|
+
"s3:/",
|
|
495
|
+
self.sagemaker_session.default_bucket(),
|
|
496
|
+
*(
|
|
497
|
+
# don't include default_bucket_prefix if it is None or ""
|
|
498
|
+
[self.sagemaker_session.default_bucket_prefix]
|
|
499
|
+
if self.sagemaker_session.default_bucket_prefix
|
|
500
|
+
else []
|
|
501
|
+
),
|
|
502
|
+
_pipeline_config.pipeline_name,
|
|
503
|
+
ExecutionVariables.PIPELINE_EXECUTION_ID,
|
|
504
|
+
_pipeline_config.step_name,
|
|
505
|
+
"output",
|
|
506
|
+
output.output_name,
|
|
507
|
+
],
|
|
508
|
+
)
|
|
509
|
+
else:
|
|
510
|
+
s3_uri = s3.s3_path_join(
|
|
511
|
+
"s3://",
|
|
512
|
+
self.sagemaker_session.default_bucket(),
|
|
513
|
+
self.sagemaker_session.default_bucket_prefix,
|
|
514
|
+
self._current_job_name,
|
|
515
|
+
"output",
|
|
516
|
+
output.output_name,
|
|
517
|
+
)
|
|
518
|
+
output.s3_output.s3_uri = s3_uri
|
|
519
|
+
normalized_outputs.append(output)
|
|
520
|
+
return normalized_outputs
|
|
521
|
+
|
|
522
|
+
def _start_new(self, inputs, outputs, experiment_config):
|
|
523
|
+
"""Starts a new processing job and returns ProcessingJob instance."""
|
|
524
|
+
from sagemaker.core.workflow.pipeline_context import PipelineSession
|
|
525
|
+
|
|
526
|
+
process_args = self._get_process_args(inputs, outputs, experiment_config)
|
|
527
|
+
|
|
528
|
+
logger.debug("Job Name: %s", process_args["job_name"])
|
|
529
|
+
logger.debug("Inputs: %s", process_args["inputs"])
|
|
530
|
+
logger.debug("Outputs: %s", process_args["output_config"]["Outputs"])
|
|
531
|
+
|
|
532
|
+
tags = _append_project_tags(format_tags(process_args["tags"]))
|
|
533
|
+
tags = _append_sagemaker_config_tags(
|
|
534
|
+
self.sagemaker_session, tags, "{}.{}.{}".format(SAGEMAKER, PROCESSING_JOB, TAGS)
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
network_config = resolve_nested_dict_value_from_config(
|
|
538
|
+
process_args["network_config"],
|
|
539
|
+
["EnableInterContainerTrafficEncryption"],
|
|
540
|
+
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
|
|
541
|
+
sagemaker_session=self.sagemaker_session,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
union_key_paths_for_dataset_definition = [
|
|
545
|
+
["DatasetDefinition", "S3Input"],
|
|
546
|
+
[
|
|
547
|
+
"DatasetDefinition.AthenaDatasetDefinition",
|
|
548
|
+
"DatasetDefinition.RedshiftDatasetDefinition",
|
|
549
|
+
],
|
|
550
|
+
]
|
|
551
|
+
update_list_of_dicts_with_values_from_config(
|
|
552
|
+
process_args["inputs"],
|
|
553
|
+
PROCESSING_JOB_INPUTS_PATH,
|
|
554
|
+
union_key_paths=union_key_paths_for_dataset_definition,
|
|
555
|
+
sagemaker_session=self.sagemaker_session,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
role_arn = resolve_value_from_config(
|
|
559
|
+
process_args["role_arn"],
|
|
560
|
+
PROCESSING_JOB_ROLE_ARN_PATH,
|
|
561
|
+
sagemaker_session=self.sagemaker_session,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
inferred_network_config = update_nested_dictionary_with_values_from_config(
|
|
565
|
+
network_config,
|
|
566
|
+
PROCESSING_JOB_NETWORK_CONFIG_PATH,
|
|
567
|
+
sagemaker_session=self.sagemaker_session,
|
|
568
|
+
)
|
|
569
|
+
inferred_output_config = update_nested_dictionary_with_values_from_config(
|
|
570
|
+
process_args["output_config"],
|
|
571
|
+
PROCESSING_OUTPUT_CONFIG_PATH,
|
|
572
|
+
sagemaker_session=self.sagemaker_session,
|
|
573
|
+
)
|
|
574
|
+
inferred_resources_config = update_nested_dictionary_with_values_from_config(
|
|
575
|
+
process_args["resources"],
|
|
576
|
+
PROCESSING_JOB_PROCESSING_RESOURCES_PATH,
|
|
577
|
+
sagemaker_session=self.sagemaker_session,
|
|
578
|
+
)
|
|
579
|
+
environment = resolve_value_from_config(
|
|
580
|
+
direct_input=process_args["environment"],
|
|
581
|
+
config_path=PROCESSING_JOB_ENVIRONMENT_PATH,
|
|
582
|
+
default_value=None,
|
|
583
|
+
sagemaker_session=self.sagemaker_session,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
process_request = _get_process_request(
|
|
587
|
+
inputs=process_args["inputs"],
|
|
588
|
+
output_config=inferred_output_config,
|
|
589
|
+
job_name=process_args["job_name"],
|
|
590
|
+
resources=inferred_resources_config,
|
|
591
|
+
stopping_condition=process_args["stopping_condition"],
|
|
592
|
+
app_specification=process_args["app_specification"],
|
|
593
|
+
environment=environment,
|
|
594
|
+
network_config=inferred_network_config,
|
|
595
|
+
role_arn=role_arn,
|
|
596
|
+
tags=tags,
|
|
597
|
+
experiment_config=experiment_config,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
# convert Unassigned() type in sagemaker-core to None
|
|
601
|
+
serialized_request = serialize(process_request)
|
|
602
|
+
|
|
603
|
+
if isinstance(self.sagemaker_session, PipelineSession):
|
|
604
|
+
self.sagemaker_session._intercept_create_request(serialized_request, None, "process")
|
|
605
|
+
return
|
|
606
|
+
|
|
607
|
+
def submit(request):
|
|
608
|
+
try:
|
|
609
|
+
logger.info("Creating processing-job with name %s", process_args["job_name"])
|
|
610
|
+
logger.debug("process request: %s", json.dumps(request, indent=4))
|
|
611
|
+
self.sagemaker_session.sagemaker_client.create_processing_job(**request)
|
|
612
|
+
except Exception as e:
|
|
613
|
+
troubleshooting = (
|
|
614
|
+
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
|
|
615
|
+
"sagemaker-python-sdk-troubleshooting.html"
|
|
616
|
+
"#sagemaker-python-sdk-troubleshooting-create-processing-job"
|
|
617
|
+
)
|
|
618
|
+
logger.error(
|
|
619
|
+
"Please check the troubleshooting guide for common errors: %s", troubleshooting
|
|
620
|
+
)
|
|
621
|
+
raise e
|
|
622
|
+
|
|
623
|
+
self.sagemaker_session._intercept_create_request(serialized_request, submit, "process")
|
|
624
|
+
|
|
625
|
+
from sagemaker.core.utils.code_injection.codec import transform
|
|
626
|
+
|
|
627
|
+
transformed = transform(serialized_request, "CreateProcessingJobRequest")
|
|
628
|
+
return ProcessingJob(**transformed)
|
|
629
|
+
|
|
630
|
+
def _get_process_args(self, inputs, outputs, experiment_config):
|
|
631
|
+
"""Gets a dict of arguments for a new Amazon SageMaker processing job."""
|
|
632
|
+
process_request_args = {}
|
|
633
|
+
process_request_args["inputs"] = [_processing_input_to_request_dict(inp) for inp in inputs]
|
|
634
|
+
process_request_args["output_config"] = {
|
|
635
|
+
"Outputs": [_processing_output_to_request_dict(output) for output in outputs]
|
|
636
|
+
}
|
|
637
|
+
if self.output_kms_key is not None:
|
|
638
|
+
process_request_args["output_config"]["KmsKeyId"] = self.output_kms_key
|
|
639
|
+
process_request_args["experiment_config"] = experiment_config
|
|
640
|
+
process_request_args["job_name"] = self._current_job_name
|
|
641
|
+
process_request_args["resources"] = {
|
|
642
|
+
"ClusterConfig": {
|
|
643
|
+
"InstanceType": self.instance_type,
|
|
644
|
+
"InstanceCount": self.instance_count,
|
|
645
|
+
"VolumeSizeInGB": self.volume_size_in_gb,
|
|
646
|
+
}
|
|
647
|
+
}
|
|
648
|
+
if self.volume_kms_key is not None:
|
|
649
|
+
process_request_args["resources"]["ClusterConfig"][
|
|
650
|
+
"VolumeKmsKeyId"
|
|
651
|
+
] = self.volume_kms_key
|
|
652
|
+
if self.max_runtime_in_seconds is not None:
|
|
653
|
+
process_request_args["stopping_condition"] = {
|
|
654
|
+
"MaxRuntimeInSeconds": self.max_runtime_in_seconds
|
|
655
|
+
}
|
|
656
|
+
else:
|
|
657
|
+
process_request_args["stopping_condition"] = None
|
|
658
|
+
process_request_args["app_specification"] = {"ImageUri": self.image_uri}
|
|
659
|
+
if self.arguments is not None:
|
|
660
|
+
process_request_args["app_specification"]["ContainerArguments"] = self.arguments
|
|
661
|
+
if self.entrypoint is not None:
|
|
662
|
+
process_request_args["app_specification"]["ContainerEntrypoint"] = self.entrypoint
|
|
663
|
+
process_request_args["environment"] = self.env
|
|
664
|
+
if self.network_config is not None:
|
|
665
|
+
process_request_args["network_config"] = self.network_config._to_request_dict()
|
|
666
|
+
else:
|
|
667
|
+
process_request_args["network_config"] = None
|
|
668
|
+
process_request_args["role_arn"] = (
|
|
669
|
+
self.role
|
|
670
|
+
if is_pipeline_variable(self.role)
|
|
671
|
+
else self.sagemaker_session.expand_role(self.role)
|
|
672
|
+
)
|
|
673
|
+
process_request_args["tags"] = self.tags
|
|
674
|
+
return process_request_args
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
class ScriptProcessor(Processor):
|
|
678
|
+
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
|
|
679
|
+
|
|
680
|
+
def __init__(
|
|
681
|
+
self,
|
|
682
|
+
role: Optional[Union[str, PipelineVariable]] = None,
|
|
683
|
+
image_uri: Union[str, PipelineVariable] = None,
|
|
684
|
+
command: List[str] = None,
|
|
685
|
+
instance_count: Union[int, PipelineVariable] = None,
|
|
686
|
+
instance_type: Union[str, PipelineVariable] = None,
|
|
687
|
+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
|
|
688
|
+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
689
|
+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
690
|
+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
|
|
691
|
+
base_job_name: Optional[str] = None,
|
|
692
|
+
sagemaker_session: Optional[Session] = None,
|
|
693
|
+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
694
|
+
tags: Optional[Tags] = None,
|
|
695
|
+
network_config: Optional[NetworkConfig] = None,
|
|
696
|
+
):
|
|
697
|
+
"""Initializes a ``ScriptProcessor`` instance.
|
|
698
|
+
|
|
699
|
+
The ``ScriptProcessor`` handles Amazon SageMaker Processing tasks for jobs
|
|
700
|
+
using a machine learning framework, which allows for providing a script to be
|
|
701
|
+
run as part of the Processing Job.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker Processing
|
|
705
|
+
uses this role to access AWS resources, such as
|
|
706
|
+
data stored in Amazon S3.
|
|
707
|
+
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
|
|
708
|
+
processing jobs.
|
|
709
|
+
command ([str]): The command to run, along with any command-line flags.
|
|
710
|
+
Example: ["python3", "-v"].
|
|
711
|
+
instance_count (int or PipelineVariable): The number of instances to run
|
|
712
|
+
a processing job with.
|
|
713
|
+
instance_type (str or PipelineVariable): The type of EC2 instance to use for
|
|
714
|
+
processing, for example, 'ml.c4.xlarge'.
|
|
715
|
+
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
|
|
716
|
+
to use for storing data during processing (default: 30).
|
|
717
|
+
volume_kms_key (str or PipelineVariable): A KMS key for the processing
|
|
718
|
+
volume (default: None).
|
|
719
|
+
output_kms_key (str or PipelineVariable): The KMS key ID for processing
|
|
720
|
+
job outputs (default: None).
|
|
721
|
+
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds (default: None).
|
|
722
|
+
After this amount of time, Amazon SageMaker terminates the job,
|
|
723
|
+
regardless of its current status. If `max_runtime_in_seconds` is not
|
|
724
|
+
specified, the default value is 24 hours.
|
|
725
|
+
base_job_name (str): Prefix for processing name. If not specified,
|
|
726
|
+
the processor generates a default job name, based on the
|
|
727
|
+
processing image name and current timestamp.
|
|
728
|
+
sagemaker_session (:class:`~sagemaker.session.Session`):
|
|
729
|
+
Session object which manages interactions with Amazon SageMaker and
|
|
730
|
+
any other AWS services needed. If not specified, the processor creates
|
|
731
|
+
one using the default AWS configuration chain.
|
|
732
|
+
env (dict[str, str] or dict[str, PipelineVariable])): Environment variables to
|
|
733
|
+
be passed to the processing jobs (default: None).
|
|
734
|
+
tags (Optional[Tags]): Tags to be passed to the processing job (default: None).
|
|
735
|
+
For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
|
|
736
|
+
network_config (:class:`~sagemaker.network.NetworkConfig`):
|
|
737
|
+
A :class:`~sagemaker.network.NetworkConfig`
|
|
738
|
+
object that configures network isolation, encryption of
|
|
739
|
+
inter-container traffic, security group IDs, and subnets.
|
|
740
|
+
"""
|
|
741
|
+
self._CODE_CONTAINER_BASE_PATH = "/opt/ml/processing/input/"
|
|
742
|
+
self._CODE_CONTAINER_INPUT_NAME = "code"
|
|
743
|
+
|
|
744
|
+
if (
|
|
745
|
+
not command
|
|
746
|
+
and image_uri
|
|
747
|
+
and ("sklearn" in str(image_uri) or "scikit-learn" in str(image_uri))
|
|
748
|
+
):
|
|
749
|
+
command = ["python3"]
|
|
750
|
+
|
|
751
|
+
self.command = command
|
|
752
|
+
|
|
753
|
+
super(ScriptProcessor, self).__init__(
|
|
754
|
+
role=role,
|
|
755
|
+
image_uri=image_uri,
|
|
756
|
+
instance_count=instance_count,
|
|
757
|
+
instance_type=instance_type,
|
|
758
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
759
|
+
volume_kms_key=volume_kms_key,
|
|
760
|
+
output_kms_key=output_kms_key,
|
|
761
|
+
max_runtime_in_seconds=max_runtime_in_seconds,
|
|
762
|
+
base_job_name=base_job_name,
|
|
763
|
+
sagemaker_session=sagemaker_session,
|
|
764
|
+
env=env,
|
|
765
|
+
tags=format_tags(tags),
|
|
766
|
+
network_config=network_config,
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
@runnable_by_pipeline
|
|
770
|
+
def run(
|
|
771
|
+
self,
|
|
772
|
+
code: str,
|
|
773
|
+
inputs: Optional[List[ProcessingInput]] = None,
|
|
774
|
+
outputs: Optional[List[ProcessingOutput]] = None,
|
|
775
|
+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
776
|
+
wait: bool = True,
|
|
777
|
+
logs: bool = True,
|
|
778
|
+
job_name: Optional[str] = None,
|
|
779
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
780
|
+
kms_key: Optional[str] = None,
|
|
781
|
+
):
|
|
782
|
+
"""Runs a processing job.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
code (str): This can be an S3 URI or a local path to
|
|
786
|
+
a file with the framework script to run.
|
|
787
|
+
inputs (list[:class:`~sagemaker.core.shapes.ProcessingInput`]): Input files for
|
|
788
|
+
the processing job. These must be provided as
|
|
789
|
+
:class:`~sagemaker.core.shapes.ProcessingInput` objects (default: None).
|
|
790
|
+
outputs (list[:class:`~sagemaker.core.shapes.ProcessingOutput`]): Outputs for
|
|
791
|
+
the processing job. These can be specified as either path strings or
|
|
792
|
+
:class:`~sagemaker.core.shapes.ProcessingOutput` objects (default: None).
|
|
793
|
+
arguments (list[str]): A list of string arguments to be passed to a
|
|
794
|
+
processing job (default: None).
|
|
795
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
796
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
797
|
+
Only meaningful when wait is True (default: True).
|
|
798
|
+
job_name (str): Processing job name. If not specified, the processor generates
|
|
799
|
+
a default job name, based on the base job name and current timestamp.
|
|
800
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
801
|
+
Optionally, the dict can contain three keys:
|
|
802
|
+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
|
|
803
|
+
The behavior of setting these keys is as follows:
|
|
804
|
+
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
|
|
805
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
806
|
+
* If `TrialName` is supplied and the Trial already exists the job's Trial Component
|
|
807
|
+
will be associated with the Trial.
|
|
808
|
+
* If both `ExperimentName` and `TrialName` are not supplied the trial component
|
|
809
|
+
will be unassociated.
|
|
810
|
+
* `TrialComponentDisplayName` is used for display in Studio.
|
|
811
|
+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
|
|
812
|
+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
|
|
813
|
+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
|
|
814
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
815
|
+
user code file (default: None).
|
|
816
|
+
Returns:
|
|
817
|
+
None or pipeline step arguments in case the Processor instance is built with
|
|
818
|
+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
|
|
819
|
+
"""
|
|
820
|
+
normalized_inputs, normalized_outputs = self._normalize_args(
|
|
821
|
+
job_name=job_name,
|
|
822
|
+
arguments=arguments,
|
|
823
|
+
inputs=inputs,
|
|
824
|
+
outputs=outputs,
|
|
825
|
+
code=code,
|
|
826
|
+
kms_key=kms_key,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
experiment_config = check_and_get_run_experiment_config(experiment_config)
|
|
830
|
+
self.latest_job = self._start_new(
|
|
831
|
+
inputs=normalized_inputs,
|
|
832
|
+
outputs=normalized_outputs,
|
|
833
|
+
experiment_config=experiment_config,
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
from sagemaker.core.workflow.pipeline_context import PipelineSession
|
|
837
|
+
|
|
838
|
+
if not isinstance(self.sagemaker_session, PipelineSession):
|
|
839
|
+
self.jobs.append(self.latest_job)
|
|
840
|
+
if wait:
|
|
841
|
+
self.latest_job.wait(logs=logs)
|
|
842
|
+
|
|
843
|
+
def _include_code_in_inputs(self, inputs, code, kms_key=None):
|
|
844
|
+
"""Converts code to appropriate input and includes in input list.
|
|
845
|
+
|
|
846
|
+
Side effects include:
|
|
847
|
+
* uploads code to S3 if the code is a local file.
|
|
848
|
+
* sets the entrypoint attribute based on the command and user script name from code.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
inputs (list[:class:`~sagemaker.core.shapes.ProcessingInput`]): Input files for
|
|
852
|
+
the processing job. These must be provided as
|
|
853
|
+
:class:`~sagemaker.core.shapes.ProcessingInput` objects.
|
|
854
|
+
code (str): This can be an S3 URI or a local path to a file with the framework
|
|
855
|
+
script to run (default: None).
|
|
856
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
857
|
+
user code file (default: None).
|
|
858
|
+
|
|
859
|
+
Returns:
|
|
860
|
+
list[:class:`~sagemaker.core.shapes.ProcessingInput`]: inputs together with the
|
|
861
|
+
code as `ProcessingInput`.
|
|
862
|
+
"""
|
|
863
|
+
user_code_s3_uri = self._handle_user_code_url(code, kms_key)
|
|
864
|
+
user_script_name = self._get_user_code_name(code)
|
|
865
|
+
|
|
866
|
+
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
|
|
867
|
+
|
|
868
|
+
self._set_entrypoint(self.command, user_script_name)
|
|
869
|
+
return inputs_with_code
|
|
870
|
+
|
|
871
|
+
def _get_user_code_name(self, code):
|
|
872
|
+
"""Gets the basename of the user's code from the URL the customer provided.
|
|
873
|
+
|
|
874
|
+
Args:
|
|
875
|
+
code (str): A URL to the user's code.
|
|
876
|
+
|
|
877
|
+
Returns:
|
|
878
|
+
str: The basename of the user's code.
|
|
879
|
+
|
|
880
|
+
"""
|
|
881
|
+
code_url = urlparse(code)
|
|
882
|
+
return os.path.basename(code_url.path)
|
|
883
|
+
|
|
884
|
+
def _handle_user_code_url(self, code, kms_key=None):
|
|
885
|
+
"""Gets the S3 URL containing the user's code.
|
|
886
|
+
|
|
887
|
+
Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
|
|
888
|
+
for absolute or local file paths. Uploads the code to S3 if the code is a local file.
|
|
889
|
+
|
|
890
|
+
Args:
|
|
891
|
+
code (str): A URL to the customer's code.
|
|
892
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
893
|
+
user code file (default: None).
|
|
894
|
+
|
|
895
|
+
Returns:
|
|
896
|
+
str: The S3 URL to the customer's code.
|
|
897
|
+
|
|
898
|
+
Raises:
|
|
899
|
+
ValueError: if the code isn't found, is a directory, or
|
|
900
|
+
does not have a valid URL scheme.
|
|
901
|
+
"""
|
|
902
|
+
code_url = urlparse(code)
|
|
903
|
+
if code_url.scheme == "s3":
|
|
904
|
+
user_code_s3_uri = code
|
|
905
|
+
elif code_url.scheme == "" or code_url.scheme == "file":
|
|
906
|
+
# Validate that the file exists locally and is not a directory.
|
|
907
|
+
code_path = url2pathname(code_url.path)
|
|
908
|
+
if not os.path.exists(code_path):
|
|
909
|
+
raise ValueError(
|
|
910
|
+
"""code {} wasn't found. Please make sure that the file exists.
|
|
911
|
+
""".format(
|
|
912
|
+
code
|
|
913
|
+
)
|
|
914
|
+
)
|
|
915
|
+
if not os.path.isfile(code_path):
|
|
916
|
+
raise ValueError(
|
|
917
|
+
"""code {} must be a file, not a directory. Please pass a path to a file.
|
|
918
|
+
""".format(
|
|
919
|
+
code
|
|
920
|
+
)
|
|
921
|
+
)
|
|
922
|
+
user_code_s3_uri = self._upload_code(code_path, kms_key)
|
|
923
|
+
else:
|
|
924
|
+
raise ValueError(
|
|
925
|
+
"code {} url scheme {} is not recognized. Please pass a file path or S3 url".format(
|
|
926
|
+
code, code_url.scheme
|
|
927
|
+
)
|
|
928
|
+
)
|
|
929
|
+
return user_code_s3_uri
|
|
930
|
+
|
|
931
|
+
def _upload_code(self, code, kms_key=None):
|
|
932
|
+
"""Uploads a code file or directory specified as a string and returns the S3 URI.
|
|
933
|
+
|
|
934
|
+
Args:
|
|
935
|
+
code (str): A file or directory to be uploaded to S3.
|
|
936
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
937
|
+
user code file (default: None).
|
|
938
|
+
|
|
939
|
+
Returns:
|
|
940
|
+
str: The S3 URI of the uploaded file or directory.
|
|
941
|
+
|
|
942
|
+
"""
|
|
943
|
+
from sagemaker.core.workflow.utilities import _pipeline_config
|
|
944
|
+
|
|
945
|
+
if _pipeline_config and _pipeline_config.code_hash:
|
|
946
|
+
desired_s3_uri = s3.s3_path_join(
|
|
947
|
+
"s3://",
|
|
948
|
+
self.sagemaker_session.default_bucket(),
|
|
949
|
+
self.sagemaker_session.default_bucket_prefix,
|
|
950
|
+
_pipeline_config.pipeline_name,
|
|
951
|
+
self._CODE_CONTAINER_INPUT_NAME,
|
|
952
|
+
_pipeline_config.code_hash,
|
|
953
|
+
)
|
|
954
|
+
else:
|
|
955
|
+
desired_s3_uri = s3.s3_path_join(
|
|
956
|
+
"s3://",
|
|
957
|
+
self.sagemaker_session.default_bucket(),
|
|
958
|
+
self.sagemaker_session.default_bucket_prefix,
|
|
959
|
+
self._current_job_name,
|
|
960
|
+
"input",
|
|
961
|
+
self._CODE_CONTAINER_INPUT_NAME,
|
|
962
|
+
)
|
|
963
|
+
return s3.S3Uploader.upload(
|
|
964
|
+
local_path=code,
|
|
965
|
+
desired_s3_uri=desired_s3_uri,
|
|
966
|
+
kms_key=kms_key,
|
|
967
|
+
sagemaker_session=self.sagemaker_session,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
|
|
971
|
+
"""Creates a ``ProcessingInput`` object from an S3 URI and adds it to the list of inputs.
|
|
972
|
+
|
|
973
|
+
Args:
|
|
974
|
+
inputs (list[sagemaker.core.shapes.ProcessingInput]):
|
|
975
|
+
List of ``ProcessingInput`` objects.
|
|
976
|
+
s3_uri (str): S3 URI of the input to be added to inputs.
|
|
977
|
+
|
|
978
|
+
Returns:
|
|
979
|
+
list[sagemaker.core.shapes.ProcessingInput]: A new list of ``ProcessingInput`` objects,
|
|
980
|
+
with the ``ProcessingInput`` object created from ``s3_uri`` appended to the list.
|
|
981
|
+
|
|
982
|
+
"""
|
|
983
|
+
|
|
984
|
+
code_file_input = ProcessingInput(
|
|
985
|
+
input_name=self._CODE_CONTAINER_INPUT_NAME,
|
|
986
|
+
s3_input=ProcessingS3Input(
|
|
987
|
+
s3_uri=s3_uri,
|
|
988
|
+
local_path=str(
|
|
989
|
+
pathlib.PurePosixPath(
|
|
990
|
+
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME
|
|
991
|
+
)
|
|
992
|
+
),
|
|
993
|
+
s3_data_type="S3Prefix",
|
|
994
|
+
s3_input_mode="File",
|
|
995
|
+
),
|
|
996
|
+
)
|
|
997
|
+
return (inputs or []) + [code_file_input]
|
|
998
|
+
|
|
999
|
+
def _set_entrypoint(self, command, user_script_name):
|
|
1000
|
+
"""Sets the entrypoint based on the user's script and corresponding executable.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
user_script_name (str): A filename with an extension.
|
|
1004
|
+
"""
|
|
1005
|
+
user_script_location = str(
|
|
1006
|
+
pathlib.PurePosixPath(
|
|
1007
|
+
self._CODE_CONTAINER_BASE_PATH,
|
|
1008
|
+
self._CODE_CONTAINER_INPUT_NAME,
|
|
1009
|
+
user_script_name,
|
|
1010
|
+
)
|
|
1011
|
+
)
|
|
1012
|
+
self.entrypoint = command + [user_script_location]
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
class FrameworkProcessor(ScriptProcessor):
|
|
1016
|
+
"""Handles Amazon SageMaker processing tasks using ModelTrainer for code packaging."""
|
|
1017
|
+
|
|
1018
|
+
framework_entrypoint_command = ["/bin/bash"]
|
|
1019
|
+
|
|
1020
|
+
def __init__(
|
|
1021
|
+
self,
|
|
1022
|
+
image_uri: Union[str, PipelineVariable],
|
|
1023
|
+
role: Optional[Union[str, PipelineVariable]] = None,
|
|
1024
|
+
instance_count: Union[int, PipelineVariable] = None,
|
|
1025
|
+
instance_type: Union[str, PipelineVariable] = None,
|
|
1026
|
+
command: Optional[List[str]] = None,
|
|
1027
|
+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
|
|
1028
|
+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
1029
|
+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
|
|
1030
|
+
code_location: Optional[str] = None,
|
|
1031
|
+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
|
|
1032
|
+
base_job_name: Optional[str] = None,
|
|
1033
|
+
sagemaker_session: Optional[Session] = None,
|
|
1034
|
+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
1035
|
+
tags: Optional[Tags] = None,
|
|
1036
|
+
network_config: Optional[NetworkConfig] = None,
|
|
1037
|
+
):
|
|
1038
|
+
"""Initializes a ``FrameworkProcessor`` instance.
|
|
1039
|
+
|
|
1040
|
+
The ``FrameworkProcessor`` handles Amazon SageMaker Processing tasks using
|
|
1041
|
+
ModelTrainer for code packaging instead of Framework estimators.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
|
|
1045
|
+
processing jobs.
|
|
1046
|
+
role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker
|
|
1047
|
+
Processing uses this role to access AWS resources, such as data stored
|
|
1048
|
+
in Amazon S3.
|
|
1049
|
+
instance_count (int or PipelineVariable): The number of instances to run a
|
|
1050
|
+
processing job with.
|
|
1051
|
+
instance_type (str or PipelineVariable): The type of EC2 instance to use for
|
|
1052
|
+
processing, for example, 'ml.c4.xlarge'.
|
|
1053
|
+
command ([str]): The command to run, along with any command-line flags
|
|
1054
|
+
to *precede* the ```code script```. Example: ["python3", "-v"]. If not
|
|
1055
|
+
provided, ["python"] will be chosen (default: None).
|
|
1056
|
+
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
|
|
1057
|
+
to use for storing data during processing (default: 30).
|
|
1058
|
+
volume_kms_key (str or PipelineVariable): A KMS key for the processing volume
|
|
1059
|
+
(default: None).
|
|
1060
|
+
output_kms_key (str or PipelineVariable): The KMS key ID for processing job outputs
|
|
1061
|
+
(default: None).
|
|
1062
|
+
code_location (str): The S3 prefix URI where custom code will be
|
|
1063
|
+
uploaded (default: None). The code file uploaded to S3 is
|
|
1064
|
+
'code_location/job-name/source/sourcedir.tar.gz'. If not specified, the
|
|
1065
|
+
default ``code location`` is 's3://{sagemaker-default-bucket}'
|
|
1066
|
+
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds (default: None).
|
|
1067
|
+
After this amount of time, Amazon SageMaker terminates the job,
|
|
1068
|
+
regardless of its current status. If `max_runtime_in_seconds` is not
|
|
1069
|
+
specified, the default value is 24 hours.
|
|
1070
|
+
base_job_name (str): Prefix for processing name. If not specified,
|
|
1071
|
+
the processor generates a default job name, based on the
|
|
1072
|
+
processing image name and current timestamp (default: None).
|
|
1073
|
+
sagemaker_session (:class:`~sagemaker.session.Session`):
|
|
1074
|
+
Session object which manages interactions with Amazon SageMaker and
|
|
1075
|
+
any other AWS services needed. If not specified, the processor creates
|
|
1076
|
+
one using the default AWS configuration chain (default: None).
|
|
1077
|
+
env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to
|
|
1078
|
+
be passed to the processing jobs (default: None).
|
|
1079
|
+
tags (Optional[Tags]): Tags to be passed to the processing job (default: None).
|
|
1080
|
+
For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
|
|
1081
|
+
network_config (:class:`~sagemaker.network.NetworkConfig`):
|
|
1082
|
+
A :class:`~sagemaker.network.NetworkConfig`
|
|
1083
|
+
object that configures network isolation, encryption of
|
|
1084
|
+
inter-container traffic, security group IDs, and subnets (default: None).
|
|
1085
|
+
"""
|
|
1086
|
+
if not command:
|
|
1087
|
+
command = ["python"]
|
|
1088
|
+
|
|
1089
|
+
super().__init__(
|
|
1090
|
+
role=role,
|
|
1091
|
+
image_uri=image_uri,
|
|
1092
|
+
command=command,
|
|
1093
|
+
instance_count=instance_count,
|
|
1094
|
+
instance_type=instance_type,
|
|
1095
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
1096
|
+
volume_kms_key=volume_kms_key,
|
|
1097
|
+
output_kms_key=output_kms_key,
|
|
1098
|
+
max_runtime_in_seconds=max_runtime_in_seconds,
|
|
1099
|
+
base_job_name=base_job_name,
|
|
1100
|
+
sagemaker_session=sagemaker_session,
|
|
1101
|
+
env=env,
|
|
1102
|
+
tags=format_tags(tags),
|
|
1103
|
+
network_config=network_config,
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
# This subclass uses the "code" input for actual payload and the ScriptProcessor parent's
|
|
1107
|
+
# functionality for uploading just a small entrypoint script to invoke it.
|
|
1108
|
+
self._CODE_CONTAINER_INPUT_NAME = "entrypoint"
|
|
1109
|
+
|
|
1110
|
+
self.code_location = (
|
|
1111
|
+
code_location[:-1] if (code_location and code_location.endswith("/")) else code_location
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
def _package_code(
|
|
1115
|
+
self,
|
|
1116
|
+
entry_point,
|
|
1117
|
+
source_dir,
|
|
1118
|
+
requirements,
|
|
1119
|
+
job_name,
|
|
1120
|
+
kms_key,
|
|
1121
|
+
):
|
|
1122
|
+
"""Package and upload code to S3."""
|
|
1123
|
+
import tarfile
|
|
1124
|
+
import tempfile
|
|
1125
|
+
|
|
1126
|
+
# If source_dir is not provided, use the directory containing entry_point
|
|
1127
|
+
if source_dir is None:
|
|
1128
|
+
if os.path.isabs(entry_point):
|
|
1129
|
+
source_dir = os.path.dirname(entry_point)
|
|
1130
|
+
else:
|
|
1131
|
+
source_dir = os.path.dirname(os.path.abspath(entry_point))
|
|
1132
|
+
|
|
1133
|
+
# Resolve source_dir to absolute path
|
|
1134
|
+
if not os.path.isabs(source_dir):
|
|
1135
|
+
source_dir = os.path.abspath(source_dir)
|
|
1136
|
+
|
|
1137
|
+
if not os.path.exists(source_dir):
|
|
1138
|
+
raise ValueError(f"source_dir does not exist: {source_dir}")
|
|
1139
|
+
|
|
1140
|
+
# Create tar.gz with source_dir contents
|
|
1141
|
+
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
|
|
1142
|
+
with tarfile.open(tmp.name, "w:gz") as tar:
|
|
1143
|
+
# Add all files from source_dir to the root of the tar
|
|
1144
|
+
for item in os.listdir(source_dir):
|
|
1145
|
+
item_path = os.path.join(source_dir, item)
|
|
1146
|
+
tar.add(item_path, arcname=item)
|
|
1147
|
+
|
|
1148
|
+
# Upload to S3
|
|
1149
|
+
s3_uri = s3.s3_path_join(
|
|
1150
|
+
"s3://",
|
|
1151
|
+
self.sagemaker_session.default_bucket(),
|
|
1152
|
+
self.sagemaker_session.default_bucket_prefix or "",
|
|
1153
|
+
job_name,
|
|
1154
|
+
"source",
|
|
1155
|
+
"sourcedir.tar.gz",
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
# Upload the tar file directly to S3
|
|
1159
|
+
s3.S3Uploader.upload_string_as_file_body(
|
|
1160
|
+
body=open(tmp.name, "rb").read(),
|
|
1161
|
+
desired_s3_uri=s3_uri,
|
|
1162
|
+
kms_key=kms_key,
|
|
1163
|
+
sagemaker_session=self.sagemaker_session,
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
os.unlink(tmp.name)
|
|
1167
|
+
return s3_uri
|
|
1168
|
+
|
|
1169
|
+
@runnable_by_pipeline
|
|
1170
|
+
def run(
|
|
1171
|
+
self,
|
|
1172
|
+
code: str,
|
|
1173
|
+
source_dir: Optional[str] = None,
|
|
1174
|
+
requirements: Optional[str] = None,
|
|
1175
|
+
inputs: Optional[List[ProcessingInput]] = None,
|
|
1176
|
+
outputs: Optional[List["ProcessingOutput"]] = None,
|
|
1177
|
+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
1178
|
+
wait: bool = True,
|
|
1179
|
+
logs: bool = True,
|
|
1180
|
+
job_name: Optional[str] = None,
|
|
1181
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
1182
|
+
kms_key: Optional[str] = None,
|
|
1183
|
+
):
|
|
1184
|
+
"""Runs a processing job.
|
|
1185
|
+
|
|
1186
|
+
Args:
|
|
1187
|
+
code (str): This can be an S3 URI or a local path to a file with the
|
|
1188
|
+
framework script to run.
|
|
1189
|
+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
|
|
1190
|
+
with any other processing source code dependencies aside from the entry
|
|
1191
|
+
point file (default: None).
|
|
1192
|
+
requirements (str): Path to a requirements.txt file relative to source_dir
|
|
1193
|
+
(default: None).
|
|
1194
|
+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
|
|
1195
|
+
the processing job. These must be provided as
|
|
1196
|
+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
|
|
1197
|
+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
|
|
1198
|
+
the processing job. These can be specified as either path strings or
|
|
1199
|
+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
|
|
1200
|
+
arguments (list[str] or list[PipelineVariable]): A list of string arguments
|
|
1201
|
+
to be passed to a processing job (default: None).
|
|
1202
|
+
wait (bool): Whether the call should wait until the job completes (default: True).
|
|
1203
|
+
logs (bool): Whether to show the logs produced by the job.
|
|
1204
|
+
Only meaningful when wait is True (default: True).
|
|
1205
|
+
job_name (str): Processing job name. If not specified, the processor generates
|
|
1206
|
+
a default job name, based on the base job name and current timestamp.
|
|
1207
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
1208
|
+
kms_key (str): The ARN of the KMS key that is used to encrypt the
|
|
1209
|
+
user code file (default: None).
|
|
1210
|
+
Returns:
|
|
1211
|
+
None or pipeline step arguments in case the Processor instance is built with
|
|
1212
|
+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
|
|
1213
|
+
"""
|
|
1214
|
+
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
|
|
1215
|
+
code,
|
|
1216
|
+
source_dir,
|
|
1217
|
+
requirements,
|
|
1218
|
+
job_name,
|
|
1219
|
+
inputs,
|
|
1220
|
+
kms_key,
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
# Submit a processing job.
|
|
1224
|
+
return super().run(
|
|
1225
|
+
code=s3_runproc_sh,
|
|
1226
|
+
inputs=inputs,
|
|
1227
|
+
outputs=outputs,
|
|
1228
|
+
arguments=arguments,
|
|
1229
|
+
wait=wait,
|
|
1230
|
+
logs=logs,
|
|
1231
|
+
job_name=job_name,
|
|
1232
|
+
experiment_config=experiment_config,
|
|
1233
|
+
kms_key=kms_key,
|
|
1234
|
+
)
|
|
1235
|
+
|
|
1236
|
+
def _pack_and_upload_code(
|
|
1237
|
+
self,
|
|
1238
|
+
code,
|
|
1239
|
+
source_dir,
|
|
1240
|
+
requirements,
|
|
1241
|
+
job_name,
|
|
1242
|
+
inputs,
|
|
1243
|
+
kms_key=None,
|
|
1244
|
+
):
|
|
1245
|
+
"""Pack local code bundle and upload to Amazon S3."""
|
|
1246
|
+
if code.startswith("s3://"):
|
|
1247
|
+
return code, inputs, job_name
|
|
1248
|
+
|
|
1249
|
+
if job_name is None:
|
|
1250
|
+
job_name = self._generate_current_job_name(job_name)
|
|
1251
|
+
|
|
1252
|
+
# Package and upload code
|
|
1253
|
+
s3_payload = self._package_code(
|
|
1254
|
+
entry_point=code,
|
|
1255
|
+
source_dir=source_dir,
|
|
1256
|
+
requirements=requirements,
|
|
1257
|
+
job_name=job_name,
|
|
1258
|
+
kms_key=kms_key,
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
inputs = self._patch_inputs_with_payload(inputs, s3_payload)
|
|
1262
|
+
|
|
1263
|
+
entrypoint_s3_uri = s3_payload.replace("sourcedir.tar.gz", "runproc.sh")
|
|
1264
|
+
|
|
1265
|
+
script = os.path.basename(code)
|
|
1266
|
+
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
|
|
1267
|
+
s3_runproc_sh = self._create_and_upload_runproc(
|
|
1268
|
+
script, evaluated_kms_key, entrypoint_s3_uri
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
return s3_runproc_sh, inputs, job_name
|
|
1272
|
+
|
|
1273
|
+
def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput]:
|
|
1274
|
+
"""Add payload sourcedir.tar.gz to processing input."""
|
|
1275
|
+
if inputs is None:
|
|
1276
|
+
inputs = []
|
|
1277
|
+
|
|
1278
|
+
# make a shallow copy of user inputs
|
|
1279
|
+
patched_inputs = copy(inputs)
|
|
1280
|
+
|
|
1281
|
+
# Extract the directory path from the s3_payload (remove the filename)
|
|
1282
|
+
s3_code_dir = s3_payload.rsplit("/", 1)[0] + "/"
|
|
1283
|
+
|
|
1284
|
+
patched_inputs.append(
|
|
1285
|
+
ProcessingInput(
|
|
1286
|
+
input_name="code",
|
|
1287
|
+
s3_input=ProcessingS3Input(
|
|
1288
|
+
s3_uri=s3_code_dir,
|
|
1289
|
+
local_path="/opt/ml/processing/input/code/",
|
|
1290
|
+
s3_data_type="S3Prefix",
|
|
1291
|
+
s3_input_mode="File",
|
|
1292
|
+
),
|
|
1293
|
+
)
|
|
1294
|
+
)
|
|
1295
|
+
return patched_inputs
|
|
1296
|
+
|
|
1297
|
+
def _set_entrypoint(self, command, user_script_name):
|
|
1298
|
+
"""Framework processor override for setting processing job entrypoint."""
|
|
1299
|
+
user_script_location = str(
|
|
1300
|
+
pathlib.PurePosixPath(
|
|
1301
|
+
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
|
|
1302
|
+
)
|
|
1303
|
+
)
|
|
1304
|
+
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
|
|
1305
|
+
|
|
1306
|
+
def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
|
|
1307
|
+
"""Create runproc shell script and upload to S3 bucket."""
|
|
1308
|
+
from sagemaker.core.workflow.utilities import _pipeline_config, hash_object
|
|
1309
|
+
|
|
1310
|
+
if _pipeline_config and _pipeline_config.pipeline_name:
|
|
1311
|
+
runproc_file_str = self._generate_framework_script(user_script)
|
|
1312
|
+
runproc_file_hash = hash_object(runproc_file_str)
|
|
1313
|
+
s3_uri = s3.s3_path_join(
|
|
1314
|
+
"s3://",
|
|
1315
|
+
self.sagemaker_session.default_bucket(),
|
|
1316
|
+
self.sagemaker_session.default_bucket_prefix,
|
|
1317
|
+
_pipeline_config.pipeline_name,
|
|
1318
|
+
"code",
|
|
1319
|
+
runproc_file_hash,
|
|
1320
|
+
"runproc.sh",
|
|
1321
|
+
)
|
|
1322
|
+
s3_runproc_sh = s3.S3Uploader.upload_string_as_file_body(
|
|
1323
|
+
runproc_file_str,
|
|
1324
|
+
desired_s3_uri=s3_uri,
|
|
1325
|
+
kms_key=kms_key,
|
|
1326
|
+
sagemaker_session=self.sagemaker_session,
|
|
1327
|
+
)
|
|
1328
|
+
else:
|
|
1329
|
+
s3_runproc_sh = s3.S3Uploader.upload_string_as_file_body(
|
|
1330
|
+
self._generate_framework_script(user_script),
|
|
1331
|
+
desired_s3_uri=entrypoint_s3_uri,
|
|
1332
|
+
kms_key=kms_key,
|
|
1333
|
+
sagemaker_session=self.sagemaker_session,
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1336
|
+
return s3_runproc_sh
|
|
1337
|
+
|
|
1338
|
+
def _generate_framework_script(self, user_script: str) -> str:
|
|
1339
|
+
"""Generate the framework entrypoint file (as text) for a processing job."""
|
|
1340
|
+
return dedent(
|
|
1341
|
+
"""\
|
|
1342
|
+
#!/bin/bash
|
|
1343
|
+
|
|
1344
|
+
# Exit on any error. SageMaker uses error code to mark failed job.
|
|
1345
|
+
set -e
|
|
1346
|
+
|
|
1347
|
+
cd /opt/ml/processing/input/code/
|
|
1348
|
+
|
|
1349
|
+
# Debug: List files before extraction
|
|
1350
|
+
echo "Files in /opt/ml/processing/input/code/ before extraction:"
|
|
1351
|
+
ls -la
|
|
1352
|
+
|
|
1353
|
+
# Extract source code
|
|
1354
|
+
if [ -f sourcedir.tar.gz ]; then
|
|
1355
|
+
tar -xzf sourcedir.tar.gz
|
|
1356
|
+
echo "Files after extraction:"
|
|
1357
|
+
ls -la
|
|
1358
|
+
else
|
|
1359
|
+
echo "ERROR: sourcedir.tar.gz not found!"
|
|
1360
|
+
exit 1
|
|
1361
|
+
fi
|
|
1362
|
+
|
|
1363
|
+
if [[ -f 'requirements.txt' ]]; then
|
|
1364
|
+
# Some py3 containers has typing, which may breaks pip install
|
|
1365
|
+
pip uninstall --yes typing
|
|
1366
|
+
|
|
1367
|
+
pip install -r requirements.txt
|
|
1368
|
+
fi
|
|
1369
|
+
|
|
1370
|
+
{entry_point_command} {entry_point} "$@"
|
|
1371
|
+
"""
|
|
1372
|
+
).format(
|
|
1373
|
+
entry_point_command=" ".join(self.command),
|
|
1374
|
+
entry_point=user_script,
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
|
|
1378
|
+
class FeatureStoreOutput(ApiObject):
|
|
1379
|
+
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""
|
|
1380
|
+
|
|
1381
|
+
feature_group_name: Optional[str] = None
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
def _processing_input_to_request_dict(processing_input):
|
|
1385
|
+
"""Convert ProcessingInput to request dictionary format."""
|
|
1386
|
+
app_managed = getattr(processing_input, "app_managed", False)
|
|
1387
|
+
request_dict = {
|
|
1388
|
+
"InputName": processing_input.input_name,
|
|
1389
|
+
"AppManaged": app_managed if app_managed is not None else False,
|
|
1390
|
+
}
|
|
1391
|
+
|
|
1392
|
+
if processing_input.s3_input:
|
|
1393
|
+
request_dict["S3Input"] = {
|
|
1394
|
+
"S3Uri": processing_input.s3_input.s3_uri,
|
|
1395
|
+
"LocalPath": processing_input.s3_input.local_path,
|
|
1396
|
+
"S3DataType": processing_input.s3_input.s3_data_type or "S3Prefix",
|
|
1397
|
+
"S3InputMode": processing_input.s3_input.s3_input_mode or "File",
|
|
1398
|
+
"S3DataDistributionType": processing_input.s3_input.s3_data_distribution_type
|
|
1399
|
+
or "FullyReplicated",
|
|
1400
|
+
"S3CompressionType": processing_input.s3_input.s3_compression_type or "None",
|
|
1401
|
+
}
|
|
1402
|
+
|
|
1403
|
+
return request_dict
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
def _processing_output_to_request_dict(processing_output):
|
|
1407
|
+
"""Convert ProcessingOutput to request dictionary format."""
|
|
1408
|
+
app_managed = getattr(processing_output, "app_managed", False)
|
|
1409
|
+
request_dict = {
|
|
1410
|
+
"OutputName": processing_output.output_name,
|
|
1411
|
+
"AppManaged": app_managed if app_managed is not None else False,
|
|
1412
|
+
}
|
|
1413
|
+
|
|
1414
|
+
if processing_output.s3_output:
|
|
1415
|
+
request_dict["S3Output"] = {
|
|
1416
|
+
"S3Uri": processing_output.s3_output.s3_uri,
|
|
1417
|
+
"LocalPath": processing_output.s3_output.local_path,
|
|
1418
|
+
"S3UploadMode": processing_output.s3_output.s3_upload_mode,
|
|
1419
|
+
}
|
|
1420
|
+
|
|
1421
|
+
return request_dict
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
def _get_process_request(
|
|
1425
|
+
inputs,
|
|
1426
|
+
output_config,
|
|
1427
|
+
job_name,
|
|
1428
|
+
resources,
|
|
1429
|
+
stopping_condition,
|
|
1430
|
+
app_specification,
|
|
1431
|
+
environment,
|
|
1432
|
+
network_config,
|
|
1433
|
+
role_arn,
|
|
1434
|
+
tags,
|
|
1435
|
+
experiment_config=None,
|
|
1436
|
+
):
|
|
1437
|
+
"""Constructs a request compatible for an Amazon SageMaker processing job.
|
|
1438
|
+
|
|
1439
|
+
Args:
|
|
1440
|
+
inputs ([dict]): List of up to 10 ProcessingInput dictionaries.
|
|
1441
|
+
output_config (dict): A config dictionary, which contains a list of up
|
|
1442
|
+
to 10 ProcessingOutput dictionaries, as well as an optional KMS key ID.
|
|
1443
|
+
job_name (str): The name of the processing job. The name must be unique
|
|
1444
|
+
within an AWS Region in an AWS account. Names should have minimum
|
|
1445
|
+
length of 1 and maximum length of 63 characters.
|
|
1446
|
+
resources (dict): Encapsulates the resources, including ML instances
|
|
1447
|
+
and storage, to use for the processing job.
|
|
1448
|
+
stopping_condition (dict[str,int]): Specifies a limit to how long
|
|
1449
|
+
the processing job can run, in seconds.
|
|
1450
|
+
app_specification (dict[str,str]): Configures the processing job to
|
|
1451
|
+
run the given image. Details are in the processing container
|
|
1452
|
+
specification.
|
|
1453
|
+
environment (dict): Environment variables to start the processing
|
|
1454
|
+
container with.
|
|
1455
|
+
network_config (dict): Specifies networking options, such as network
|
|
1456
|
+
traffic encryption between processing containers, whether to allow
|
|
1457
|
+
inbound and outbound network calls to and from processing containers,
|
|
1458
|
+
and VPC subnets and security groups to use for VPC-enabled processing
|
|
1459
|
+
jobs.
|
|
1460
|
+
role_arn (str): The Amazon Resource Name (ARN) of an IAM role that
|
|
1461
|
+
Amazon SageMaker can assume to perform tasks on your behalf.
|
|
1462
|
+
tags ([dict[str,str]]): A list of dictionaries containing key-value
|
|
1463
|
+
pairs.
|
|
1464
|
+
experiment_config (dict[str, str]): Experiment management configuration.
|
|
1465
|
+
Optionally, the dict can contain three keys:
|
|
1466
|
+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
|
|
1467
|
+
The behavior of setting these keys is as follows:
|
|
1468
|
+
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
|
|
1469
|
+
automatically created and the job's Trial Component associated with the Trial.
|
|
1470
|
+
* If `TrialName` is supplied and the Trial already exists the job's Trial Component
|
|
1471
|
+
will be associated with the Trial.
|
|
1472
|
+
* If both `ExperimentName` and `TrialName` are not supplied the trial component
|
|
1473
|
+
will be unassociated.
|
|
1474
|
+
* `TrialComponentDisplayName` is used for display in Studio.
|
|
1475
|
+
|
|
1476
|
+
Returns:
|
|
1477
|
+
Dict: a processing job request dict
|
|
1478
|
+
"""
|
|
1479
|
+
process_request = {
|
|
1480
|
+
"ProcessingJobName": job_name,
|
|
1481
|
+
"ProcessingResources": resources,
|
|
1482
|
+
"AppSpecification": app_specification,
|
|
1483
|
+
"RoleArn": role_arn,
|
|
1484
|
+
}
|
|
1485
|
+
|
|
1486
|
+
if inputs:
|
|
1487
|
+
process_request["ProcessingInputs"] = inputs
|
|
1488
|
+
|
|
1489
|
+
if output_config["Outputs"]:
|
|
1490
|
+
process_request["ProcessingOutputConfig"] = output_config
|
|
1491
|
+
|
|
1492
|
+
if environment is not None:
|
|
1493
|
+
process_request["Environment"] = environment
|
|
1494
|
+
|
|
1495
|
+
if network_config is not None:
|
|
1496
|
+
process_request["NetworkConfig"] = network_config
|
|
1497
|
+
|
|
1498
|
+
if stopping_condition is not None:
|
|
1499
|
+
process_request["StoppingCondition"] = stopping_condition
|
|
1500
|
+
|
|
1501
|
+
if tags is not None:
|
|
1502
|
+
process_request["Tags"] = tags
|
|
1503
|
+
|
|
1504
|
+
if experiment_config:
|
|
1505
|
+
process_request["ExperimentConfig"] = experiment_config
|
|
1506
|
+
|
|
1507
|
+
return process_request
|
|
1508
|
+
|
|
1509
|
+
|
|
1510
|
+
def logs_for_processing_job(sagemaker_session, job_name, wait=False, poll=10):
|
|
1511
|
+
"""Display logs for a given processing job, optionally tailing them until the is complete.
|
|
1512
|
+
|
|
1513
|
+
Args:
|
|
1514
|
+
job_name (str): Name of the processing job to display the logs for.
|
|
1515
|
+
wait (bool): Whether to keep looking for new log entries until the job completes
|
|
1516
|
+
(default: False).
|
|
1517
|
+
poll (int): The interval in seconds between polling for new log entries and job
|
|
1518
|
+
completion (default: 5).
|
|
1519
|
+
|
|
1520
|
+
Raises:
|
|
1521
|
+
ValueError: If the processing job fails.
|
|
1522
|
+
"""
|
|
1523
|
+
|
|
1524
|
+
description = _wait_until(
|
|
1525
|
+
lambda: ProcessingJob.get(
|
|
1526
|
+
processing_job_name=job_name, session=sagemaker_session.boto_session
|
|
1527
|
+
)
|
|
1528
|
+
.refresh()
|
|
1529
|
+
.__dict__,
|
|
1530
|
+
poll,
|
|
1531
|
+
)
|
|
1532
|
+
|
|
1533
|
+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
|
|
1534
|
+
sagemaker_session.boto_session, description, job="Processing"
|
|
1535
|
+
)
|
|
1536
|
+
|
|
1537
|
+
state = _get_initial_job_state(description, "ProcessingJobStatus", wait)
|
|
1538
|
+
|
|
1539
|
+
# The loop below implements a state machine that alternates between checking the job status
|
|
1540
|
+
# and reading whatever is available in the logs at this point. Note, that if we were
|
|
1541
|
+
# called with wait == False, we never check the job status.
|
|
1542
|
+
#
|
|
1543
|
+
# If wait == TRUE and job is not completed, the initial state is TAILING
|
|
1544
|
+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
|
|
1545
|
+
# complete).
|
|
1546
|
+
#
|
|
1547
|
+
# The state table:
|
|
1548
|
+
#
|
|
1549
|
+
# STATE ACTIONS CONDITION NEW STATE
|
|
1550
|
+
# ---------------- ---------------- ----------------- ----------------
|
|
1551
|
+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
|
|
1552
|
+
# Else TAILING
|
|
1553
|
+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
|
|
1554
|
+
# COMPLETE Read logs, Exit N/A
|
|
1555
|
+
#
|
|
1556
|
+
# Notes:
|
|
1557
|
+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
|
|
1558
|
+
# Cloudwatch after the job was marked complete.
|
|
1559
|
+
last_describe_job_call = time.time()
|
|
1560
|
+
while True:
|
|
1561
|
+
_flush_log_streams(
|
|
1562
|
+
stream_names,
|
|
1563
|
+
instance_count,
|
|
1564
|
+
client,
|
|
1565
|
+
log_group,
|
|
1566
|
+
job_name,
|
|
1567
|
+
positions,
|
|
1568
|
+
dot,
|
|
1569
|
+
color_wrap,
|
|
1570
|
+
)
|
|
1571
|
+
if state == LogState.COMPLETE:
|
|
1572
|
+
break
|
|
1573
|
+
|
|
1574
|
+
time.sleep(poll)
|
|
1575
|
+
|
|
1576
|
+
if state == LogState.JOB_COMPLETE:
|
|
1577
|
+
state = LogState.COMPLETE
|
|
1578
|
+
elif time.time() - last_describe_job_call >= 30:
|
|
1579
|
+
description = (
|
|
1580
|
+
ProcessingJob.get(
|
|
1581
|
+
processing_job_name=job_name, session=sagemaker_session.boto_session
|
|
1582
|
+
)
|
|
1583
|
+
.refresh()
|
|
1584
|
+
.__dict__
|
|
1585
|
+
)
|
|
1586
|
+
last_describe_job_call = time.time()
|
|
1587
|
+
|
|
1588
|
+
status = description["ProcessingJobStatus"]
|
|
1589
|
+
|
|
1590
|
+
if status in ("Completed", "Failed", "Stopped"):
|
|
1591
|
+
print()
|
|
1592
|
+
state = LogState.JOB_COMPLETE
|
|
1593
|
+
|
|
1594
|
+
if wait:
|
|
1595
|
+
_check_job_status(job_name, description, "ProcessingJobStatus")
|
|
1596
|
+
if dot:
|
|
1597
|
+
print()
|