sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Contains the helper classes for SageMaker Experiment."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
import botocore
|
|
21
|
+
|
|
22
|
+
from sagemaker.core import s3
|
|
23
|
+
from sagemaker.core.experiments._utils import is_already_exist_error
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts"
|
|
29
|
+
_DEFAULT_ARTIFACT_TYPE = "Tracker"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class _ArtifactUploader(object):
|
|
33
|
+
"""Artifact uploader"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
trial_component_name,
|
|
38
|
+
sagemaker_session,
|
|
39
|
+
artifact_bucket=None,
|
|
40
|
+
artifact_prefix=_DEFAULT_ARTIFACT_PREFIX,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize a `_ArtifactUploader` instance.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
trial_component_name (str): The name of the trial component,
|
|
46
|
+
which is used to generate the S3 path to upload the artifact to.
|
|
47
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
48
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
49
|
+
AWS services needed.
|
|
50
|
+
artifact_bucket (str): The S3 bucket to upload the artifact to.
|
|
51
|
+
If not specified, the default bucket defined in `sagemaker_session`
|
|
52
|
+
will be used.
|
|
53
|
+
artifact_prefix (str): The S3 key prefix used to generate the S3 path
|
|
54
|
+
to upload the artifact to (default: "trial-component-artifacts").
|
|
55
|
+
"""
|
|
56
|
+
self.sagemaker_session = sagemaker_session
|
|
57
|
+
self.trial_component_name = trial_component_name
|
|
58
|
+
self.artifact_bucket = artifact_bucket
|
|
59
|
+
self.artifact_prefix = artifact_prefix
|
|
60
|
+
self._s3_client = self.sagemaker_session.boto_session.client("s3")
|
|
61
|
+
|
|
62
|
+
def upload_artifact(self, file_path, extra_args=None):
|
|
63
|
+
"""Upload an artifact file to S3.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
file_path (str): the file path of the artifact
|
|
67
|
+
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
|
|
68
|
+
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
|
|
69
|
+
ExtraArgs parameter documentation here:
|
|
70
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
(str, str): The s3 URI of the uploaded file and the etag of the file.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ValueError: If file does not exist.
|
|
77
|
+
"""
|
|
78
|
+
file_path = os.path.expanduser(file_path)
|
|
79
|
+
if not os.path.isfile(file_path):
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"{} does not exist or is not a file. Please supply a file path.".format(file_path)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket.
|
|
85
|
+
# In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix
|
|
86
|
+
# needs to be updated too (because not updating self.artifact_prefix would result in
|
|
87
|
+
# different behavior the 1st time this method is called vs the 2nd).
|
|
88
|
+
self.artifact_bucket, self.artifact_prefix = s3.determine_bucket_and_prefix(
|
|
89
|
+
bucket=self.artifact_bucket,
|
|
90
|
+
key_prefix=self.artifact_prefix,
|
|
91
|
+
sagemaker_session=self.sagemaker_session,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
artifact_name = os.path.basename(file_path)
|
|
95
|
+
artifact_s3_key = "{}/{}/{}".format(
|
|
96
|
+
self.artifact_prefix, self.trial_component_name, artifact_name
|
|
97
|
+
)
|
|
98
|
+
self._s3_client.upload_file(
|
|
99
|
+
file_path,
|
|
100
|
+
self.artifact_bucket,
|
|
101
|
+
artifact_s3_key,
|
|
102
|
+
ExtraArgs=extra_args,
|
|
103
|
+
)
|
|
104
|
+
etag = self._try_get_etag(artifact_s3_key)
|
|
105
|
+
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
|
|
106
|
+
|
|
107
|
+
def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None):
|
|
108
|
+
"""Upload an artifact object to S3.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
artifact_name (str): the name of the artifact.
|
|
112
|
+
artifact_object (obj): the object of the artifact
|
|
113
|
+
file_extension (str): Optional file extension.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
str: The s3 URI of the uploaded file and the version of the file.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
# If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket.
|
|
120
|
+
# In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix
|
|
121
|
+
# needs to be updated too (because not updating self.artifact_prefix would result in
|
|
122
|
+
# different behavior the 1st time this method is called vs the 2nd).
|
|
123
|
+
self.artifact_bucket, self.artifact_prefix = s3.determine_bucket_and_prefix(
|
|
124
|
+
bucket=self.artifact_bucket,
|
|
125
|
+
key_prefix=self.artifact_prefix,
|
|
126
|
+
sagemaker_session=self.sagemaker_session,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if file_extension:
|
|
130
|
+
artifact_name = (
|
|
131
|
+
artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension
|
|
132
|
+
)
|
|
133
|
+
artifact_s3_key = "{}/{}/{}".format(
|
|
134
|
+
self.artifact_prefix, self.trial_component_name, artifact_name
|
|
135
|
+
)
|
|
136
|
+
self._s3_client.put_object(
|
|
137
|
+
Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key
|
|
138
|
+
)
|
|
139
|
+
etag = self._try_get_etag(artifact_s3_key)
|
|
140
|
+
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
|
|
141
|
+
|
|
142
|
+
def _try_get_etag(self, key):
|
|
143
|
+
"""Get ETag of given key and return None if not allowed
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
key (str): The S3 object key.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
str: The S3 object ETag if it allows, otherwise return None.
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key)
|
|
153
|
+
return response["ETag"]
|
|
154
|
+
except botocore.exceptions.ClientError as error:
|
|
155
|
+
# requires read permissions
|
|
156
|
+
logger.warning("Failed to get ETag of %s due to %s", key, error)
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class _LineageArtifactManager(object):
|
|
161
|
+
"""A helper class to manage Lineage Artifacts"""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
name,
|
|
166
|
+
source_uri,
|
|
167
|
+
etag,
|
|
168
|
+
source_arn=None,
|
|
169
|
+
dest_arn=None,
|
|
170
|
+
artifact_type=_DEFAULT_ARTIFACT_TYPE,
|
|
171
|
+
):
|
|
172
|
+
"""Initialize a `_LineageArtifactManager` instance.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
name (str): The name of the Lineage artifact to be created.
|
|
176
|
+
source_uri (str): The source URI used to create the Lineage artifact.
|
|
177
|
+
etag (str): The S3 Etag used to create the Lineage artifact.
|
|
178
|
+
source_arn (str): The source ARN of a trail component to associate
|
|
179
|
+
this Lineage artifact with (default: None).
|
|
180
|
+
dest_arn (str): The destination ARN of a trial component to associate
|
|
181
|
+
this Lineage artifact with (default: None).
|
|
182
|
+
artifact_type (str): The type of the Lineage artifact (default: "Tracker").
|
|
183
|
+
"""
|
|
184
|
+
self.name = name
|
|
185
|
+
self.source_uri = source_uri
|
|
186
|
+
self.etag = etag
|
|
187
|
+
self.source_arn = source_arn
|
|
188
|
+
self.dest_arn = dest_arn
|
|
189
|
+
self.artifact_arn = None
|
|
190
|
+
self.artifact_type = artifact_type
|
|
191
|
+
|
|
192
|
+
def create_artifact(self, sagemaker_session):
|
|
193
|
+
"""Create the artifact by calling `CreateArtifact` API
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
197
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
198
|
+
AWS services needed.
|
|
199
|
+
"""
|
|
200
|
+
source_ids = []
|
|
201
|
+
if self.etag:
|
|
202
|
+
source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag})
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
response = sagemaker_session.sagemaker_client.create_artifact(
|
|
206
|
+
ArtifactName=self.name,
|
|
207
|
+
ArtifactType=self.artifact_type,
|
|
208
|
+
Source={"SourceUri": self.source_uri, "SourceTypes": source_ids},
|
|
209
|
+
)
|
|
210
|
+
self.artifact_arn = response["ArtifactArn"]
|
|
211
|
+
except botocore.exceptions.ClientError as err:
|
|
212
|
+
err_info = err.response["Error"]
|
|
213
|
+
if not is_already_exist_error(err_info):
|
|
214
|
+
raise
|
|
215
|
+
logger.warning(
|
|
216
|
+
"Skip creating the artifact since it already exists: %s", err_info["Message"]
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def add_association(self, sagemaker_session):
|
|
220
|
+
"""Associate the artifact with a source/destination ARN (e.g. trial component arn)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
224
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
225
|
+
AWS services needed.
|
|
226
|
+
"""
|
|
227
|
+
source_arn = self.source_arn if self.source_arn else self.artifact_arn
|
|
228
|
+
dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn
|
|
229
|
+
# if the trial component (job) is the source then it produced the artifact,
|
|
230
|
+
# otherwise the artifact contributed to the trial component (job)
|
|
231
|
+
association_edge_type = "Produced" if self.source_arn else "ContributedTo"
|
|
232
|
+
try:
|
|
233
|
+
sagemaker_session.sagemaker_client.add_association(
|
|
234
|
+
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type
|
|
235
|
+
)
|
|
236
|
+
except botocore.exceptions.ClientError as err:
|
|
237
|
+
err_info = err.response["Error"]
|
|
238
|
+
if not is_already_exist_error(err_info):
|
|
239
|
+
raise
|
|
240
|
+
logger.warning(
|
|
241
|
+
"Skip associating since the association already exists: %s", err_info["Message"]
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class _LineageArtifactTracker(object):
|
|
246
|
+
"""Lineage Artifact Tracker"""
|
|
247
|
+
|
|
248
|
+
def __init__(self, trial_component_arn, sagemaker_session):
|
|
249
|
+
"""Initialize a `_LineageArtifactTracker` instance.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
trial_component_arn (str): The ARN of the trial component to be
|
|
253
|
+
associated with the input/output artifacts.
|
|
254
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
255
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
256
|
+
AWS services needed.
|
|
257
|
+
"""
|
|
258
|
+
self.trial_component_arn = trial_component_arn
|
|
259
|
+
self.sagemaker_session = sagemaker_session
|
|
260
|
+
self.artifacts = []
|
|
261
|
+
|
|
262
|
+
def add_input_artifact(self, name, source_uri, etag, artifact_type):
|
|
263
|
+
"""Add a Lineage input artifact locally
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
name (str): The name of the Lineage input artifact to be added.
|
|
267
|
+
source_uri (str): The source URI used to create the Lineage input artifact.
|
|
268
|
+
etag (str): The S3 Etag used to create the Lineage input artifact.
|
|
269
|
+
artifact_type (str): The type of the Lineage input artifact.
|
|
270
|
+
"""
|
|
271
|
+
artifact = _LineageArtifactManager(
|
|
272
|
+
name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type
|
|
273
|
+
)
|
|
274
|
+
self.artifacts.append(artifact)
|
|
275
|
+
|
|
276
|
+
def add_output_artifact(self, name, source_uri, etag, artifact_type):
|
|
277
|
+
"""Add a Lineage output artifact locally
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
name (str): The name of the Lineage output artifact to be added.
|
|
281
|
+
source_uri (str): The source URI used to create the Lineage output artifact.
|
|
282
|
+
etag (str): The S3 Etag used to create the Lineage output artifact.
|
|
283
|
+
artifact_type (str): The type of the Lineage output artifact.
|
|
284
|
+
"""
|
|
285
|
+
artifact = _LineageArtifactManager(
|
|
286
|
+
name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type
|
|
287
|
+
)
|
|
288
|
+
self.artifacts.append(artifact)
|
|
289
|
+
|
|
290
|
+
def save(self):
|
|
291
|
+
"""Persist any artifact data saved locally"""
|
|
292
|
+
for artifact in self.artifacts:
|
|
293
|
+
artifact.create_artifact(self.sagemaker_session)
|
|
294
|
+
artifact.add_association(self.sagemaker_session)
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Contains classes to manage metrics for Sagemaker Experiment"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import datetime
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
import threading
|
|
21
|
+
import queue
|
|
22
|
+
|
|
23
|
+
import dateutil.tz
|
|
24
|
+
|
|
25
|
+
from sagemaker.core.helper.session_helper import Session
|
|
26
|
+
|
|
27
|
+
METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", "")
|
|
28
|
+
METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds
|
|
29
|
+
METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds
|
|
30
|
+
|
|
31
|
+
BATCH_SIZE = 10
|
|
32
|
+
|
|
33
|
+
logging.basicConfig(level=logging.INFO)
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _RawMetricData(object):
|
|
38
|
+
"""A Raw Metric Data Object"""
|
|
39
|
+
|
|
40
|
+
MetricName = None
|
|
41
|
+
Value = None
|
|
42
|
+
Timestamp = None
|
|
43
|
+
Step = None
|
|
44
|
+
|
|
45
|
+
def __init__(self, metric_name, value, timestamp=None, step=None):
|
|
46
|
+
"""Construct a `_RawMetricData` instance.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
metric_name (str): The name of the metric.
|
|
50
|
+
value (float): The value of the metric.
|
|
51
|
+
timestamp (datetime.datetime or float or str): Timestamp of the metric.
|
|
52
|
+
If not specified, the current UTC time will be used.
|
|
53
|
+
step (int): Iteration number of the metric (default: None).
|
|
54
|
+
"""
|
|
55
|
+
if timestamp is None:
|
|
56
|
+
timestamp = time.time()
|
|
57
|
+
elif isinstance(timestamp, datetime.datetime):
|
|
58
|
+
# If the input is a datetime then convert it to UTC time.
|
|
59
|
+
# Assume a naive datetime is in local timezone
|
|
60
|
+
if not timestamp.tzinfo:
|
|
61
|
+
timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal())
|
|
62
|
+
timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc)
|
|
63
|
+
timestamp = timestamp.timestamp()
|
|
64
|
+
else:
|
|
65
|
+
timestamp = float(timestamp)
|
|
66
|
+
|
|
67
|
+
if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > (
|
|
68
|
+
time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW
|
|
69
|
+
):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"Supplied timestamp %f is invalid."
|
|
72
|
+
" Timestamps must be between two weeks before and two hours from now." % timestamp
|
|
73
|
+
)
|
|
74
|
+
value = float(value)
|
|
75
|
+
|
|
76
|
+
self.MetricName = metric_name
|
|
77
|
+
self.Value = float(value)
|
|
78
|
+
self.Timestamp = timestamp
|
|
79
|
+
if step is not None:
|
|
80
|
+
if not isinstance(step, int):
|
|
81
|
+
raise ValueError("step must be int.")
|
|
82
|
+
self.Step = step
|
|
83
|
+
|
|
84
|
+
def to_record(self):
|
|
85
|
+
"""Convert the `_RawMetricData` object to dict"""
|
|
86
|
+
return self.__dict__
|
|
87
|
+
|
|
88
|
+
def to_raw_metric_data(self):
|
|
89
|
+
"""Converts the metric data to a BatchPutMetrics RawMetricData item"""
|
|
90
|
+
# Convert timestamp from float to timestamp str.
|
|
91
|
+
# Otherwise will get ParamValidationError
|
|
92
|
+
raw_metric_data = {
|
|
93
|
+
"MetricName": self.MetricName,
|
|
94
|
+
"Value": self.Value,
|
|
95
|
+
"Timestamp": str(int(self.Timestamp)),
|
|
96
|
+
}
|
|
97
|
+
if self.Step is not None:
|
|
98
|
+
raw_metric_data["Step"] = int(self.Step)
|
|
99
|
+
return raw_metric_data
|
|
100
|
+
|
|
101
|
+
def __str__(self):
|
|
102
|
+
"""String representation of the `_RawMetricData` object."""
|
|
103
|
+
return repr(self)
|
|
104
|
+
|
|
105
|
+
def __repr__(self):
|
|
106
|
+
"""Return a string representation of this _RawMetricData` object."""
|
|
107
|
+
return "{}({})".format(
|
|
108
|
+
type(self).__name__,
|
|
109
|
+
",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class _MetricsManager(object):
|
|
114
|
+
"""Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
|
|
115
|
+
|
|
116
|
+
def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None:
|
|
117
|
+
"""Initialize a `_MetricsManager` instance
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
trial_component_name (str): The Name of the Trial Component to log metrics to
|
|
121
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
122
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
123
|
+
AWS services needed. If not specified, one is created using the
|
|
124
|
+
default AWS configuration chain.
|
|
125
|
+
sink (object): The metrics sink to use.
|
|
126
|
+
"""
|
|
127
|
+
if sink is None:
|
|
128
|
+
self.sink = _SyncMetricsSink(
|
|
129
|
+
trial_component_name, sagemaker_session.sagemaker_metrics_client
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
self.sink = sink
|
|
133
|
+
|
|
134
|
+
def log_metric(self, metric_name, value, timestamp=None, step=None):
|
|
135
|
+
"""Sends a metric to metrics service."""
|
|
136
|
+
|
|
137
|
+
metric_data = _RawMetricData(metric_name, value, timestamp, step)
|
|
138
|
+
self.sink.log_metric(metric_data)
|
|
139
|
+
|
|
140
|
+
def __enter__(self):
|
|
141
|
+
"""Return self"""
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
145
|
+
"""Execute self.close()"""
|
|
146
|
+
self.sink.close()
|
|
147
|
+
|
|
148
|
+
def close(self):
|
|
149
|
+
"""Close the metrics object."""
|
|
150
|
+
self.sink.close()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class _SyncMetricsSink(object):
|
|
154
|
+
"""Collects metrics and sends them directly to metrics service."""
|
|
155
|
+
|
|
156
|
+
def __init__(self, trial_component_name, metrics_client) -> None:
|
|
157
|
+
"""Initialize a `_SyncMetricsSink` instance
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
trial_component_name (str): The Name of the Trial Component to log metrics.
|
|
161
|
+
metrics_client (boto3.client): boto client for metrics service
|
|
162
|
+
"""
|
|
163
|
+
self._trial_component_name = trial_component_name
|
|
164
|
+
self._metrics_client = metrics_client
|
|
165
|
+
self._buffer = []
|
|
166
|
+
|
|
167
|
+
def log_metric(self, metric_data):
|
|
168
|
+
"""Sends a metric to metrics service."""
|
|
169
|
+
|
|
170
|
+
# this is a simplistic solution which calls BatchPutMetrics
|
|
171
|
+
# on the same thread as the client code
|
|
172
|
+
self._buffer.append(metric_data)
|
|
173
|
+
self._drain()
|
|
174
|
+
|
|
175
|
+
def _drain(self, close=False):
|
|
176
|
+
"""Pops off all metrics in the buffer and starts sending them to metrics service."""
|
|
177
|
+
|
|
178
|
+
if not self._buffer:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
if len(self._buffer) < BATCH_SIZE and not close:
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
# pop all the available metrics
|
|
185
|
+
available_metrics, self._buffer = self._buffer, []
|
|
186
|
+
|
|
187
|
+
self._send_metrics(available_metrics)
|
|
188
|
+
|
|
189
|
+
def _send_metrics(self, metrics):
|
|
190
|
+
"""Calls BatchPutMetrics directly on the metrics service."""
|
|
191
|
+
while metrics:
|
|
192
|
+
batch, metrics = (
|
|
193
|
+
metrics[:BATCH_SIZE],
|
|
194
|
+
metrics[BATCH_SIZE:],
|
|
195
|
+
)
|
|
196
|
+
request = self._construct_batch_put_metrics_request(batch)
|
|
197
|
+
response = self._metrics_client.batch_put_metrics(**request)
|
|
198
|
+
errors = response["Errors"] if "Errors" in response else None
|
|
199
|
+
if errors:
|
|
200
|
+
error_code = errors[0]["Code"]
|
|
201
|
+
raise Exception(f'{len(errors)} errors with error code "{error_code}"')
|
|
202
|
+
|
|
203
|
+
def _construct_batch_put_metrics_request(self, batch):
|
|
204
|
+
"""Creates dictionary object used as request to metrics service."""
|
|
205
|
+
return {
|
|
206
|
+
"TrialComponentName": self._trial_component_name.lower(),
|
|
207
|
+
"MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
def close(self):
|
|
211
|
+
"""Drains any remaining metrics."""
|
|
212
|
+
self._drain(close=True)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class _MetricQueue(object):
|
|
216
|
+
"""A thread safe queue for sending metrics to SageMaker.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
trial_component_name (str): the ARN of the resource
|
|
220
|
+
metric_name (str): the name of the metric
|
|
221
|
+
metrics_client (boto_client): the boto client for SageMaker Metrics service
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
_CONSUMER_SLEEP_SECONDS = 5
|
|
225
|
+
|
|
226
|
+
def __init__(self, trial_component_name, metric_name, metrics_client):
|
|
227
|
+
# infinite queue size
|
|
228
|
+
self._queue = queue.Queue()
|
|
229
|
+
self._buffer = []
|
|
230
|
+
self._thread = threading.Thread(target=self._run)
|
|
231
|
+
self._started = False
|
|
232
|
+
self._finished = False
|
|
233
|
+
self._trial_component_name = trial_component_name
|
|
234
|
+
self._metrics_client = metrics_client
|
|
235
|
+
self._metric_name = metric_name
|
|
236
|
+
self._logged_metrics = 0
|
|
237
|
+
|
|
238
|
+
def log_metric(self, metric_data):
|
|
239
|
+
"""Adds a metric data point to the queue"""
|
|
240
|
+
self._buffer.append(metric_data)
|
|
241
|
+
|
|
242
|
+
if len(self._buffer) < BATCH_SIZE:
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
self._enqueue_all()
|
|
246
|
+
|
|
247
|
+
if not self._started:
|
|
248
|
+
self._thread.start()
|
|
249
|
+
self._started = True
|
|
250
|
+
|
|
251
|
+
def _run(self):
|
|
252
|
+
"""Starts the metric thread which sends metrics to SageMaker in batches"""
|
|
253
|
+
|
|
254
|
+
while not self._queue.empty() or not self._finished:
|
|
255
|
+
if self._queue.empty():
|
|
256
|
+
time.sleep(self._CONSUMER_SLEEP_SECONDS)
|
|
257
|
+
else:
|
|
258
|
+
batch = self._queue.get()
|
|
259
|
+
self._send_metrics(batch)
|
|
260
|
+
|
|
261
|
+
def _send_metrics(self, metrics_batch):
|
|
262
|
+
"""Calls BatchPutMetrics directly on the metrics service."""
|
|
263
|
+
request = self._construct_batch_put_metrics_request(metrics_batch)
|
|
264
|
+
self._logged_metrics += len(metrics_batch)
|
|
265
|
+
self._metrics_client.batch_put_metrics(**request)
|
|
266
|
+
|
|
267
|
+
def _construct_batch_put_metrics_request(self, batch):
|
|
268
|
+
"""Creates dictionary object used as request to metrics service."""
|
|
269
|
+
|
|
270
|
+
return {
|
|
271
|
+
"TrialComponentName": self._trial_component_name,
|
|
272
|
+
"MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
def _enqueue_all(self):
|
|
276
|
+
"""Enqueue all buffered metrics to be sent to SageMaker"""
|
|
277
|
+
|
|
278
|
+
available_metrics, self._buffer = self._buffer, []
|
|
279
|
+
if available_metrics:
|
|
280
|
+
self._queue.put(available_metrics)
|
|
281
|
+
|
|
282
|
+
def close(self):
|
|
283
|
+
"""Flushes any buffered metrics"""
|
|
284
|
+
|
|
285
|
+
self._enqueue_all()
|
|
286
|
+
self._finished = True
|
|
287
|
+
|
|
288
|
+
def is_active(self):
|
|
289
|
+
"""Is the thread active (still draining metrics to SageMaker)"""
|
|
290
|
+
|
|
291
|
+
return self._thread.is_alive()
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class _AsyncMetricsSink(object):
|
|
295
|
+
"""Collects metrics and sends them directly to metrics service."""
|
|
296
|
+
|
|
297
|
+
_COMPLETE_SLEEP_SECONDS = 1.0
|
|
298
|
+
|
|
299
|
+
def __init__(self, trial_component_name, metrics_client) -> None:
|
|
300
|
+
"""Initialize a `_AsyncMetricsSink` instance
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
trial_component_name (str): The Name of the Trial Component to log metrics to.
|
|
304
|
+
metrics_client (boto3.client): boto client for metrics service
|
|
305
|
+
"""
|
|
306
|
+
self._trial_component_name = trial_component_name
|
|
307
|
+
self._metrics_client = metrics_client
|
|
308
|
+
self._buffer = []
|
|
309
|
+
self._is_draining = False
|
|
310
|
+
self._metric_queues = {}
|
|
311
|
+
|
|
312
|
+
def log_metric(self, metric_data):
|
|
313
|
+
"""Sends a metric to metrics service."""
|
|
314
|
+
|
|
315
|
+
if metric_data.MetricName in self._metric_queues:
|
|
316
|
+
self._metric_queues[metric_data.MetricName].log_metric(metric_data)
|
|
317
|
+
else:
|
|
318
|
+
cur_metric_queue = _MetricQueue(
|
|
319
|
+
self._trial_component_name, metric_data.MetricName, self._metrics_client
|
|
320
|
+
)
|
|
321
|
+
self._metric_queues[metric_data.MetricName] = cur_metric_queue
|
|
322
|
+
cur_metric_queue.log_metric(metric_data)
|
|
323
|
+
|
|
324
|
+
def close(self):
|
|
325
|
+
"""Closes the metric file."""
|
|
326
|
+
logging.debug("Closing")
|
|
327
|
+
for q in self._metric_queues.values():
|
|
328
|
+
q.close()
|
|
329
|
+
|
|
330
|
+
# TODO should probably use join
|
|
331
|
+
while any(map(lambda x: x.is_active(), self._metric_queues.values())):
|
|
332
|
+
time.sleep(self._COMPLETE_SLEEP_SECONDS)
|
|
333
|
+
logging.debug("Closed")
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Contains the SageMaker Experiment _RunContext class."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import TYPE_CHECKING
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from sagemaker.core.experiments import Run
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _RunContext:
|
|
23
|
+
"""A static context variable to keep track of the current Run object"""
|
|
24
|
+
|
|
25
|
+
_context_run = None
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def add_run_object(cls, run: "Run"):
|
|
29
|
+
"""Keep track of the current executing Run object
|
|
30
|
+
|
|
31
|
+
by adding it to a class static variable.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
run (Run): The current Run object to be tracked.
|
|
35
|
+
"""
|
|
36
|
+
cls._context_run = run
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def drop_current_run(cls) -> "Run":
|
|
40
|
+
"""Drop the Run object tracked in the global static variable
|
|
41
|
+
|
|
42
|
+
as its execution finishes (its "with" block ends).
|
|
43
|
+
|
|
44
|
+
Return:
|
|
45
|
+
Run: the dropped Run object.
|
|
46
|
+
"""
|
|
47
|
+
current_run = cls._context_run
|
|
48
|
+
cls._context_run = None
|
|
49
|
+
return current_run
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def get_current_run(cls) -> "Run":
|
|
53
|
+
"""Return the current Run object without dropping it.
|
|
54
|
+
|
|
55
|
+
Return:
|
|
56
|
+
Run: the current Run object to be returned.
|
|
57
|
+
"""
|
|
58
|
+
return cls._context_run
|