sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,744 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Placeholder docstring"""
|
|
14
|
+
from __future__ import print_function, absolute_import
|
|
15
|
+
|
|
16
|
+
from abc import ABCMeta, abstractmethod
|
|
17
|
+
from collections import defaultdict, OrderedDict
|
|
18
|
+
import datetime
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
from six import with_metaclass
|
|
22
|
+
|
|
23
|
+
from sagemaker.core.helper.session_helper import Session
|
|
24
|
+
from sagemaker.core.common_utils import DeferredError
|
|
25
|
+
from sagemaker.core.lineage import artifact
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import pandas as pd
|
|
31
|
+
except ImportError as e:
|
|
32
|
+
logger.warning("pandas failed to import. Analytics features will be impaired or broken.")
|
|
33
|
+
# Any subsequent attempt to use pandas will raise the ImportError
|
|
34
|
+
pd = DeferredError(e)
|
|
35
|
+
|
|
36
|
+
METRICS_PERIOD_DEFAULT = 60 # seconds
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)):
|
|
40
|
+
"""Base class for tuning job or training job analytics classes.
|
|
41
|
+
|
|
42
|
+
Understands common functionality like persistence and caching.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self):
|
|
46
|
+
"""Initializes ``AnalyticsMetricsBase`` instance."""
|
|
47
|
+
self._dataframe = None
|
|
48
|
+
|
|
49
|
+
def export_csv(self, filename):
|
|
50
|
+
"""Persists the analytics dataframe to a file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
filename (str): The name of the file to save to.
|
|
54
|
+
"""
|
|
55
|
+
self.dataframe().to_csv(filename)
|
|
56
|
+
|
|
57
|
+
def dataframe(self, force_refresh=False):
|
|
58
|
+
"""A pandas dataframe with lots of interesting results about this object.
|
|
59
|
+
|
|
60
|
+
Created by calling SageMaker List and Describe APIs and converting them into a
|
|
61
|
+
convenient tabular summary.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
force_refresh (bool): Set to True to fetch the latest data from
|
|
65
|
+
SageMaker API.
|
|
66
|
+
"""
|
|
67
|
+
if force_refresh:
|
|
68
|
+
self.clear_cache()
|
|
69
|
+
if self._dataframe is None:
|
|
70
|
+
self._dataframe = self._fetch_dataframe()
|
|
71
|
+
return self._dataframe
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def _fetch_dataframe(self):
|
|
75
|
+
"""Sub-class must calculate the dataframe and return it."""
|
|
76
|
+
|
|
77
|
+
def clear_cache(self):
|
|
78
|
+
"""Clear the object of all local caches of API methods.
|
|
79
|
+
|
|
80
|
+
So that the next time any properties are accessed they will be refreshed from the service.
|
|
81
|
+
"""
|
|
82
|
+
self._dataframe = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class HyperparameterTuningJobAnalytics(AnalyticsMetricsBase):
|
|
86
|
+
"""Fetch results about a hyperparameter tuning job and make them accessible for analytics."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, hyperparameter_tuning_job_name, sagemaker_session=None):
|
|
89
|
+
"""Initialize a ``HyperparameterTuningJobAnalytics`` instance.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
hyperparameter_tuning_job_name (str): name of the
|
|
93
|
+
HyperparameterTuningJob to analyze.
|
|
94
|
+
sagemaker_session (sagemaker.session.Session): Session object which
|
|
95
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
96
|
+
AWS services needed. If not specified, one is created using the
|
|
97
|
+
default AWS configuration chain.
|
|
98
|
+
"""
|
|
99
|
+
sagemaker_session = sagemaker_session or Session()
|
|
100
|
+
self._sage_client = sagemaker_session.sagemaker_client
|
|
101
|
+
self._tuning_job_name = hyperparameter_tuning_job_name
|
|
102
|
+
self._tuning_job_describe_result = None
|
|
103
|
+
self._training_job_summaries = None
|
|
104
|
+
super(HyperparameterTuningJobAnalytics, self).__init__()
|
|
105
|
+
self.clear_cache()
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def name(self):
|
|
109
|
+
"""Name of the HyperparameterTuningJob being analyzed"""
|
|
110
|
+
return self._tuning_job_name
|
|
111
|
+
|
|
112
|
+
def __repr__(self):
|
|
113
|
+
"""Human-readable representation override."""
|
|
114
|
+
return "<sagemaker.HyperparameterTuningJobAnalytics for %s>" % self.name
|
|
115
|
+
|
|
116
|
+
def clear_cache(self):
|
|
117
|
+
"""Clear the object of all local caches of API methods."""
|
|
118
|
+
super(HyperparameterTuningJobAnalytics, self).clear_cache()
|
|
119
|
+
self._tuning_job_describe_result = None
|
|
120
|
+
self._training_job_summaries = None
|
|
121
|
+
|
|
122
|
+
def _fetch_dataframe(self):
|
|
123
|
+
"""Return a pandas dataframe with all the training jobs.
|
|
124
|
+
|
|
125
|
+
This includes their hyperparameters, results, and metadata, as well as
|
|
126
|
+
a column to indicate if a training job was the best seen so far.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def reshape(training_summary):
|
|
130
|
+
# Helper method to reshape a single training job summary into a dataframe record
|
|
131
|
+
out = {}
|
|
132
|
+
for k, v in training_summary["TunedHyperParameters"].items():
|
|
133
|
+
# Something (bokeh?) gets confused with ints so convert to float
|
|
134
|
+
try:
|
|
135
|
+
v = float(v)
|
|
136
|
+
except (TypeError, ValueError):
|
|
137
|
+
pass
|
|
138
|
+
out[k] = v
|
|
139
|
+
out["TrainingJobName"] = training_summary["TrainingJobName"]
|
|
140
|
+
out["TrainingJobStatus"] = training_summary["TrainingJobStatus"]
|
|
141
|
+
out["FinalObjectiveValue"] = training_summary.get(
|
|
142
|
+
"FinalHyperParameterTuningJobObjectiveMetric", {}
|
|
143
|
+
).get("Value")
|
|
144
|
+
|
|
145
|
+
start_time = training_summary.get("TrainingStartTime", None)
|
|
146
|
+
end_time = training_summary.get("TrainingEndTime", None)
|
|
147
|
+
out["TrainingStartTime"] = start_time
|
|
148
|
+
out["TrainingEndTime"] = end_time
|
|
149
|
+
if start_time and end_time:
|
|
150
|
+
out["TrainingElapsedTimeSeconds"] = (end_time - start_time).total_seconds()
|
|
151
|
+
if "TrainingJobDefinitionName" in training_summary:
|
|
152
|
+
out["TrainingJobDefinitionName"] = training_summary["TrainingJobDefinitionName"]
|
|
153
|
+
return out
|
|
154
|
+
|
|
155
|
+
# Run that helper over all the summaries.
|
|
156
|
+
df = pd.DataFrame([reshape(tjs) for tjs in self.training_job_summaries()])
|
|
157
|
+
return df
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def tuning_ranges(self):
|
|
161
|
+
"""A dictionary describing the ranges of all tuned hyperparameters.
|
|
162
|
+
|
|
163
|
+
The keys are the names of the hyperparameter, and the values are the ranges.
|
|
164
|
+
|
|
165
|
+
The output can take one of two forms:
|
|
166
|
+
|
|
167
|
+
* If the 'TrainingJobDefinition' field is present in the job description, the output
|
|
168
|
+
is a dictionary constructed from 'ParameterRanges' in
|
|
169
|
+
'HyperParameterTuningJobConfig' of the job description. The keys are the
|
|
170
|
+
parameter names, while the values are the parameter ranges.
|
|
171
|
+
Example:
|
|
172
|
+
>>> {
|
|
173
|
+
>>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
|
|
174
|
+
>>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
|
|
175
|
+
>>> "iterations": {"MaxValue": "100", "MinValue": "50", "Name": "iterations"},
|
|
176
|
+
>>> "num_layers": {"MaxValue": "30", "MinValue": "5", "Name": "num_layers"},
|
|
177
|
+
>>> }
|
|
178
|
+
* If the 'TrainingJobDefinitions' field (list) is present in the job description,
|
|
179
|
+
the output is a dictionary with keys as the 'DefinitionName' values from
|
|
180
|
+
all items in 'TrainingJobDefinitions', and each value would be a dictionary
|
|
181
|
+
constructed from 'HyperParameterRanges' in each item in 'TrainingJobDefinitions'
|
|
182
|
+
in the same format as above
|
|
183
|
+
Example:
|
|
184
|
+
>>> {
|
|
185
|
+
>>> "estimator_1": {
|
|
186
|
+
>>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
|
|
187
|
+
>>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
|
|
188
|
+
>>> },
|
|
189
|
+
>>> "estimator_2": {
|
|
190
|
+
>>> "framework": {"Values": ["TF", "MXNet"], "Name": "framework"},
|
|
191
|
+
>>> "gamma": {"MaxValue": "1.0", "MinValue": "0.2", "Name": "gamma"}
|
|
192
|
+
>>> }
|
|
193
|
+
>>> }
|
|
194
|
+
|
|
195
|
+
For more details about the 'TrainingJobDefinition' and 'TrainingJobDefinitions' fields
|
|
196
|
+
in job description, see
|
|
197
|
+
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
|
|
198
|
+
"""
|
|
199
|
+
description = self.description()
|
|
200
|
+
|
|
201
|
+
if "TrainingJobDefinition" in description:
|
|
202
|
+
return self._prepare_parameter_ranges(
|
|
203
|
+
description["HyperParameterTuningJobConfig"]["ParameterRanges"]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
return {
|
|
207
|
+
training_job_definition["DefinitionName"]: self._prepare_parameter_ranges(
|
|
208
|
+
training_job_definition["HyperParameterRanges"]
|
|
209
|
+
)
|
|
210
|
+
for training_job_definition in description["TrainingJobDefinitions"]
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def _prepare_parameter_ranges(self, parameter_ranges):
|
|
214
|
+
"""Convert parameter ranges a dictionary using the parameter range names as the keys"""
|
|
215
|
+
out = {}
|
|
216
|
+
for _, ranges in parameter_ranges.items():
|
|
217
|
+
for param in ranges:
|
|
218
|
+
out[param["Name"]] = param
|
|
219
|
+
return out
|
|
220
|
+
|
|
221
|
+
def description(self, force_refresh=False):
|
|
222
|
+
"""Call ``DescribeHyperParameterTuningJob`` for the hyperparameter tuning job.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
force_refresh (bool): Set to True to fetch the latest data from
|
|
226
|
+
SageMaker API.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
dict: The Amazon SageMaker response for
|
|
230
|
+
``DescribeHyperParameterTuningJob``.
|
|
231
|
+
"""
|
|
232
|
+
if force_refresh:
|
|
233
|
+
self.clear_cache()
|
|
234
|
+
if not self._tuning_job_describe_result:
|
|
235
|
+
self._tuning_job_describe_result = self._sage_client.describe_hyper_parameter_tuning_job( # noqa: E501 # pylint: disable=line-too-long
|
|
236
|
+
HyperParameterTuningJobName=self.name
|
|
237
|
+
)
|
|
238
|
+
return self._tuning_job_describe_result
|
|
239
|
+
|
|
240
|
+
def training_job_summaries(self, force_refresh=False):
|
|
241
|
+
"""A (paginated) list of everything from ``ListTrainingJobsForTuningJob``.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
force_refresh (bool): Set to True to fetch the latest data from
|
|
245
|
+
SageMaker API.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
dict: The Amazon SageMaker response for
|
|
249
|
+
``ListTrainingJobsForTuningJob``.
|
|
250
|
+
"""
|
|
251
|
+
if force_refresh:
|
|
252
|
+
self.clear_cache()
|
|
253
|
+
if self._training_job_summaries is not None:
|
|
254
|
+
return self._training_job_summaries
|
|
255
|
+
output = []
|
|
256
|
+
next_args = {}
|
|
257
|
+
for count in range(100):
|
|
258
|
+
logger.debug("Calling list_training_jobs_for_hyper_parameter_tuning_job %d", count)
|
|
259
|
+
raw_result = self._sage_client.list_training_jobs_for_hyper_parameter_tuning_job(
|
|
260
|
+
HyperParameterTuningJobName=self.name, MaxResults=100, **next_args
|
|
261
|
+
)
|
|
262
|
+
new_output = raw_result["TrainingJobSummaries"]
|
|
263
|
+
output.extend(new_output)
|
|
264
|
+
logger.debug(
|
|
265
|
+
"Got %d more TrainingJobs. Total so far: %d",
|
|
266
|
+
len(new_output),
|
|
267
|
+
len(output),
|
|
268
|
+
)
|
|
269
|
+
if ("NextToken" in raw_result) and (len(new_output) > 0):
|
|
270
|
+
next_args["NextToken"] = raw_result["NextToken"]
|
|
271
|
+
else:
|
|
272
|
+
break
|
|
273
|
+
self._training_job_summaries = output
|
|
274
|
+
return output
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class TrainingJobAnalytics(AnalyticsMetricsBase):
|
|
278
|
+
"""Fetch training curve data from CloudWatch Metrics for a specific training job."""
|
|
279
|
+
|
|
280
|
+
CLOUDWATCH_NAMESPACE = "/aws/sagemaker/TrainingJobs"
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
training_job_name,
|
|
285
|
+
metric_names=None,
|
|
286
|
+
sagemaker_session=None,
|
|
287
|
+
start_time=None,
|
|
288
|
+
end_time=None,
|
|
289
|
+
period=None,
|
|
290
|
+
):
|
|
291
|
+
"""Initialize a ``TrainingJobAnalytics`` instance.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
training_job_name (str): name of the TrainingJob to analyze.
|
|
295
|
+
metric_names (list, optional): string names of all the metrics to
|
|
296
|
+
collect for this training job. If not specified, then it will
|
|
297
|
+
use all metric names configured for this job.
|
|
298
|
+
sagemaker_session (sagemaker.session.Session): Session object which
|
|
299
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
300
|
+
AWS services needed. If not specified, one is specified using
|
|
301
|
+
the default AWS configuration chain.
|
|
302
|
+
start_time:
|
|
303
|
+
end_time:
|
|
304
|
+
period:
|
|
305
|
+
"""
|
|
306
|
+
sagemaker_session = sagemaker_session or Session()
|
|
307
|
+
self._sage_client = sagemaker_session.sagemaker_client
|
|
308
|
+
self._cloudwatch = sagemaker_session.boto_session.client("cloudwatch")
|
|
309
|
+
self._training_job_name = training_job_name
|
|
310
|
+
self._start_time = start_time
|
|
311
|
+
self._end_time = end_time
|
|
312
|
+
self._period = period or METRICS_PERIOD_DEFAULT
|
|
313
|
+
|
|
314
|
+
if metric_names:
|
|
315
|
+
self._metric_names = metric_names
|
|
316
|
+
else:
|
|
317
|
+
self._metric_names = self._metric_names_for_training_job()
|
|
318
|
+
|
|
319
|
+
super(TrainingJobAnalytics, self).__init__()
|
|
320
|
+
self.clear_cache()
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def name(self):
|
|
324
|
+
"""Name of the TrainingJob being analyzed"""
|
|
325
|
+
return self._training_job_name
|
|
326
|
+
|
|
327
|
+
def __repr__(self):
|
|
328
|
+
"""The human-readable representation override."""
|
|
329
|
+
return "<sagemaker.TrainingJobAnalytics for %s>" % self.name
|
|
330
|
+
|
|
331
|
+
def clear_cache(self):
|
|
332
|
+
"""Clear the object of all local caches of API methods.
|
|
333
|
+
|
|
334
|
+
This is so that the next time any properties are accessed they will be
|
|
335
|
+
refreshed from the service.
|
|
336
|
+
"""
|
|
337
|
+
super(TrainingJobAnalytics, self).clear_cache()
|
|
338
|
+
self._data = defaultdict(list)
|
|
339
|
+
self._time_interval = self._determine_timeinterval()
|
|
340
|
+
|
|
341
|
+
def _determine_timeinterval(self):
|
|
342
|
+
"""Return a dict with two datetime objects.
|
|
343
|
+
|
|
344
|
+
The dict includes the `start_time` and `end_time`, covering the interval
|
|
345
|
+
of the training job.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
a dict with the `start_time` and `end_time`.
|
|
349
|
+
"""
|
|
350
|
+
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
|
|
351
|
+
start_time = self._start_time or description["TrainingStartTime"] # datetime object
|
|
352
|
+
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
|
|
353
|
+
# This results in logs being searched in the time range in which the correct log line was
|
|
354
|
+
# not present.
|
|
355
|
+
# Example - Log time - 2018-10-22 08:25:55
|
|
356
|
+
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
|
|
357
|
+
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the
|
|
358
|
+
# correct log.
|
|
359
|
+
end_time = self._end_time or description.get(
|
|
360
|
+
"TrainingEndTime", datetime.datetime.utcnow()
|
|
361
|
+
) + datetime.timedelta(minutes=1)
|
|
362
|
+
|
|
363
|
+
return {"start_time": start_time, "end_time": end_time}
|
|
364
|
+
|
|
365
|
+
def _fetch_dataframe(self):
|
|
366
|
+
for metric_name in self._metric_names:
|
|
367
|
+
self._fetch_metric(metric_name)
|
|
368
|
+
return pd.DataFrame(self._data)
|
|
369
|
+
|
|
370
|
+
def _fetch_metric(self, metric_name):
|
|
371
|
+
"""Fetch all the values of a named metric, and add them to _data
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
metric_name: The metric name to fetch.
|
|
375
|
+
"""
|
|
376
|
+
request = {
|
|
377
|
+
"Namespace": self.CLOUDWATCH_NAMESPACE,
|
|
378
|
+
"MetricName": metric_name,
|
|
379
|
+
"Dimensions": [{"Name": "TrainingJobName", "Value": self.name}],
|
|
380
|
+
"StartTime": self._time_interval["start_time"],
|
|
381
|
+
"EndTime": self._time_interval["end_time"],
|
|
382
|
+
"Period": self._period,
|
|
383
|
+
"Statistics": ["Average"],
|
|
384
|
+
}
|
|
385
|
+
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)["Datapoints"]
|
|
386
|
+
if len(raw_cwm_data) == 0:
|
|
387
|
+
logger.warning("Warning: No metrics called %s found", metric_name)
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
# Process data: normalize to starting time, and sort.
|
|
391
|
+
base_time = min(raw_cwm_data, key=lambda pt: pt["Timestamp"])["Timestamp"]
|
|
392
|
+
all_xy = []
|
|
393
|
+
for pt in raw_cwm_data:
|
|
394
|
+
y = pt["Average"]
|
|
395
|
+
x = (pt["Timestamp"] - base_time).total_seconds()
|
|
396
|
+
all_xy.append([x, y])
|
|
397
|
+
all_xy = sorted(all_xy, key=lambda x: x[0])
|
|
398
|
+
|
|
399
|
+
# Store everything in _data to make a dataframe from
|
|
400
|
+
for elapsed_seconds, value in all_xy:
|
|
401
|
+
self._add_single_metric(elapsed_seconds, metric_name, value)
|
|
402
|
+
|
|
403
|
+
def _add_single_metric(self, timestamp, metric_name, value):
|
|
404
|
+
"""Store a single metric in the _data dict.
|
|
405
|
+
|
|
406
|
+
This can be converted to a dataframe.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
timestamp: The timestamp of the metric.
|
|
410
|
+
metric_name: The name of the metric.
|
|
411
|
+
value: The value of the metric.
|
|
412
|
+
"""
|
|
413
|
+
# note that this method is built this way to make it possible to
|
|
414
|
+
# support live-refreshing charts in Bokeh at some point in the future.
|
|
415
|
+
self._data["timestamp"].append(timestamp)
|
|
416
|
+
self._data["metric_name"].append(metric_name)
|
|
417
|
+
self._data["value"].append(value)
|
|
418
|
+
|
|
419
|
+
def _metric_names_for_training_job(self):
|
|
420
|
+
"""Helper method to discover the metrics defined for a training job."""
|
|
421
|
+
training_description = self._sage_client.describe_training_job(
|
|
422
|
+
TrainingJobName=self._training_job_name
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
metric_definitions = training_description["AlgorithmSpecification"]["MetricDefinitions"]
|
|
426
|
+
metric_names = [md["Name"] for md in metric_definitions]
|
|
427
|
+
|
|
428
|
+
return metric_names
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class ArtifactAnalytics(AnalyticsMetricsBase):
|
|
432
|
+
"""Fetch artifact data and make them accessible for analytics."""
|
|
433
|
+
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
sort_by=None,
|
|
437
|
+
sort_order=None,
|
|
438
|
+
source_uri=None,
|
|
439
|
+
artifact_type=None,
|
|
440
|
+
sagemaker_session=None,
|
|
441
|
+
):
|
|
442
|
+
"""Initialize a ``ArtifactAnalytics`` instance.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
sort_by (str, optional): The name of the resource property used to sort
|
|
446
|
+
the set of artifacts. Currently only support for sort by Name
|
|
447
|
+
sort_order(str optional): How trial components are ordered, valid values are Ascending
|
|
448
|
+
and Descending. The default is Descending.
|
|
449
|
+
source_uri(dict optional): The artifact source uri for filtering.
|
|
450
|
+
artifact_type(dict optional): The artifact type for filtering.
|
|
451
|
+
sagemaker_session (obj, optional): Sagemaker session. Defaults to None.
|
|
452
|
+
"""
|
|
453
|
+
self._sort_by = sort_by if sort_by == "Name" else None
|
|
454
|
+
self._sort_order = sort_order
|
|
455
|
+
self._source_uri = source_uri
|
|
456
|
+
self._artifact_type = artifact_type
|
|
457
|
+
self._sagemaker_session = sagemaker_session
|
|
458
|
+
super(ArtifactAnalytics, self).__init__()
|
|
459
|
+
self.clear_cache()
|
|
460
|
+
|
|
461
|
+
def __repr__(self):
|
|
462
|
+
"""Human-readable representation override."""
|
|
463
|
+
return "<sagemaker.ArtifactAnalytics>"
|
|
464
|
+
|
|
465
|
+
def _reshape_source_type(self, artifact_source_types):
|
|
466
|
+
"""Reshape artifact source type."""
|
|
467
|
+
out = OrderedDict()
|
|
468
|
+
for artifact_source_type in artifact_source_types:
|
|
469
|
+
out["ArtifactSourceType"] = artifact_source_type
|
|
470
|
+
return out
|
|
471
|
+
|
|
472
|
+
def _reshape(self, artifact_summary):
|
|
473
|
+
"""Reshape artifact summary."""
|
|
474
|
+
out = OrderedDict()
|
|
475
|
+
out["ArtifactName"] = artifact_summary.artifact_name
|
|
476
|
+
out["ArtifactArn"] = artifact_summary.artifact_arn
|
|
477
|
+
out["ArtifactType"] = artifact_summary.artifact_type
|
|
478
|
+
out["ArtifactSourceUri"] = artifact_summary.source.source_uri
|
|
479
|
+
out["CreationTime"] = artifact_summary.creation_time
|
|
480
|
+
out["LastModifiedTime"] = artifact_summary.last_modified_time
|
|
481
|
+
return out
|
|
482
|
+
|
|
483
|
+
def _fetch_dataframe(self):
|
|
484
|
+
"""Return a pandas dataframe with all artifacts."""
|
|
485
|
+
df = pd.DataFrame([self._reshape(artifact) for artifact in self._get_list_artifacts()])
|
|
486
|
+
return df
|
|
487
|
+
|
|
488
|
+
def _get_list_artifacts(self):
|
|
489
|
+
"""List artifacts."""
|
|
490
|
+
artifacts = artifact.Artifact.list(
|
|
491
|
+
source_uri=self._source_uri,
|
|
492
|
+
artifact_type=self._artifact_type,
|
|
493
|
+
sort_by=self._sort_by,
|
|
494
|
+
sort_order=self._sort_order,
|
|
495
|
+
sagemaker_session=self._sagemaker_session,
|
|
496
|
+
)
|
|
497
|
+
return artifacts
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
class ExperimentAnalytics(AnalyticsMetricsBase):
|
|
501
|
+
"""Fetch trial component data and make them accessible for analytics."""
|
|
502
|
+
|
|
503
|
+
MAX_TRIAL_COMPONENTS = 10000
|
|
504
|
+
|
|
505
|
+
def __init__(
|
|
506
|
+
self,
|
|
507
|
+
experiment_name=None,
|
|
508
|
+
search_expression=None,
|
|
509
|
+
sort_by=None,
|
|
510
|
+
sort_order=None,
|
|
511
|
+
metric_names=None,
|
|
512
|
+
parameter_names=None,
|
|
513
|
+
sagemaker_session=None,
|
|
514
|
+
input_artifact_names=None,
|
|
515
|
+
output_artifact_names=None,
|
|
516
|
+
):
|
|
517
|
+
"""Initialize a ``ExperimentAnalytics`` instance.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
experiment_name (str, optional): Name of the experiment if you want to constrain the
|
|
521
|
+
search to only trial components belonging to an experiment.
|
|
522
|
+
search_expression (dict, optional): The search query to find the set of trial components
|
|
523
|
+
to use to populate the data frame.
|
|
524
|
+
sort_by (str, optional): The name of the resource property used to sort
|
|
525
|
+
the set of trial components.
|
|
526
|
+
sort_order(str optional): How trial components are ordered, valid values are Ascending
|
|
527
|
+
and Descending. The default is Descending.
|
|
528
|
+
metric_names (list, optional): string names of all the metrics to be shown in the
|
|
529
|
+
data frame. If not specified, all metrics will be shown of all trials.
|
|
530
|
+
parameter_names (list, optional): string names of the parameters to be shown in the
|
|
531
|
+
data frame. If not specified, all parameters will be shown of all trials.
|
|
532
|
+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
|
|
533
|
+
with Amazon SageMaker APIs and any other AWS services needed. If not specified,
|
|
534
|
+
one is created using the default AWS configuration chain.
|
|
535
|
+
input_artifact_names(dict optional):The input artifacts for the experiment. Examples of
|
|
536
|
+
input artifacts are datasets, algorithms, hyperparameters, source code, and instance
|
|
537
|
+
types.
|
|
538
|
+
output_artifact_names(dict optional): The output artifacts for the experiment. Examples
|
|
539
|
+
of output artifacts are metrics, snapshots, logs, and images.
|
|
540
|
+
"""
|
|
541
|
+
sagemaker_session = sagemaker_session or Session()
|
|
542
|
+
self._sage_client = sagemaker_session.sagemaker_client
|
|
543
|
+
|
|
544
|
+
if not experiment_name and not search_expression:
|
|
545
|
+
raise ValueError("Either experiment_name or search_expression must be supplied.")
|
|
546
|
+
|
|
547
|
+
self._experiment_name = experiment_name
|
|
548
|
+
self._search_expression = search_expression
|
|
549
|
+
self._sort_by = sort_by
|
|
550
|
+
self._sort_order = sort_order
|
|
551
|
+
self._metric_names = metric_names
|
|
552
|
+
self._parameter_names = parameter_names
|
|
553
|
+
self._input_artifact_names = input_artifact_names
|
|
554
|
+
self._output_artifact_names = output_artifact_names
|
|
555
|
+
self._trial_components = None
|
|
556
|
+
super(ExperimentAnalytics, self).__init__()
|
|
557
|
+
self.clear_cache()
|
|
558
|
+
|
|
559
|
+
@property
|
|
560
|
+
def name(self):
|
|
561
|
+
"""Name of the Experiment being analyzed."""
|
|
562
|
+
return self._experiment_name
|
|
563
|
+
|
|
564
|
+
def __repr__(self):
|
|
565
|
+
"""The human-readable representation override."""
|
|
566
|
+
return "<sagemaker.ExperimentAnalytics for %s>" % self.name
|
|
567
|
+
|
|
568
|
+
def clear_cache(self):
|
|
569
|
+
"""Clear the object of all local caches of API methods."""
|
|
570
|
+
super(ExperimentAnalytics, self).clear_cache()
|
|
571
|
+
self._trial_components = None
|
|
572
|
+
|
|
573
|
+
def _reshape_parameters(self, parameters):
|
|
574
|
+
"""Reshape trial component parameters to a pandas column.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
parameters: trial component parameters
|
|
578
|
+
Returns:
|
|
579
|
+
dict: Key: Parameter name, Value: Parameter value
|
|
580
|
+
"""
|
|
581
|
+
out = OrderedDict()
|
|
582
|
+
for name, value in sorted(parameters.items()):
|
|
583
|
+
if self._parameter_names and name not in self._parameter_names:
|
|
584
|
+
continue
|
|
585
|
+
out[name] = value.get("NumberValue", value.get("StringValue"))
|
|
586
|
+
return out
|
|
587
|
+
|
|
588
|
+
def _reshape_metrics(self, metrics):
|
|
589
|
+
"""Reshape trial component metrics to a pandas column.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
metrics: trial component metrics
|
|
593
|
+
Returns:
|
|
594
|
+
dict: Key: Metric name, Value: Metric value
|
|
595
|
+
"""
|
|
596
|
+
statistic_types = ["Min", "Max", "Avg", "StdDev", "Last", "Count"]
|
|
597
|
+
out = OrderedDict()
|
|
598
|
+
for metric_summary in metrics:
|
|
599
|
+
metric_name = metric_summary["MetricName"]
|
|
600
|
+
if self._metric_names and metric_name not in self._metric_names:
|
|
601
|
+
continue
|
|
602
|
+
|
|
603
|
+
for stat_type in statistic_types:
|
|
604
|
+
stat_value = metric_summary.get(stat_type)
|
|
605
|
+
if stat_value is not None:
|
|
606
|
+
out["{} - {}".format(metric_name, stat_type)] = stat_value
|
|
607
|
+
return out
|
|
608
|
+
|
|
609
|
+
def _reshape_artifacts(self, artifacts, _artifact_names):
|
|
610
|
+
"""Reshape trial component input/output artifacts to a pandas column.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
artifacts: trial component input/output artifacts
|
|
614
|
+
Returns:
|
|
615
|
+
dict: Key: artifacts name, Value: artifacts value
|
|
616
|
+
"""
|
|
617
|
+
out = OrderedDict()
|
|
618
|
+
for name, value in sorted(artifacts.items()):
|
|
619
|
+
if _artifact_names and (name not in _artifact_names):
|
|
620
|
+
continue
|
|
621
|
+
out["{} - {}".format(name, "MediaType")] = value.get("MediaType")
|
|
622
|
+
out["{} - {}".format(name, "Value")] = value.get("Value")
|
|
623
|
+
return out
|
|
624
|
+
|
|
625
|
+
def _reshape_parents(self, parents):
|
|
626
|
+
"""Reshape trial component parents to a pandas column.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
parents: trial component parents (trials and experiments)
|
|
630
|
+
Returns:
|
|
631
|
+
dict: Key: artifacts name, Value: artifacts value
|
|
632
|
+
"""
|
|
633
|
+
out = OrderedDict()
|
|
634
|
+
trials = []
|
|
635
|
+
experiments = []
|
|
636
|
+
for parent in parents:
|
|
637
|
+
trials.append(parent["TrialName"])
|
|
638
|
+
experiments.append(parent["ExperimentName"])
|
|
639
|
+
out["Trials"] = trials
|
|
640
|
+
out["Experiments"] = experiments
|
|
641
|
+
return out
|
|
642
|
+
|
|
643
|
+
def _reshape(self, trial_component):
|
|
644
|
+
"""Reshape trial component data to pandas columns.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
trial_component: dict representing a trial component
|
|
648
|
+
Returns:
|
|
649
|
+
dict: Key-Value pair representing the data in the pandas dataframe
|
|
650
|
+
"""
|
|
651
|
+
out = OrderedDict()
|
|
652
|
+
for attribute in ["TrialComponentName", "DisplayName"]:
|
|
653
|
+
out[attribute] = trial_component.get(attribute, "")
|
|
654
|
+
|
|
655
|
+
source = trial_component.get("Source", "")
|
|
656
|
+
if source:
|
|
657
|
+
out["SourceArn"] = source["SourceArn"]
|
|
658
|
+
|
|
659
|
+
out.update(self._reshape_parameters(trial_component.get("Parameters", [])))
|
|
660
|
+
out.update(self._reshape_metrics(trial_component.get("Metrics", [])))
|
|
661
|
+
out.update(
|
|
662
|
+
self._reshape_artifacts(
|
|
663
|
+
trial_component.get("InputArtifacts", []), self._input_artifact_names
|
|
664
|
+
)
|
|
665
|
+
)
|
|
666
|
+
out.update(
|
|
667
|
+
self._reshape_artifacts(
|
|
668
|
+
trial_component.get("OutputArtifacts", []), self._output_artifact_names
|
|
669
|
+
)
|
|
670
|
+
)
|
|
671
|
+
out.update(self._reshape_parents(trial_component.get("Parents", [])))
|
|
672
|
+
return out
|
|
673
|
+
|
|
674
|
+
def _fetch_dataframe(self):
|
|
675
|
+
"""Return a pandas dataframe includes all the trial_components."""
|
|
676
|
+
|
|
677
|
+
df = pd.DataFrame([self._reshape(component) for component in self._get_trial_components()])
|
|
678
|
+
return df
|
|
679
|
+
|
|
680
|
+
def _get_trial_components(self, force_refresh=False):
|
|
681
|
+
"""Get all trial components matching the given search query expression.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
list: List of dicts representing the trial components
|
|
688
|
+
"""
|
|
689
|
+
if force_refresh:
|
|
690
|
+
self.clear_cache()
|
|
691
|
+
if self._trial_components is not None:
|
|
692
|
+
return self._trial_components
|
|
693
|
+
|
|
694
|
+
if not self._search_expression:
|
|
695
|
+
self._search_expression = {}
|
|
696
|
+
|
|
697
|
+
if self._experiment_name:
|
|
698
|
+
if not self._search_expression.get("Filters"):
|
|
699
|
+
self._search_expression["Filters"] = []
|
|
700
|
+
|
|
701
|
+
self._search_expression["Filters"].append(
|
|
702
|
+
{
|
|
703
|
+
"Name": "Parents.ExperimentName",
|
|
704
|
+
"Operator": "Equals",
|
|
705
|
+
"Value": self._experiment_name,
|
|
706
|
+
}
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
return self._search(self._search_expression, self._sort_by, self._sort_order)
|
|
710
|
+
|
|
711
|
+
def _search(self, search_expression, sort_by, sort_order):
|
|
712
|
+
"""Perform a search query using SageMaker Search and return the matching trial components.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
search_expression: Search expression to filter trial components.
|
|
716
|
+
sort_by: The name of the resource property used to sort the trial components.
|
|
717
|
+
sort_order: How trial components are ordered, valid values are Ascending
|
|
718
|
+
and Descending. The default is Descending.
|
|
719
|
+
Returns:
|
|
720
|
+
list: List of dict representing trial components.
|
|
721
|
+
"""
|
|
722
|
+
trial_components = []
|
|
723
|
+
|
|
724
|
+
search_args = {
|
|
725
|
+
"Resource": "ExperimentTrialComponent",
|
|
726
|
+
"SearchExpression": search_expression,
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
if sort_by:
|
|
730
|
+
search_args["SortBy"] = sort_by
|
|
731
|
+
|
|
732
|
+
if sort_order:
|
|
733
|
+
search_args["SortOrder"] = sort_order
|
|
734
|
+
|
|
735
|
+
while len(trial_components) < self.MAX_TRIAL_COMPONENTS:
|
|
736
|
+
search_response = self._sage_client.search(**search_args)
|
|
737
|
+
components = [result["TrialComponent"] for result in search_response["Results"]]
|
|
738
|
+
trial_components.extend(components)
|
|
739
|
+
if "NextToken" in search_response and len(components) > 0:
|
|
740
|
+
search_args["NextToken"] = search_response["NextToken"]
|
|
741
|
+
else:
|
|
742
|
+
break
|
|
743
|
+
|
|
744
|
+
return trial_components
|