sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,945 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Amazon SageMaker Debugger provides full visibility into ML training jobs.
|
|
14
|
+
|
|
15
|
+
This module provides SageMaker Debugger high-level methods
|
|
16
|
+
to set up Debugger objects, such as Debugger built-in rules, tensor collections,
|
|
17
|
+
and hook configuration. Use the Debugger objects for parameters when constructing
|
|
18
|
+
a SageMaker estimator to initiate a training job.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
from __future__ import absolute_import
|
|
22
|
+
|
|
23
|
+
from abc import ABC
|
|
24
|
+
|
|
25
|
+
from typing import Union, Optional, List, Dict
|
|
26
|
+
|
|
27
|
+
import attr
|
|
28
|
+
|
|
29
|
+
import smdebug_rulesconfig as rule_configs
|
|
30
|
+
|
|
31
|
+
from sagemaker.core import image_uris
|
|
32
|
+
from sagemaker.core.common_utils import build_dict, name_from_base
|
|
33
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
34
|
+
from sagemaker.core.debugger.profiler_constants import (
|
|
35
|
+
DETAIL_PROF_PROCESSING_DEFAULT_INSTANCE_TYPE,
|
|
36
|
+
DETAIL_PROF_PROCESSING_DEFAULT_VOLUME_SIZE,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
framework_name = "debugger"
|
|
40
|
+
detailed_framework_name = "detailed-profiler"
|
|
41
|
+
DEBUGGER_FLAG = "USE_SMDEBUG"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DetailedProfilerProcessingJobConfig:
|
|
45
|
+
"""ProfilerRule like class.
|
|
46
|
+
|
|
47
|
+
Serves as a vehicle to pass info through to the processing instance.
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self):
|
|
52
|
+
self.rule_name = self.__class__.__name__
|
|
53
|
+
self.rule_parameters = {"rule_to_invoke": "DetailedProfilerProcessing"}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_rule_container_image_uri(name, region):
|
|
57
|
+
"""Return the Debugger rule image URI for the given AWS Region.
|
|
58
|
+
|
|
59
|
+
For a full list of rule image URIs,
|
|
60
|
+
see `Use Debugger Docker Images for Built-in or Custom Rules
|
|
61
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-docker-images-rules.html>`_.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
region (str): A string of AWS Region. For example, ``'us-east-1'``.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
str: Formatted image URI for the given AWS Region and the rule container type.
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
if name is not None and name.startswith("DetailedProfilerProcessingJobConfig"):
|
|
71
|
+
# should have the format like "123456789012.dkr.ecr.us-west-2.amazonaws.com/detailed-profiler-processing:latest"
|
|
72
|
+
return image_uris.retrieve(detailed_framework_name, region)
|
|
73
|
+
|
|
74
|
+
return image_uris.retrieve(framework_name, region)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_default_profiler_processing_job(instance_type=None, volume_size_in_gb=None):
|
|
78
|
+
"""Return the default profiler processing job (a rule) with a unique name.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
sagemaker.debugger.ProfilerRule: The instance of the built-in ProfilerRule.
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
default_rule = DetailedProfilerProcessingJobConfig()
|
|
85
|
+
custom_name = name_from_base(default_rule.rule_name)
|
|
86
|
+
return ProfilerRule.sagemaker(
|
|
87
|
+
default_rule,
|
|
88
|
+
name=custom_name,
|
|
89
|
+
instance_type=instance_type,
|
|
90
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@attr.s
|
|
95
|
+
class RuleBase(ABC):
|
|
96
|
+
"""The SageMaker Debugger rule base class that cannot be instantiated directly.
|
|
97
|
+
|
|
98
|
+
.. tip::
|
|
99
|
+
|
|
100
|
+
Debugger rule classes inheriting this RuleBase class are
|
|
101
|
+
:class:`~sagemaker.debugger.Rule` and :class:`~sagemaker.debugger.ProfilerRule`.
|
|
102
|
+
Do not directly use the rule base class to instantiate a SageMaker Debugger rule.
|
|
103
|
+
Use the :class:`~sagemaker.debugger.Rule` classmethods for debugging
|
|
104
|
+
and the :class:`~sagemaker.debugger.ProfilerRule` classmethods for profiling.
|
|
105
|
+
|
|
106
|
+
Attributes:
|
|
107
|
+
name (str): The name of the rule.
|
|
108
|
+
image_uri (str): The image URI to use the rule.
|
|
109
|
+
instance_type (str): Type of EC2 instance to use. For example, 'ml.c4.xlarge'.
|
|
110
|
+
container_local_output_path (str): The local path to store the Rule output.
|
|
111
|
+
s3_output_path (str): The location in S3 to store the output.
|
|
112
|
+
volume_size_in_gb (int): Size in GB of the EBS volume to use for storing data.
|
|
113
|
+
rule_parameters (dict): A dictionary of parameters for the rule.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
name = attr.ib()
|
|
118
|
+
image_uri = attr.ib()
|
|
119
|
+
instance_type = attr.ib()
|
|
120
|
+
container_local_output_path = attr.ib()
|
|
121
|
+
s3_output_path = attr.ib()
|
|
122
|
+
volume_size_in_gb = attr.ib()
|
|
123
|
+
rule_parameters = attr.ib()
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def _set_rule_parameters(source, rule_to_invoke, rule_parameters):
|
|
127
|
+
"""Create a dictionary of rule parameters.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
source (str): Optional. A source file containing a rule to invoke. If provided,
|
|
131
|
+
you must also provide rule_to_invoke. This can either be an S3 uri or
|
|
132
|
+
a local path.
|
|
133
|
+
rule_to_invoke (str): Optional. The name of the rule to invoke within the source.
|
|
134
|
+
If provided, you must also provide source.
|
|
135
|
+
rule_parameters (dict): Optional. A dictionary of parameters for the rule.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
dict: A dictionary of rule parameters.
|
|
139
|
+
|
|
140
|
+
"""
|
|
141
|
+
if bool(source) ^ bool(rule_to_invoke):
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"If you provide a source, you must also provide a rule to invoke (and vice versa)."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
merged_rule_params = {}
|
|
147
|
+
merged_rule_params.update(build_dict("source_s3_uri", source))
|
|
148
|
+
merged_rule_params.update(build_dict("rule_to_invoke", rule_to_invoke))
|
|
149
|
+
merged_rule_params.update(rule_parameters or {})
|
|
150
|
+
|
|
151
|
+
return merged_rule_params
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class Rule(RuleBase):
|
|
155
|
+
"""The SageMaker Debugger Rule class configures *debugging* rules to debug your training job.
|
|
156
|
+
|
|
157
|
+
The debugging rules analyze tensor outputs from your training job
|
|
158
|
+
and monitor conditions that are critical for the success of the training
|
|
159
|
+
job.
|
|
160
|
+
|
|
161
|
+
SageMaker Debugger comes pre-packaged with built-in *debugging* rules.
|
|
162
|
+
For example, the debugging rules can detect whether gradients are getting too large or
|
|
163
|
+
too small, or if a model is overfitting.
|
|
164
|
+
For a full list of built-in rules for debugging, see
|
|
165
|
+
`List of Debugger Built-in Rules
|
|
166
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html>`_.
|
|
167
|
+
You can also write your own rules using the custom rule classmethod.
|
|
168
|
+
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
name,
|
|
174
|
+
image_uri,
|
|
175
|
+
instance_type,
|
|
176
|
+
container_local_output_path,
|
|
177
|
+
s3_output_path,
|
|
178
|
+
volume_size_in_gb,
|
|
179
|
+
rule_parameters,
|
|
180
|
+
collections_to_save,
|
|
181
|
+
actions=None,
|
|
182
|
+
):
|
|
183
|
+
"""Configure the debugging rules using the following classmethods.
|
|
184
|
+
|
|
185
|
+
.. tip::
|
|
186
|
+
Use the following ``Rule.sagemaker`` class method for built-in debugging rules
|
|
187
|
+
or the ``Rule.custom`` class method for custom debugging rules.
|
|
188
|
+
Do not directly use the :class:`~sagemaker.debugger.Rule`
|
|
189
|
+
initialization method.
|
|
190
|
+
|
|
191
|
+
"""
|
|
192
|
+
super(Rule, self).__init__(
|
|
193
|
+
name,
|
|
194
|
+
image_uri,
|
|
195
|
+
instance_type,
|
|
196
|
+
container_local_output_path,
|
|
197
|
+
s3_output_path,
|
|
198
|
+
volume_size_in_gb,
|
|
199
|
+
rule_parameters,
|
|
200
|
+
)
|
|
201
|
+
self.collection_configs = collections_to_save
|
|
202
|
+
self.actions = actions
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def sagemaker(
|
|
206
|
+
cls,
|
|
207
|
+
base_config,
|
|
208
|
+
name=None,
|
|
209
|
+
container_local_output_path=None,
|
|
210
|
+
s3_output_path=None,
|
|
211
|
+
other_trials_s3_input_paths=None,
|
|
212
|
+
rule_parameters=None,
|
|
213
|
+
collections_to_save=None,
|
|
214
|
+
actions=None,
|
|
215
|
+
):
|
|
216
|
+
"""Initialize a ``Rule`` object for a *built-in* debugging rule.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
base_config (dict): Required. This is the base rule config dictionary returned from the
|
|
220
|
+
:class:`~sagemaker.debugger.rule_configs` method.
|
|
221
|
+
For example, ``rule_configs.dead_relu()``.
|
|
222
|
+
For a full list of built-in rules for debugging, see
|
|
223
|
+
`List of Debugger Built-in Rules
|
|
224
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html>`_.
|
|
225
|
+
name (str): Optional. The name of the debugger rule. If one is not provided,
|
|
226
|
+
the name of the base_config will be used.
|
|
227
|
+
container_local_output_path (str): Optional. The local path in the rule processing
|
|
228
|
+
container.
|
|
229
|
+
s3_output_path (str): Optional. The location in Amazon S3 to store the output tensors.
|
|
230
|
+
The default Debugger output path for debugging data is created under the
|
|
231
|
+
default output path of the :class:`~sagemaker.estimator.Estimator` class.
|
|
232
|
+
For example,
|
|
233
|
+
s3://sagemaker-<region>-<12digit_account_id>/<training-job-name>/debug-output/.
|
|
234
|
+
other_trials_s3_input_paths ([str]): Optional. The Amazon S3 input paths
|
|
235
|
+
of other trials to use the SimilarAcrossRuns rule.
|
|
236
|
+
rule_parameters (dict): Optional. A dictionary of parameters for the rule.
|
|
237
|
+
collections_to_save (:class:`~sagemaker.debugger.CollectionConfig`):
|
|
238
|
+
Optional. A list
|
|
239
|
+
of :class:`~sagemaker.debugger.CollectionConfig` objects to be saved.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
:class:`~sagemaker.debugger.Rule`: An instance of the built-in rule.
|
|
243
|
+
|
|
244
|
+
**Example of how to create a built-in rule instance:**
|
|
245
|
+
|
|
246
|
+
.. code-block:: python
|
|
247
|
+
|
|
248
|
+
from sagemaker.debugger import Rule, rule_configs
|
|
249
|
+
|
|
250
|
+
built_in_rules = [
|
|
251
|
+
Rule.sagemaker(rule_configs.built_in_rule_name_in_pysdk_format_1()),
|
|
252
|
+
Rule.sagemaker(rule_configs.built_in_rule_name_in_pysdk_format_2()),
|
|
253
|
+
...
|
|
254
|
+
Rule.sagemaker(rule_configs.built_in_rule_name_in_pysdk_format_n())
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
You need to replace the ``built_in_rule_name_in_pysdk_format_*`` with the
|
|
258
|
+
names of built-in rules. You can find the rule names at `List of Debugger Built-in
|
|
259
|
+
Rules <https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html>`_.
|
|
260
|
+
|
|
261
|
+
**Example of creating a built-in rule instance with adjusting parameter values:**
|
|
262
|
+
|
|
263
|
+
.. code-block:: python
|
|
264
|
+
|
|
265
|
+
from sagemaker.debugger import Rule, rule_configs
|
|
266
|
+
|
|
267
|
+
built_in_rules = [
|
|
268
|
+
Rule.sagemaker(
|
|
269
|
+
base_config=rule_configs.built_in_rule_name_in_pysdk_format(),
|
|
270
|
+
rule_parameters={
|
|
271
|
+
"key": "value"
|
|
272
|
+
}
|
|
273
|
+
collections_to_save=[
|
|
274
|
+
CollectionConfig(
|
|
275
|
+
name="tensor_collection_name",
|
|
276
|
+
parameters={
|
|
277
|
+
"key": "value"
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
]
|
|
281
|
+
)
|
|
282
|
+
]
|
|
283
|
+
|
|
284
|
+
For more information about setting up the ``rule_parameters`` parameter,
|
|
285
|
+
see `List of Debugger Built-in
|
|
286
|
+
Rules <https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html>`_.
|
|
287
|
+
|
|
288
|
+
For more information about setting up the ``collections_to_save`` parameter,
|
|
289
|
+
see the :class:`~sagemaker.debugger.CollectionConfig` class.
|
|
290
|
+
|
|
291
|
+
"""
|
|
292
|
+
merged_rule_params = {}
|
|
293
|
+
|
|
294
|
+
if rule_parameters is not None and rule_parameters.get("rule_to_invoke") is not None:
|
|
295
|
+
raise RuntimeError(
|
|
296
|
+
"""You cannot provide a 'rule_to_invoke' for SageMaker rules.
|
|
297
|
+
Either remove the rule_to_invoke or use a custom rule.
|
|
298
|
+
|
|
299
|
+
"""
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if actions is not None and not rule_configs.is_valid_action_object(actions):
|
|
303
|
+
raise RuntimeError("""`actions` must be of type `Action` or `ActionList`!""")
|
|
304
|
+
|
|
305
|
+
if other_trials_s3_input_paths is not None:
|
|
306
|
+
for index, s3_input_path in enumerate(other_trials_s3_input_paths):
|
|
307
|
+
merged_rule_params["other_trial_{}".format(str(index))] = s3_input_path
|
|
308
|
+
|
|
309
|
+
default_rule_params = base_config["DebugRuleConfiguration"].get("RuleParameters", {})
|
|
310
|
+
merged_rule_params.update(default_rule_params)
|
|
311
|
+
merged_rule_params.update(rule_parameters or {})
|
|
312
|
+
|
|
313
|
+
base_config_collections = []
|
|
314
|
+
for config in base_config.get("CollectionConfigurations", []):
|
|
315
|
+
collection_name = None
|
|
316
|
+
collection_parameters = {}
|
|
317
|
+
for key, value in config.items():
|
|
318
|
+
if key == "CollectionName":
|
|
319
|
+
collection_name = value
|
|
320
|
+
if key == "CollectionParameters":
|
|
321
|
+
collection_parameters = value
|
|
322
|
+
base_config_collections.append(
|
|
323
|
+
CollectionConfig(name=collection_name, parameters=collection_parameters)
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return cls(
|
|
327
|
+
name=name or base_config["DebugRuleConfiguration"].get("RuleConfigurationName"),
|
|
328
|
+
image_uri="DEFAULT_RULE_EVALUATOR_IMAGE",
|
|
329
|
+
instance_type=None,
|
|
330
|
+
container_local_output_path=container_local_output_path,
|
|
331
|
+
s3_output_path=s3_output_path,
|
|
332
|
+
volume_size_in_gb=None,
|
|
333
|
+
rule_parameters=merged_rule_params,
|
|
334
|
+
collections_to_save=collections_to_save or base_config_collections,
|
|
335
|
+
actions=actions,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
@classmethod
|
|
339
|
+
def custom(
|
|
340
|
+
cls,
|
|
341
|
+
name: str,
|
|
342
|
+
image_uri: Union[str, PipelineVariable],
|
|
343
|
+
instance_type: Union[str, PipelineVariable],
|
|
344
|
+
volume_size_in_gb: Union[int, PipelineVariable],
|
|
345
|
+
source: Optional[str] = None,
|
|
346
|
+
rule_to_invoke: Optional[Union[str, PipelineVariable]] = None,
|
|
347
|
+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
|
|
348
|
+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
|
|
349
|
+
other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
350
|
+
rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
351
|
+
collections_to_save: Optional[List["CollectionConfig"]] = None,
|
|
352
|
+
actions=None,
|
|
353
|
+
):
|
|
354
|
+
"""Initialize a ``Rule`` object for a *custom* debugging rule.
|
|
355
|
+
|
|
356
|
+
You can create a custom rule that analyzes tensors emitted
|
|
357
|
+
during the training of a model
|
|
358
|
+
and monitors conditions that are critical for the success of a training
|
|
359
|
+
job. For more information, see `Create Debugger Custom Rules for Training Job
|
|
360
|
+
Analysis
|
|
361
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-custom-rules.html>`_.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
name (str): Required. The name of the debugger rule.
|
|
365
|
+
image_uri (str or PipelineVariable): Required. The URI of the image to
|
|
366
|
+
be used by the debugger rule.
|
|
367
|
+
instance_type (str or PipelineVariable): Required. Type of EC2 instance to use,
|
|
368
|
+
for example, 'ml.c4.xlarge'.
|
|
369
|
+
volume_size_in_gb (int or PipelineVariable): Required. Size in GB of the
|
|
370
|
+
EBS volume to use for storing data.
|
|
371
|
+
source (str): Optional. A source file containing a rule to invoke. If provided,
|
|
372
|
+
you must also provide rule_to_invoke. This can either be an S3 uri or
|
|
373
|
+
a local path.
|
|
374
|
+
rule_to_invoke (str or PipelineVariable): Optional. The name of the rule to
|
|
375
|
+
invoke within the source. If provided, you must also provide source.
|
|
376
|
+
container_local_output_path (str or PipelineVariable): Optional. The local path
|
|
377
|
+
in the container.
|
|
378
|
+
s3_output_path (str or PipelineVariable): Optional. The location in Amazon S3
|
|
379
|
+
to store the output tensors.
|
|
380
|
+
The default Debugger output path for debugging data is created under the
|
|
381
|
+
default output path of the :class:`~sagemaker.estimator.Estimator` class.
|
|
382
|
+
For example,
|
|
383
|
+
s3://sagemaker-<region>-<12digit_account_id>/<training-job-name>/debug-output/.
|
|
384
|
+
other_trials_s3_input_paths (list[str] or list[PipelineVariable]: Optional.
|
|
385
|
+
The Amazon S3 input paths of other trials to use the SimilarAcrossRuns rule.
|
|
386
|
+
rule_parameters (dict[str, str] or dict[str, PipelineVariable]): Optional.
|
|
387
|
+
A dictionary of parameters for the rule.
|
|
388
|
+
collections_to_save ([sagemaker.debugger.CollectionConfig]): Optional. A list
|
|
389
|
+
of :class:`~sagemaker.debugger.CollectionConfig` objects to be saved.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
:class:`~sagemaker.debugger.Rule`: The instance of the custom rule.
|
|
393
|
+
|
|
394
|
+
"""
|
|
395
|
+
if actions is not None and not rule_configs.is_valid_action_object(actions):
|
|
396
|
+
raise RuntimeError("""`actions` must be of type `Action` or `ActionList`!""")
|
|
397
|
+
|
|
398
|
+
merged_rule_params = cls._set_rule_parameters(
|
|
399
|
+
source, rule_to_invoke, other_trials_s3_input_paths, rule_parameters
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
return cls(
|
|
403
|
+
name=name,
|
|
404
|
+
image_uri=image_uri,
|
|
405
|
+
instance_type=instance_type,
|
|
406
|
+
container_local_output_path=container_local_output_path,
|
|
407
|
+
s3_output_path=s3_output_path,
|
|
408
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
409
|
+
rule_parameters=merged_rule_params,
|
|
410
|
+
collections_to_save=collections_to_save or [],
|
|
411
|
+
actions=actions,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def prepare_actions(self, training_job_name):
|
|
415
|
+
"""Prepare actions for Debugger Rule.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
training_job_name (str): The training job name. To be set as the default training job
|
|
419
|
+
prefix for the StopTraining action if it is specified.
|
|
420
|
+
"""
|
|
421
|
+
if self.actions is None:
|
|
422
|
+
# user cannot manually specify action_json in rule_parameters for actions.
|
|
423
|
+
self.rule_parameters.pop("action_json", None)
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
self.actions.update_training_job_prefix_if_not_specified(training_job_name)
|
|
427
|
+
action_params = {"action_json": self.actions.serialize()}
|
|
428
|
+
self.rule_parameters.update(action_params)
|
|
429
|
+
|
|
430
|
+
@staticmethod
|
|
431
|
+
def _set_rule_parameters(source, rule_to_invoke, other_trials_s3_input_paths, rule_parameters):
|
|
432
|
+
"""Set rule parameters for Debugger Rule.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
source (str): Optional. A source file containing a rule to invoke. If provided,
|
|
436
|
+
you must also provide rule_to_invoke. This can either be an S3 uri or
|
|
437
|
+
a local path.
|
|
438
|
+
rule_to_invoke (str): Optional. The name of the rule to invoke within the source.
|
|
439
|
+
If provided, you must also provide source.
|
|
440
|
+
other_trials_s3_input_paths ([str]): Optional. S3 input paths for other trials.
|
|
441
|
+
rule_parameters (dict): Optional. A dictionary of parameters for the rule.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
dict: A dictionary of rule parameters.
|
|
445
|
+
|
|
446
|
+
"""
|
|
447
|
+
merged_rule_params = {}
|
|
448
|
+
if other_trials_s3_input_paths is not None:
|
|
449
|
+
for index, s3_input_path in enumerate(other_trials_s3_input_paths):
|
|
450
|
+
merged_rule_params["other_trial_{}".format(str(index))] = s3_input_path
|
|
451
|
+
|
|
452
|
+
merged_rule_params.update(
|
|
453
|
+
super(Rule, Rule)._set_rule_parameters(source, rule_to_invoke, rule_parameters)
|
|
454
|
+
)
|
|
455
|
+
return merged_rule_params
|
|
456
|
+
|
|
457
|
+
def to_debugger_rule_config_dict(self):
|
|
458
|
+
"""Generates a request dictionary using the parameters provided when initializing object.
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
dict: An portion of an API request as a dictionary.
|
|
462
|
+
|
|
463
|
+
"""
|
|
464
|
+
debugger_rule_config_request = {
|
|
465
|
+
"RuleConfigurationName": self.name,
|
|
466
|
+
"RuleEvaluatorImage": self.image_uri,
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
debugger_rule_config_request.update(build_dict("InstanceType", self.instance_type))
|
|
470
|
+
debugger_rule_config_request.update(build_dict("VolumeSizeInGB", self.volume_size_in_gb))
|
|
471
|
+
debugger_rule_config_request.update(
|
|
472
|
+
build_dict("LocalPath", self.container_local_output_path)
|
|
473
|
+
)
|
|
474
|
+
debugger_rule_config_request.update(build_dict("S3OutputPath", self.s3_output_path))
|
|
475
|
+
debugger_rule_config_request.update(build_dict("RuleParameters", self.rule_parameters))
|
|
476
|
+
|
|
477
|
+
return debugger_rule_config_request
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class ProfilerRule(RuleBase):
|
|
481
|
+
"""The SageMaker Debugger ProfilerRule class configures *profiling* rules.
|
|
482
|
+
|
|
483
|
+
SageMaker Debugger profiling rules automatically analyze
|
|
484
|
+
hardware system resource utilization and framework metrics of a
|
|
485
|
+
training job to identify performance bottlenecks.
|
|
486
|
+
|
|
487
|
+
SageMaker Debugger comes pre-packaged with built-in *profiling* rules.
|
|
488
|
+
For example, the profiling rules can detect if GPUs are underutilized due to CPU bottlenecks or
|
|
489
|
+
IO bottlenecks.
|
|
490
|
+
For a full list of built-in rules for debugging, see
|
|
491
|
+
`List of Debugger Built-in Rules <https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html>`_.
|
|
492
|
+
You can also write your own profiling rules using the Amazon SageMaker
|
|
493
|
+
Debugger APIs.
|
|
494
|
+
|
|
495
|
+
.. tip::
|
|
496
|
+
Use the following ``ProfilerRule.sagemaker`` class method for built-in profiling rules
|
|
497
|
+
or the ``ProfilerRule.custom`` class method for custom profiling rules.
|
|
498
|
+
Do not directly use the `Rule` initialization method.
|
|
499
|
+
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
@classmethod
|
|
503
|
+
def sagemaker(
|
|
504
|
+
cls,
|
|
505
|
+
base_config,
|
|
506
|
+
name=None,
|
|
507
|
+
container_local_output_path=None,
|
|
508
|
+
s3_output_path=None,
|
|
509
|
+
instance_type=None,
|
|
510
|
+
volume_size_in_gb=None,
|
|
511
|
+
):
|
|
512
|
+
"""Initialize a ``ProfilerRule`` object for a *built-in* profiling rule.
|
|
513
|
+
|
|
514
|
+
The rule analyzes system and framework metrics of a given
|
|
515
|
+
training job to identify performance bottlenecks.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
base_config (rule_configs.ProfilerRule): The base rule configuration object
|
|
519
|
+
returned from the ``rule_configs`` method.
|
|
520
|
+
For example, 'rule_configs.ProfilerReport()'.
|
|
521
|
+
For a full list of built-in rules for debugging, see
|
|
522
|
+
`List of Debugger Built-in Rules
|
|
523
|
+
<https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html>`_.
|
|
524
|
+
|
|
525
|
+
name (str): The name of the profiler rule. If one is not provided,
|
|
526
|
+
the name of the base_config will be used.
|
|
527
|
+
container_local_output_path (str): The path in the container.
|
|
528
|
+
s3_output_path (str): The location in Amazon S3 to store the profiling output data.
|
|
529
|
+
The default Debugger output path for profiling data is created under the
|
|
530
|
+
default output path of the :class:`~sagemaker.estimator.Estimator` class.
|
|
531
|
+
For example,
|
|
532
|
+
s3://sagemaker-<region>-<12digit_account_id>/<training-job-name>/profiler-output/.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
:class:`~sagemaker.debugger.ProfilerRule`:
|
|
536
|
+
The instance of the built-in ProfilerRule.
|
|
537
|
+
|
|
538
|
+
"""
|
|
539
|
+
used_name = name or base_config.rule_name
|
|
540
|
+
if used_name.startswith("DetailedProfilerProcessingJobConfig"):
|
|
541
|
+
if volume_size_in_gb is None:
|
|
542
|
+
volume_size_in_gb = DETAIL_PROF_PROCESSING_DEFAULT_VOLUME_SIZE
|
|
543
|
+
if instance_type is None:
|
|
544
|
+
instance_type = DETAIL_PROF_PROCESSING_DEFAULT_INSTANCE_TYPE
|
|
545
|
+
return cls(
|
|
546
|
+
name=used_name,
|
|
547
|
+
image_uri="DEFAULT_RULE_EVALUATOR_IMAGE",
|
|
548
|
+
instance_type=instance_type,
|
|
549
|
+
container_local_output_path=container_local_output_path,
|
|
550
|
+
s3_output_path=s3_output_path,
|
|
551
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
552
|
+
rule_parameters=base_config.rule_parameters,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
@classmethod
|
|
556
|
+
def custom(
|
|
557
|
+
cls,
|
|
558
|
+
name,
|
|
559
|
+
image_uri,
|
|
560
|
+
instance_type,
|
|
561
|
+
volume_size_in_gb,
|
|
562
|
+
source=None,
|
|
563
|
+
rule_to_invoke=None,
|
|
564
|
+
container_local_output_path=None,
|
|
565
|
+
s3_output_path=None,
|
|
566
|
+
rule_parameters=None,
|
|
567
|
+
):
|
|
568
|
+
"""Initialize a ``ProfilerRule`` object for a *custom* profiling rule.
|
|
569
|
+
|
|
570
|
+
You can create a rule that
|
|
571
|
+
analyzes system and framework metrics emitted during the training of a model and
|
|
572
|
+
monitors conditions that are critical for the success of a
|
|
573
|
+
training job.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
name (str): The name of the profiler rule.
|
|
577
|
+
image_uri (str): The URI of the image to be used by the proflier rule.
|
|
578
|
+
instance_type (str): Type of EC2 instance to use, for example,
|
|
579
|
+
'ml.c4.xlarge'.
|
|
580
|
+
volume_size_in_gb (int): Size in GB of the EBS volume
|
|
581
|
+
to use for storing data.
|
|
582
|
+
source (str): A source file containing a rule to invoke. If provided,
|
|
583
|
+
you must also provide rule_to_invoke. This can either be an S3 uri or
|
|
584
|
+
a local path.
|
|
585
|
+
rule_to_invoke (str): The name of the rule to invoke within the source.
|
|
586
|
+
If provided, you must also provide the source.
|
|
587
|
+
container_local_output_path (str): The path in the container.
|
|
588
|
+
s3_output_path (str): The location in Amazon S3 to store the output.
|
|
589
|
+
The default Debugger output path for profiling data is created under the
|
|
590
|
+
default output path of the :class:`~sagemaker.estimator.Estimator` class.
|
|
591
|
+
For example,
|
|
592
|
+
s3://sagemaker-<region>-<12digit_account_id>/<training-job-name>/profiler-output/.
|
|
593
|
+
rule_parameters (dict): A dictionary of parameters for the rule.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
:class:`~sagemaker.debugger.ProfilerRule`:
|
|
597
|
+
The instance of the custom ProfilerRule.
|
|
598
|
+
|
|
599
|
+
"""
|
|
600
|
+
merged_rule_params = super()._set_rule_parameters(source, rule_to_invoke, rule_parameters)
|
|
601
|
+
|
|
602
|
+
return cls(
|
|
603
|
+
name=name,
|
|
604
|
+
image_uri=image_uri,
|
|
605
|
+
instance_type=instance_type,
|
|
606
|
+
container_local_output_path=container_local_output_path,
|
|
607
|
+
s3_output_path=s3_output_path,
|
|
608
|
+
volume_size_in_gb=volume_size_in_gb,
|
|
609
|
+
rule_parameters=merged_rule_params,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def to_profiler_rule_config_dict(self):
|
|
613
|
+
"""Generates a request dictionary using the parameters provided when initializing object.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
dict: An portion of an API request as a dictionary.
|
|
617
|
+
|
|
618
|
+
"""
|
|
619
|
+
profiler_rule_config_request = {
|
|
620
|
+
"RuleConfigurationName": self.name,
|
|
621
|
+
"RuleEvaluatorImage": self.image_uri,
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
profiler_rule_config_request.update(build_dict("InstanceType", self.instance_type))
|
|
625
|
+
profiler_rule_config_request.update(build_dict("VolumeSizeInGB", self.volume_size_in_gb))
|
|
626
|
+
profiler_rule_config_request.update(
|
|
627
|
+
build_dict("LocalPath", self.container_local_output_path)
|
|
628
|
+
)
|
|
629
|
+
profiler_rule_config_request.update(build_dict("S3OutputPath", self.s3_output_path))
|
|
630
|
+
|
|
631
|
+
if self.rule_parameters:
|
|
632
|
+
profiler_rule_config_request["RuleParameters"] = self.rule_parameters
|
|
633
|
+
for k, v in profiler_rule_config_request["RuleParameters"].items():
|
|
634
|
+
profiler_rule_config_request["RuleParameters"][k] = str(v)
|
|
635
|
+
|
|
636
|
+
return profiler_rule_config_request
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class DebuggerHookConfig(object):
|
|
640
|
+
"""Create a Debugger hook configuration object to save the tensor for debugging.
|
|
641
|
+
|
|
642
|
+
DebuggerHookConfig provides options to customize how debugging
|
|
643
|
+
information is emitted and saved. This high-level DebuggerHookConfig class
|
|
644
|
+
runs based on the `smdebug.SaveConfig
|
|
645
|
+
<https://github.com/awslabs/sagemaker-debugger/blob/master/docs/
|
|
646
|
+
api.md#saveconfig>`_ class.
|
|
647
|
+
|
|
648
|
+
"""
|
|
649
|
+
|
|
650
|
+
def __init__(
|
|
651
|
+
self,
|
|
652
|
+
s3_output_path: Optional[Union[str, PipelineVariable]] = None,
|
|
653
|
+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
|
|
654
|
+
hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
655
|
+
collection_configs: Optional[List["CollectionConfig"]] = None,
|
|
656
|
+
):
|
|
657
|
+
"""Initialize the DebuggerHookConfig instance.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
s3_output_path (str or PipelineVariable): Optional. The location in Amazon S3 to
|
|
661
|
+
store the output tensors. The default Debugger output path is created under the
|
|
662
|
+
default output path of the :class:`~sagemaker.estimator.Estimator` class.
|
|
663
|
+
For example,
|
|
664
|
+
s3://sagemaker-<region>-<12digit_account_id>/<training-job-name>/debug-output/.
|
|
665
|
+
container_local_output_path (str or PipelineVariable): Optional. The local path
|
|
666
|
+
in the container.
|
|
667
|
+
hook_parameters (dict[str, str] or dict[str, PipelineVariable]): Optional.
|
|
668
|
+
A dictionary of parameters.
|
|
669
|
+
collection_configs ([sagemaker.debugger.CollectionConfig]): Required. A list
|
|
670
|
+
of :class:`~sagemaker.debugger.CollectionConfig` objects to be saved
|
|
671
|
+
at the **s3_output_path**.
|
|
672
|
+
|
|
673
|
+
**Example of creating a DebuggerHookConfig object:**
|
|
674
|
+
|
|
675
|
+
.. code-block:: python
|
|
676
|
+
|
|
677
|
+
from sagemaker.debugger import CollectionConfig, DebuggerHookConfig
|
|
678
|
+
|
|
679
|
+
collection_configs=[
|
|
680
|
+
CollectionConfig(name="tensor_collection_1")
|
|
681
|
+
CollectionConfig(name="tensor_collection_2")
|
|
682
|
+
...
|
|
683
|
+
CollectionConfig(name="tensor_collection_n")
|
|
684
|
+
]
|
|
685
|
+
|
|
686
|
+
hook_config = DebuggerHookConfig(
|
|
687
|
+
collection_configs=collection_configs
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
"""
|
|
691
|
+
self.s3_output_path = s3_output_path
|
|
692
|
+
self.container_local_output_path = container_local_output_path
|
|
693
|
+
self.hook_parameters = hook_parameters
|
|
694
|
+
self.collection_configs = collection_configs
|
|
695
|
+
|
|
696
|
+
def _to_request_dict(self):
|
|
697
|
+
"""Generate a request dictionary using the parameters when initializing the object.
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
dict: An portion of an API request as a dictionary.
|
|
701
|
+
|
|
702
|
+
"""
|
|
703
|
+
debugger_hook_config_request = {"S3OutputPath": self.s3_output_path}
|
|
704
|
+
|
|
705
|
+
if self.container_local_output_path is not None:
|
|
706
|
+
debugger_hook_config_request["LocalPath"] = self.container_local_output_path
|
|
707
|
+
|
|
708
|
+
if self.hook_parameters is not None:
|
|
709
|
+
debugger_hook_config_request["HookParameters"] = self.hook_parameters
|
|
710
|
+
|
|
711
|
+
if self.collection_configs is not None:
|
|
712
|
+
debugger_hook_config_request["CollectionConfigurations"] = [
|
|
713
|
+
collection_config._to_request_dict()
|
|
714
|
+
for collection_config in self.collection_configs
|
|
715
|
+
]
|
|
716
|
+
|
|
717
|
+
return debugger_hook_config_request
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
class TensorBoardOutputConfig(object):
|
|
721
|
+
"""Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
|
|
722
|
+
|
|
723
|
+
def __init__(
|
|
724
|
+
self,
|
|
725
|
+
s3_output_path: Union[str, PipelineVariable],
|
|
726
|
+
container_local_output_path: Optional[Union[str, PipelineVariable]] = None,
|
|
727
|
+
):
|
|
728
|
+
"""Initialize the TensorBoardOutputConfig instance.
|
|
729
|
+
|
|
730
|
+
Args:
|
|
731
|
+
s3_output_path (str or PipelineVariable): Optional. The location in Amazon S3
|
|
732
|
+
to store the output.
|
|
733
|
+
container_local_output_path (str or PipelineVariable): Optional. The local path
|
|
734
|
+
in the container.
|
|
735
|
+
|
|
736
|
+
"""
|
|
737
|
+
self.s3_output_path = s3_output_path
|
|
738
|
+
self.container_local_output_path = container_local_output_path
|
|
739
|
+
|
|
740
|
+
def _to_request_dict(self):
|
|
741
|
+
"""Generate a request dictionary using the instances attributes.
|
|
742
|
+
|
|
743
|
+
Returns:
|
|
744
|
+
dict: An portion of an API request as a dictionary.
|
|
745
|
+
|
|
746
|
+
"""
|
|
747
|
+
tensorboard_output_config_request = {"S3OutputPath": self.s3_output_path}
|
|
748
|
+
|
|
749
|
+
if self.container_local_output_path is not None:
|
|
750
|
+
tensorboard_output_config_request["LocalPath"] = self.container_local_output_path
|
|
751
|
+
|
|
752
|
+
return tensorboard_output_config_request
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
class CollectionConfig(object):
|
|
756
|
+
"""Creates tensor collections for SageMaker Debugger."""
|
|
757
|
+
|
|
758
|
+
def __init__(
|
|
759
|
+
self,
|
|
760
|
+
name: Union[str, PipelineVariable],
|
|
761
|
+
parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
|
|
762
|
+
):
|
|
763
|
+
"""Constructor for collection configuration.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
name (str or PipelineVariable): Required. The name of the collection configuration.
|
|
767
|
+
parameters (dict[str, str] or dict[str, PipelineVariable]): Optional. The parameters
|
|
768
|
+
for the collection configuration.
|
|
769
|
+
|
|
770
|
+
**Example of creating a CollectionConfig object:**
|
|
771
|
+
|
|
772
|
+
.. code-block:: python
|
|
773
|
+
|
|
774
|
+
from sagemaker.debugger import CollectionConfig
|
|
775
|
+
|
|
776
|
+
collection_configs=[
|
|
777
|
+
CollectionConfig(name="tensor_collection_1")
|
|
778
|
+
CollectionConfig(name="tensor_collection_2")
|
|
779
|
+
...
|
|
780
|
+
CollectionConfig(name="tensor_collection_n")
|
|
781
|
+
]
|
|
782
|
+
|
|
783
|
+
For a full list of Debugger built-in collection, see
|
|
784
|
+
`Debugger Built in Collections
|
|
785
|
+
<https://github.com/awslabs/sagemaker-debugger/blob/master
|
|
786
|
+
/docs/api.md#built-in-collections>`_.
|
|
787
|
+
|
|
788
|
+
**Example of creating a CollectionConfig object with parameter adjustment:**
|
|
789
|
+
|
|
790
|
+
You can use the following CollectionConfig template in two ways:
|
|
791
|
+
(1) to adjust the parameters of the built-in tensor collections,
|
|
792
|
+
and (2) to create custom tensor collections.
|
|
793
|
+
|
|
794
|
+
If you put the built-in collection names to the ``name`` parameter,
|
|
795
|
+
``CollectionConfig`` takes it to match the built-in collections and adjust parameters.
|
|
796
|
+
If you specify a new name to the ``name`` parameter,
|
|
797
|
+
``CollectionConfig`` creates a new tensor collection, and you must use
|
|
798
|
+
``include_regex`` parameter to specify regex of tensors you want to collect.
|
|
799
|
+
|
|
800
|
+
.. code-block:: python
|
|
801
|
+
|
|
802
|
+
from sagemaker.debugger import CollectionConfig
|
|
803
|
+
|
|
804
|
+
collection_configs=[
|
|
805
|
+
CollectionConfig(
|
|
806
|
+
name="tensor_collection",
|
|
807
|
+
parameters={
|
|
808
|
+
"key_1": "value_1",
|
|
809
|
+
"key_2": "value_2"
|
|
810
|
+
...
|
|
811
|
+
"key_n": "value_n"
|
|
812
|
+
}
|
|
813
|
+
)
|
|
814
|
+
]
|
|
815
|
+
|
|
816
|
+
The following list shows the available CollectionConfig parameters.
|
|
817
|
+
|
|
818
|
+
+--------------------------+---------------------------------------------------------+
|
|
819
|
+
| Parameter Key | Descriptions |
|
|
820
|
+
+==========================+=========================================================+
|
|
821
|
+
|``include_regex`` | Specify a list of regex patterns of tensors to save. |
|
|
822
|
+
| | |
|
|
823
|
+
| | Tensors whose names match these patterns will be saved.|
|
|
824
|
+
+--------------------------+---------------------------------------------------------+
|
|
825
|
+
|``save_histogram`` | Set *True* if want to save histogram output data for |
|
|
826
|
+
| | |
|
|
827
|
+
| | TensorFlow visualization. |
|
|
828
|
+
+--------------------------+---------------------------------------------------------+
|
|
829
|
+
|``reductions`` | Specify certain reduction values of tensors. |
|
|
830
|
+
| | |
|
|
831
|
+
| | This helps reduce the amount of data saved and |
|
|
832
|
+
| | |
|
|
833
|
+
| | increase training speed. |
|
|
834
|
+
| | |
|
|
835
|
+
| | Available values are ``min``, ``max``, ``median``, |
|
|
836
|
+
| | |
|
|
837
|
+
| | ``mean``, ``std``, ``variance``, ``sum``, and ``prod``.|
|
|
838
|
+
+--------------------------+---------------------------------------------------------+
|
|
839
|
+
|``save_interval`` | Specify how often to save tensors in steps. |
|
|
840
|
+
| | |
|
|
841
|
+
|``train.save_interval`` | You can also specify the save intervals |
|
|
842
|
+
| | |
|
|
843
|
+
|``eval.save_interval`` | in TRAIN, EVAL, PREDICT, and GLOBAL modes. |
|
|
844
|
+
| | |
|
|
845
|
+
|``predict.save_interval`` | The default value is 500 steps. |
|
|
846
|
+
| | |
|
|
847
|
+
|``global.save_interval`` | |
|
|
848
|
+
+--------------------------+---------------------------------------------------------+
|
|
849
|
+
|``save_steps`` | Specify the exact step numbers to save tensors. |
|
|
850
|
+
| | |
|
|
851
|
+
|``train.save_steps`` | You can also specify the save steps |
|
|
852
|
+
| | |
|
|
853
|
+
|``eval.save_steps`` | in TRAIN, EVAL, PREDICT, and GLOBAL modes. |
|
|
854
|
+
| | |
|
|
855
|
+
|``predict.save_steps`` | |
|
|
856
|
+
| | |
|
|
857
|
+
|``global.save_steps`` | |
|
|
858
|
+
+--------------------------+---------------------------------------------------------+
|
|
859
|
+
|``start_step`` | Specify the exact start step to save tensors. |
|
|
860
|
+
| | |
|
|
861
|
+
|``train.start_step`` | You can also specify the start steps |
|
|
862
|
+
| | |
|
|
863
|
+
|``eval.start_step`` | in TRAIN, EVAL, PREDICT, and GLOBAL modes. |
|
|
864
|
+
| | |
|
|
865
|
+
|``predict.start_step`` | |
|
|
866
|
+
| | |
|
|
867
|
+
|``global.start_step`` | |
|
|
868
|
+
+--------------------------+---------------------------------------------------------+
|
|
869
|
+
|``end_step`` | Specify the exact end step to save tensors. |
|
|
870
|
+
| | |
|
|
871
|
+
|``train.end_step`` | You can also specify the end steps |
|
|
872
|
+
| | |
|
|
873
|
+
|``eval.end_step`` | in TRAIN, EVAL, PREDICT, and GLOBAL modes. |
|
|
874
|
+
| | |
|
|
875
|
+
|``predict.end_step`` | |
|
|
876
|
+
| | |
|
|
877
|
+
|``global.end_step`` | |
|
|
878
|
+
+--------------------------+---------------------------------------------------------+
|
|
879
|
+
|
|
880
|
+
For example, the following code shows how to control the save_interval parameters
|
|
881
|
+
of the built-in ``losses`` tensor collection. With the following collection configuration,
|
|
882
|
+
Debugger collects loss values every 100 steps from training loops and every 10 steps
|
|
883
|
+
from evaluation loops.
|
|
884
|
+
|
|
885
|
+
.. code-block:: python
|
|
886
|
+
|
|
887
|
+
collection_configs=[
|
|
888
|
+
CollectionConfig(
|
|
889
|
+
name="losses",
|
|
890
|
+
parameters={
|
|
891
|
+
"train.save_interval": "100",
|
|
892
|
+
"eval.save_interval": "10"
|
|
893
|
+
}
|
|
894
|
+
)
|
|
895
|
+
]
|
|
896
|
+
|
|
897
|
+
"""
|
|
898
|
+
self.name = name
|
|
899
|
+
self.parameters = parameters
|
|
900
|
+
|
|
901
|
+
def __eq__(self, other):
|
|
902
|
+
"""Equal method override.
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
other: Object to test equality against.
|
|
906
|
+
|
|
907
|
+
"""
|
|
908
|
+
if not isinstance(other, CollectionConfig):
|
|
909
|
+
raise TypeError(
|
|
910
|
+
"CollectionConfig is only comparable with other CollectionConfig objects."
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
return self.name == other.name and self.parameters == other.parameters
|
|
914
|
+
|
|
915
|
+
def __ne__(self, other):
|
|
916
|
+
"""Not-equal method override.
|
|
917
|
+
|
|
918
|
+
Args:
|
|
919
|
+
other: Object to test equality against.
|
|
920
|
+
|
|
921
|
+
"""
|
|
922
|
+
if not isinstance(other, CollectionConfig):
|
|
923
|
+
raise TypeError(
|
|
924
|
+
"CollectionConfig is only comparable with other CollectionConfig objects."
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
return self.name != other.name or self.parameters != other.parameters
|
|
928
|
+
|
|
929
|
+
def __hash__(self):
|
|
930
|
+
"""Hash method override."""
|
|
931
|
+
return hash((self.name, tuple(sorted((self.parameters or {}).items()))))
|
|
932
|
+
|
|
933
|
+
def _to_request_dict(self):
|
|
934
|
+
"""Generate a request dictionary using the parameters initializing the object.
|
|
935
|
+
|
|
936
|
+
Returns:
|
|
937
|
+
dict: A portion of an API request as a dictionary.
|
|
938
|
+
|
|
939
|
+
"""
|
|
940
|
+
collection_config_request = {"CollectionName": self.name}
|
|
941
|
+
|
|
942
|
+
if self.parameters is not None:
|
|
943
|
+
collection_config_request["CollectionParameters"] = self.parameters
|
|
944
|
+
|
|
945
|
+
return collection_config_request
|