sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,970 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Contains the SageMaker Experiment Run class."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import datetime
|
|
17
|
+
import logging
|
|
18
|
+
from enum import Enum
|
|
19
|
+
from math import isnan, isinf
|
|
20
|
+
from numbers import Number
|
|
21
|
+
from typing import Optional, List, Dict, TYPE_CHECKING, Union
|
|
22
|
+
|
|
23
|
+
import dateutil
|
|
24
|
+
from numpy import array
|
|
25
|
+
|
|
26
|
+
from sagemaker.core.apiutils import _utils
|
|
27
|
+
import sagemaker.core.experiments._api_types as _api_types
|
|
28
|
+
from sagemaker.core.experiments._api_types import (
|
|
29
|
+
TrialComponentArtifact,
|
|
30
|
+
_TrialComponentStatusType,
|
|
31
|
+
)
|
|
32
|
+
from sagemaker.core.experiments._helper import (
|
|
33
|
+
_ArtifactUploader,
|
|
34
|
+
_LineageArtifactTracker,
|
|
35
|
+
_DEFAULT_ARTIFACT_PREFIX,
|
|
36
|
+
)
|
|
37
|
+
from sagemaker.core.experiments._environment import _RunEnvironment
|
|
38
|
+
from sagemaker.core.experiments._run_context import _RunContext
|
|
39
|
+
from sagemaker.core.experiments.experiment import Experiment
|
|
40
|
+
from sagemaker.core.experiments._metrics import _MetricsManager
|
|
41
|
+
from sagemaker.core.experiments.trial import _Trial
|
|
42
|
+
from sagemaker.core.experiments.trial_component import _TrialComponent
|
|
43
|
+
|
|
44
|
+
from sagemaker.core.common_utils import (
|
|
45
|
+
get_module,
|
|
46
|
+
unique_name_from_base,
|
|
47
|
+
format_tags,
|
|
48
|
+
Tags,
|
|
49
|
+
TagsDict,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
from sagemaker.core.experiments._utils import (
|
|
53
|
+
guess_media_type,
|
|
54
|
+
resolve_artifact_name,
|
|
55
|
+
verify_length_of_true_and_predicted,
|
|
56
|
+
validate_invoked_inside_run_context,
|
|
57
|
+
get_tc_and_exp_config_from_job_env,
|
|
58
|
+
verify_load_input_names,
|
|
59
|
+
is_run_trial_component,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if TYPE_CHECKING:
|
|
63
|
+
from sagemaker.core.helper.session_helper import Session
|
|
64
|
+
|
|
65
|
+
logger = logging.getLogger(__name__)
|
|
66
|
+
|
|
67
|
+
RUN_NAME_BASE = "Sagemaker-Run".lower()
|
|
68
|
+
TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}"
|
|
69
|
+
MAX_RUN_TC_ARTIFACTS_LEN = 30
|
|
70
|
+
MAX_NAME_LEN_IN_BACKEND = 120
|
|
71
|
+
EXPERIMENT_NAME = "ExperimentName"
|
|
72
|
+
TRIAL_NAME = "TrialName"
|
|
73
|
+
RUN_NAME = "RunName"
|
|
74
|
+
DELIMITER = "-"
|
|
75
|
+
RUN_TC_TAG_KEY = "sagemaker:trial-component-source"
|
|
76
|
+
RUN_TC_TAG_VALUE = "run"
|
|
77
|
+
RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SortByType(Enum):
|
|
81
|
+
"""The type of property by which to sort the `list_runs` results."""
|
|
82
|
+
|
|
83
|
+
CREATION_TIME = "CreationTime"
|
|
84
|
+
NAME = "Name"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class SortOrderType(Enum):
|
|
88
|
+
"""The type of order to sort the list or search results."""
|
|
89
|
+
|
|
90
|
+
ASCENDING = "Ascending"
|
|
91
|
+
DESCENDING = "Descending"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class Run(object):
|
|
95
|
+
"""A collection of parameters, metrics, and artifacts to create a ML model."""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
experiment_name: str,
|
|
100
|
+
run_name: Optional[str] = None,
|
|
101
|
+
experiment_display_name: Optional[str] = None,
|
|
102
|
+
run_display_name: Optional[str] = None,
|
|
103
|
+
tags: Optional[Tags] = None,
|
|
104
|
+
sagemaker_session: Optional["Session"] = None,
|
|
105
|
+
artifact_bucket: Optional[str] = None,
|
|
106
|
+
artifact_prefix: Optional[str] = None,
|
|
107
|
+
):
|
|
108
|
+
"""Construct a `Run` instance.
|
|
109
|
+
|
|
110
|
+
SageMaker Experiments automatically tracks the inputs, parameters, configurations,
|
|
111
|
+
and results of your iterations as runs.
|
|
112
|
+
You can assign, group, and organize these runs into experiments.
|
|
113
|
+
You can also create, compare, and evaluate runs.
|
|
114
|
+
|
|
115
|
+
The code sample below shows how to initialize a run, log parameters to the Run object
|
|
116
|
+
and invoke a training job under the context of this Run object, which automatically
|
|
117
|
+
passes the run's ``experiment_config`` (including the experiment name, run name etc.)
|
|
118
|
+
to the training job.
|
|
119
|
+
|
|
120
|
+
Note:
|
|
121
|
+
All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within
|
|
122
|
+
the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown.
|
|
123
|
+
|
|
124
|
+
.. code:: python
|
|
125
|
+
|
|
126
|
+
with Run(experiment_name="my-exp", run_name="my-run", ...) as run:
|
|
127
|
+
run.log_parameter(...)
|
|
128
|
+
...
|
|
129
|
+
estimator.fit(job_name="my-job") # Create a training job
|
|
130
|
+
|
|
131
|
+
In order to reuse an existing run to log extra data, ``load_run`` is recommended.
|
|
132
|
+
For example, instead of the ``Run`` constructor, the ``load_run`` is recommended to use
|
|
133
|
+
in a job script to load the existing run created before the job launch.
|
|
134
|
+
Otherwise, a new run may be created each time you launch a job.
|
|
135
|
+
|
|
136
|
+
The code snippet below displays how to load the run initialized above
|
|
137
|
+
in a custom training job script, where no ``run_name`` or ``experiment_name``
|
|
138
|
+
is presented as they are automatically retrieved from the experiment config
|
|
139
|
+
in the job environment.
|
|
140
|
+
|
|
141
|
+
.. code:: python
|
|
142
|
+
|
|
143
|
+
with load_run(sagemaker_session=sagemaker_session) as run:
|
|
144
|
+
run.log_metric(...)
|
|
145
|
+
...
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
experiment_name (str): The name of the experiment. The name must be unique
|
|
149
|
+
within an account.
|
|
150
|
+
run_name (str): The name of the run. If it is not specified, one is auto generated.
|
|
151
|
+
experiment_display_name (str): Name of the experiment that will appear in UI,
|
|
152
|
+
such as SageMaker Studio. (default: None). This display name is used in
|
|
153
|
+
a create experiment call. If an experiment with the specified name already exists,
|
|
154
|
+
this display name won't take effect.
|
|
155
|
+
run_display_name (str): The display name of the run used in UI (default: None).
|
|
156
|
+
This display name is used in a create run call. If a run with the
|
|
157
|
+
specified name already exists, this display name won't take effect.
|
|
158
|
+
tags (Optional[Tags]): Tags to be used for all create calls,
|
|
159
|
+
e.g. to create an experiment, a run group, etc. (default: None).
|
|
160
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
161
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
162
|
+
AWS services needed. If not specified, one is created using the
|
|
163
|
+
default AWS configuration chain.
|
|
164
|
+
artifact_bucket (str): The S3 bucket to upload the artifact to.
|
|
165
|
+
If not specified, the default bucket defined in `sagemaker_session`
|
|
166
|
+
will be used.
|
|
167
|
+
artifact_prefix (str): The S3 key prefix used to generate the S3 path
|
|
168
|
+
to upload the artifact to (default: "trial-component-artifacts").
|
|
169
|
+
"""
|
|
170
|
+
# TODO: we should revert the lower casting once backend fix reaches prod
|
|
171
|
+
self.experiment_name = experiment_name.lower()
|
|
172
|
+
sagemaker_session = sagemaker_session or _utils.default_session()
|
|
173
|
+
self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE)
|
|
174
|
+
|
|
175
|
+
# avoid confusion due to mis-match in casing between run name and TC name
|
|
176
|
+
self.run_name = self.run_name.lower()
|
|
177
|
+
|
|
178
|
+
tags = format_tags(tags)
|
|
179
|
+
|
|
180
|
+
trial_component_name = Run._generate_trial_component_name(
|
|
181
|
+
run_name=self.run_name, experiment_name=self.experiment_name
|
|
182
|
+
)
|
|
183
|
+
self.run_group_name = Run._generate_trial_name(self.experiment_name)
|
|
184
|
+
|
|
185
|
+
self._experiment = Experiment._load_or_create(
|
|
186
|
+
experiment_name=self.experiment_name,
|
|
187
|
+
display_name=experiment_display_name,
|
|
188
|
+
tags=tags,
|
|
189
|
+
sagemaker_session=sagemaker_session,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self._trial = _Trial._load_or_create(
|
|
193
|
+
experiment_name=self.experiment_name,
|
|
194
|
+
trial_name=self.run_group_name,
|
|
195
|
+
tags=tags,
|
|
196
|
+
sagemaker_session=sagemaker_session,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
self._trial_component, is_existed = _TrialComponent._load_or_create(
|
|
200
|
+
trial_component_name=trial_component_name,
|
|
201
|
+
display_name=run_display_name,
|
|
202
|
+
tags=Run._append_run_tc_label_to_tags(tags),
|
|
203
|
+
sagemaker_session=sagemaker_session,
|
|
204
|
+
)
|
|
205
|
+
if is_existed:
|
|
206
|
+
logger.info(
|
|
207
|
+
"The run (%s) under experiment (%s) already exists. Loading it.",
|
|
208
|
+
self.run_name,
|
|
209
|
+
self.experiment_name,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if not _TrialComponent._trial_component_is_associated_to_trial(
|
|
213
|
+
self._trial_component.trial_component_name,
|
|
214
|
+
self._trial.trial_name,
|
|
215
|
+
sagemaker_session,
|
|
216
|
+
):
|
|
217
|
+
self._trial.add_trial_component(self._trial_component)
|
|
218
|
+
|
|
219
|
+
self._artifact_uploader = _ArtifactUploader(
|
|
220
|
+
trial_component_name=self._trial_component.trial_component_name,
|
|
221
|
+
sagemaker_session=sagemaker_session,
|
|
222
|
+
artifact_bucket=artifact_bucket,
|
|
223
|
+
artifact_prefix=(
|
|
224
|
+
_DEFAULT_ARTIFACT_PREFIX if artifact_prefix is None else artifact_prefix
|
|
225
|
+
),
|
|
226
|
+
)
|
|
227
|
+
self._lineage_artifact_tracker = _LineageArtifactTracker(
|
|
228
|
+
trial_component_arn=self._trial_component.trial_component_arn,
|
|
229
|
+
sagemaker_session=sagemaker_session,
|
|
230
|
+
)
|
|
231
|
+
self._metrics_manager = _MetricsManager(
|
|
232
|
+
trial_component_name=self._trial_component.trial_component_name,
|
|
233
|
+
sagemaker_session=sagemaker_session,
|
|
234
|
+
)
|
|
235
|
+
self._inside_init_context = False
|
|
236
|
+
self._inside_load_context = False
|
|
237
|
+
self._in_load = False
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def experiment_config(self) -> dict:
|
|
241
|
+
"""Get experiment config from run attributes."""
|
|
242
|
+
return {
|
|
243
|
+
EXPERIMENT_NAME: self.experiment_name,
|
|
244
|
+
TRIAL_NAME: self.run_group_name,
|
|
245
|
+
RUN_NAME: self._trial_component.trial_component_name,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
@validate_invoked_inside_run_context
|
|
249
|
+
def log_parameter(self, name: str, value: Union[str, int, float]):
|
|
250
|
+
"""Record a single parameter value for this run.
|
|
251
|
+
|
|
252
|
+
Overwrites any previous value recorded for the specified parameter name.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
name (str): The name of the parameter.
|
|
256
|
+
value (str or int or float): The value of the parameter.
|
|
257
|
+
"""
|
|
258
|
+
if self._is_input_valid("parameter", name, value):
|
|
259
|
+
self._trial_component.parameters[name] = value
|
|
260
|
+
|
|
261
|
+
@validate_invoked_inside_run_context
|
|
262
|
+
def log_parameters(self, parameters: Dict[str, Union[str, int, float]]):
|
|
263
|
+
"""Record a collection of parameter values for this run.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
parameters (dict[str, str or int or float]): The parameters to record.
|
|
267
|
+
"""
|
|
268
|
+
filtered_parameters = {
|
|
269
|
+
key: value
|
|
270
|
+
for (key, value) in parameters.items()
|
|
271
|
+
if self._is_input_valid("parameter", key, value)
|
|
272
|
+
}
|
|
273
|
+
self._trial_component.parameters.update(filtered_parameters)
|
|
274
|
+
|
|
275
|
+
@validate_invoked_inside_run_context
|
|
276
|
+
def log_metric(
|
|
277
|
+
self,
|
|
278
|
+
name: str,
|
|
279
|
+
value: float,
|
|
280
|
+
timestamp: Optional[datetime.datetime] = None,
|
|
281
|
+
step: Optional[int] = None,
|
|
282
|
+
):
|
|
283
|
+
"""Record a custom scalar metric value for this run.
|
|
284
|
+
|
|
285
|
+
Note:
|
|
286
|
+
This method is for manual custom metrics, for automatic metrics see the
|
|
287
|
+
``enable_sagemaker_metrics`` parameter on the ``estimator`` class.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
name (str): The name of the metric.
|
|
291
|
+
value (float): The value of the metric.
|
|
292
|
+
timestamp (datetime.datetime): The timestamp of the metric.
|
|
293
|
+
If not specified, the current UTC time will be used.
|
|
294
|
+
step (int): The integer iteration number of the metric value (default: None).
|
|
295
|
+
"""
|
|
296
|
+
if self._is_input_valid("metric", name, value):
|
|
297
|
+
self._metrics_manager.log_metric(
|
|
298
|
+
metric_name=name, value=value, timestamp=timestamp, step=step
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
@validate_invoked_inside_run_context
|
|
302
|
+
def log_precision_recall(
|
|
303
|
+
self,
|
|
304
|
+
y_true: Union[list, array],
|
|
305
|
+
predicted_probabilities: Union[list, array],
|
|
306
|
+
positive_label: Optional[Union[str, int]] = None,
|
|
307
|
+
title: Optional[str] = None,
|
|
308
|
+
is_output: bool = True,
|
|
309
|
+
no_skill: Optional[int] = None,
|
|
310
|
+
):
|
|
311
|
+
"""Create and log a precision recall graph artifact for Studio UI to render.
|
|
312
|
+
|
|
313
|
+
The artifact is stored in S3 and represented as a lineage artifact
|
|
314
|
+
with an association with the run.
|
|
315
|
+
|
|
316
|
+
You can view the artifact in the UI.
|
|
317
|
+
If your job is created by a pipeline execution you can view the artifact
|
|
318
|
+
by selecting the corresponding step in the pipelines UI.
|
|
319
|
+
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
|
|
320
|
+
|
|
321
|
+
This method requires sklearn library.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
y_true (list or array): True labels. If labels are not binary
|
|
325
|
+
then positive_label should be given.
|
|
326
|
+
predicted_probabilities (list or array): Estimated/predicted probabilities.
|
|
327
|
+
positive_label (str or int): Label of the positive class (default: None).
|
|
328
|
+
title (str): Title of the graph (default: None).
|
|
329
|
+
is_output (bool): Determines direction of association to the
|
|
330
|
+
run. Defaults to True (output artifact).
|
|
331
|
+
If set to False then represented as input association.
|
|
332
|
+
no_skill (int): The precision threshold under which the classifier cannot discriminate
|
|
333
|
+
between the classes and would predict a random class or a constant class in
|
|
334
|
+
all cases (default: None).
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
verify_length_of_true_and_predicted(
|
|
338
|
+
true_labels=y_true,
|
|
339
|
+
predicted_attrs=predicted_probabilities,
|
|
340
|
+
predicted_attrs_name="predicted probabilities",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
get_module("sklearn")
|
|
344
|
+
from sklearn.metrics import precision_recall_curve, average_precision_score
|
|
345
|
+
|
|
346
|
+
kwargs = {}
|
|
347
|
+
if positive_label is not None:
|
|
348
|
+
kwargs["pos_label"] = positive_label
|
|
349
|
+
|
|
350
|
+
precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs)
|
|
351
|
+
|
|
352
|
+
kwargs["average"] = "micro"
|
|
353
|
+
ap = average_precision_score(y_true, predicted_probabilities, **kwargs)
|
|
354
|
+
|
|
355
|
+
data = {
|
|
356
|
+
"type": "PrecisionRecallCurve",
|
|
357
|
+
"version": 0,
|
|
358
|
+
"title": title,
|
|
359
|
+
"precision": precision.tolist(),
|
|
360
|
+
"recall": recall.tolist(),
|
|
361
|
+
"averagePrecisionScore": ap,
|
|
362
|
+
"noSkill": no_skill,
|
|
363
|
+
}
|
|
364
|
+
self._log_graph_artifact(
|
|
365
|
+
artifact_name=title,
|
|
366
|
+
data=data,
|
|
367
|
+
graph_type="PrecisionRecallCurve",
|
|
368
|
+
is_output=is_output,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
@validate_invoked_inside_run_context
|
|
372
|
+
def log_roc_curve(
|
|
373
|
+
self,
|
|
374
|
+
y_true: Union[list, array],
|
|
375
|
+
y_score: Union[list, array],
|
|
376
|
+
title: Optional[str] = None,
|
|
377
|
+
is_output: bool = True,
|
|
378
|
+
):
|
|
379
|
+
"""Create and log a receiver operating characteristic (ROC curve) artifact.
|
|
380
|
+
|
|
381
|
+
The artifact is stored in S3 and represented as a lineage artifact
|
|
382
|
+
with an association with the run.
|
|
383
|
+
|
|
384
|
+
You can view the artifact in the UI.
|
|
385
|
+
If your job is created by a pipeline execution you can view the artifact
|
|
386
|
+
by selecting the corresponding step in the pipelines UI.
|
|
387
|
+
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
|
|
388
|
+
|
|
389
|
+
This method requires sklearn library.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
y_true (list or array): True labels. If labels are not binary
|
|
393
|
+
then positive_label should be given.
|
|
394
|
+
y_score (list or array): Estimated/predicted probabilities.
|
|
395
|
+
title (str): Title of the graph (default: None).
|
|
396
|
+
is_output (bool): Determines direction of association to the
|
|
397
|
+
run. Defaults to True (output artifact).
|
|
398
|
+
If set to False then represented as input association.
|
|
399
|
+
"""
|
|
400
|
+
verify_length_of_true_and_predicted(
|
|
401
|
+
true_labels=y_true,
|
|
402
|
+
predicted_attrs=y_score,
|
|
403
|
+
predicted_attrs_name="predicted scores",
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
get_module("sklearn")
|
|
407
|
+
from sklearn.metrics import roc_curve, auc
|
|
408
|
+
|
|
409
|
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
|
410
|
+
|
|
411
|
+
auc = auc(fpr, tpr)
|
|
412
|
+
|
|
413
|
+
data = {
|
|
414
|
+
"type": "ROCCurve",
|
|
415
|
+
"version": 0,
|
|
416
|
+
"title": title,
|
|
417
|
+
"falsePositiveRate": fpr.tolist(),
|
|
418
|
+
"truePositiveRate": tpr.tolist(),
|
|
419
|
+
"areaUnderCurve": auc,
|
|
420
|
+
}
|
|
421
|
+
self._log_graph_artifact(
|
|
422
|
+
artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
@validate_invoked_inside_run_context
|
|
426
|
+
def log_confusion_matrix(
|
|
427
|
+
self,
|
|
428
|
+
y_true: Union[list, array],
|
|
429
|
+
y_pred: Union[list, array],
|
|
430
|
+
title: Optional[str] = None,
|
|
431
|
+
is_output: bool = True,
|
|
432
|
+
):
|
|
433
|
+
"""Create and log a confusion matrix artifact.
|
|
434
|
+
|
|
435
|
+
The artifact is stored in S3 and represented as a lineage artifact
|
|
436
|
+
with an association with the run.
|
|
437
|
+
|
|
438
|
+
You can view the artifact in the UI.
|
|
439
|
+
If your job is created by a pipeline execution you can view the
|
|
440
|
+
artifact by selecting the corresponding step in the pipelines UI.
|
|
441
|
+
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
|
|
442
|
+
This method requires sklearn library.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
y_true (list or array): True labels. If labels are not binary
|
|
446
|
+
then positive_label should be given.
|
|
447
|
+
y_pred (list or array): Predicted labels.
|
|
448
|
+
title (str): Title of the graph (default: None).
|
|
449
|
+
is_output (bool): Determines direction of association to the
|
|
450
|
+
run. Defaults to True (output artifact).
|
|
451
|
+
If set to False then represented as input association.
|
|
452
|
+
"""
|
|
453
|
+
verify_length_of_true_and_predicted(
|
|
454
|
+
true_labels=y_true,
|
|
455
|
+
predicted_attrs=y_pred,
|
|
456
|
+
predicted_attrs_name="predicted labels",
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
get_module("sklearn")
|
|
460
|
+
from sklearn.metrics import confusion_matrix
|
|
461
|
+
|
|
462
|
+
matrix = confusion_matrix(y_true, y_pred)
|
|
463
|
+
|
|
464
|
+
data = {
|
|
465
|
+
"type": "ConfusionMatrix",
|
|
466
|
+
"version": 0,
|
|
467
|
+
"title": title,
|
|
468
|
+
"confusionMatrix": matrix.tolist(),
|
|
469
|
+
}
|
|
470
|
+
self._log_graph_artifact(
|
|
471
|
+
artifact_name=title,
|
|
472
|
+
data=data,
|
|
473
|
+
graph_type="ConfusionMatrix",
|
|
474
|
+
is_output=is_output,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
@validate_invoked_inside_run_context
|
|
478
|
+
def log_artifact(
|
|
479
|
+
self,
|
|
480
|
+
name: str,
|
|
481
|
+
value: str,
|
|
482
|
+
media_type: Optional[str] = None,
|
|
483
|
+
is_output: bool = True,
|
|
484
|
+
):
|
|
485
|
+
"""Record a single artifact for this run.
|
|
486
|
+
|
|
487
|
+
Overwrites any previous value recorded for the specified name.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
name (str): The name of the artifact.
|
|
491
|
+
value (str): The value.
|
|
492
|
+
media_type (str): The MediaType (MIME type) of the value (default: None).
|
|
493
|
+
is_output (bool): Determines direction of association to the
|
|
494
|
+
run. Defaults to True (output artifact).
|
|
495
|
+
If set to False then represented as input association.
|
|
496
|
+
"""
|
|
497
|
+
self._verify_trial_component_artifacts_length(is_output=is_output)
|
|
498
|
+
if is_output:
|
|
499
|
+
self._trial_component.output_artifacts[name] = TrialComponentArtifact(
|
|
500
|
+
value, media_type=media_type
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
self._trial_component.input_artifacts[name] = TrialComponentArtifact(
|
|
504
|
+
value, media_type=media_type
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
@validate_invoked_inside_run_context
|
|
508
|
+
def log_file(
|
|
509
|
+
self,
|
|
510
|
+
file_path: str,
|
|
511
|
+
name: Optional[str] = None,
|
|
512
|
+
media_type: Optional[str] = None,
|
|
513
|
+
is_output: Optional[bool] = True,
|
|
514
|
+
extra_args: Optional[dict] = None,
|
|
515
|
+
):
|
|
516
|
+
"""Upload a file to s3 and store it as an input/output artifact in this run.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
file_path (str): The path of the local file to upload.
|
|
520
|
+
name (str): The name of the artifact (default: None).
|
|
521
|
+
media_type (str): The MediaType (MIME type) of the file.
|
|
522
|
+
If not specified, this library will attempt to infer the media type
|
|
523
|
+
from the file extension of ``file_path``.
|
|
524
|
+
is_output (bool): Determines direction of association to the
|
|
525
|
+
run. Defaults to True (output artifact).
|
|
526
|
+
If set to False then represented as input association.
|
|
527
|
+
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
|
|
528
|
+
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
|
|
529
|
+
ExtraArgs parameter documentation here:
|
|
530
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
|
|
531
|
+
"""
|
|
532
|
+
self._verify_trial_component_artifacts_length(is_output)
|
|
533
|
+
media_type = media_type or guess_media_type(file_path)
|
|
534
|
+
name = name or resolve_artifact_name(file_path)
|
|
535
|
+
s3_uri, _ = self._artifact_uploader.upload_artifact(file_path, extra_args=extra_args)
|
|
536
|
+
if is_output:
|
|
537
|
+
self._trial_component.output_artifacts[name] = TrialComponentArtifact(
|
|
538
|
+
value=s3_uri, media_type=media_type
|
|
539
|
+
)
|
|
540
|
+
else:
|
|
541
|
+
self._trial_component.input_artifacts[name] = TrialComponentArtifact(
|
|
542
|
+
value=s3_uri, media_type=media_type
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def close(self):
|
|
546
|
+
"""Persist any data saved locally."""
|
|
547
|
+
try:
|
|
548
|
+
# Update the trial component with additions from the Run object
|
|
549
|
+
self._trial_component.save()
|
|
550
|
+
# Create Lineage entities for the artifacts
|
|
551
|
+
self._lineage_artifact_tracker.save()
|
|
552
|
+
finally:
|
|
553
|
+
if self._metrics_manager:
|
|
554
|
+
self._metrics_manager.close()
|
|
555
|
+
|
|
556
|
+
@staticmethod
|
|
557
|
+
def _generate_trial_name(base_name) -> str:
|
|
558
|
+
"""Generate the reserved trial name based on experiment name
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
base_name (str): The ``experiment_name`` of this ``Run`` object.
|
|
562
|
+
"""
|
|
563
|
+
available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE)
|
|
564
|
+
return TRIAL_NAME_TEMPLATE.format(base_name[:available_length])
|
|
565
|
+
|
|
566
|
+
@staticmethod
|
|
567
|
+
def _is_input_valid(input_type, field_name, field_value) -> bool:
|
|
568
|
+
"""Check if the input is valid or not
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
input_type (str): The type of the input, one of ``parameter``, ``metric``.
|
|
572
|
+
field_name (str): The name of the field to be checked.
|
|
573
|
+
field_value (str or int or float): The value of the field to be checked.
|
|
574
|
+
"""
|
|
575
|
+
if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
|
|
576
|
+
logger.warning(
|
|
577
|
+
"Failed to log %s %s. Received invalid value: %s.",
|
|
578
|
+
input_type,
|
|
579
|
+
field_name,
|
|
580
|
+
field_value,
|
|
581
|
+
)
|
|
582
|
+
return False
|
|
583
|
+
return True
|
|
584
|
+
|
|
585
|
+
def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
|
|
586
|
+
"""Log an artifact.
|
|
587
|
+
|
|
588
|
+
Logs an artifact by uploading data to S3, creating an artifact, and associating that
|
|
589
|
+
artifact with the run trial component.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
data (dict): Artifacts data that will be saved to S3.
|
|
593
|
+
graph_type (str): The type of the artifact.
|
|
594
|
+
is_output (bool): Determines direction of association to the
|
|
595
|
+
trial component. Defaults to True (output artifact).
|
|
596
|
+
If set to False then represented as input association.
|
|
597
|
+
artifact_name (str): Name of the artifact (default: None).
|
|
598
|
+
"""
|
|
599
|
+
# generate an artifact name
|
|
600
|
+
if not artifact_name:
|
|
601
|
+
unique_name_from_base(graph_type)
|
|
602
|
+
|
|
603
|
+
# create a json file in S3
|
|
604
|
+
s3_uri, etag = self._artifact_uploader.upload_object_artifact(
|
|
605
|
+
artifact_name, data, file_extension="json"
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
# create an artifact and association for the table
|
|
609
|
+
if is_output:
|
|
610
|
+
self._lineage_artifact_tracker.add_output_artifact(
|
|
611
|
+
name=artifact_name,
|
|
612
|
+
source_uri=s3_uri,
|
|
613
|
+
etag=etag,
|
|
614
|
+
artifact_type=graph_type,
|
|
615
|
+
)
|
|
616
|
+
else:
|
|
617
|
+
self._lineage_artifact_tracker.add_input_artifact(
|
|
618
|
+
name=artifact_name,
|
|
619
|
+
source_uri=s3_uri,
|
|
620
|
+
etag=etag,
|
|
621
|
+
artifact_type=graph_type,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
def _verify_trial_component_artifacts_length(self, is_output):
|
|
625
|
+
"""Verify the length of trial component artifacts
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
is_output (bool): Determines direction of association to the
|
|
629
|
+
trial component.
|
|
630
|
+
|
|
631
|
+
Raises:
|
|
632
|
+
ValueError: If the length of trial component artifacts exceeds the limit.
|
|
633
|
+
"""
|
|
634
|
+
err_msg_template = "Cannot add more than {} {}_artifacts under run"
|
|
635
|
+
if is_output:
|
|
636
|
+
if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
|
|
637
|
+
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
|
|
638
|
+
else:
|
|
639
|
+
if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
|
|
640
|
+
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input"))
|
|
641
|
+
|
|
642
|
+
@staticmethod
|
|
643
|
+
def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
|
|
644
|
+
"""Generate the TrialComponentName based on run_name and experiment_name
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
run_name (str): The run_name supplied by the user.
|
|
648
|
+
experiment_name (str): The experiment_name supplied by the user,
|
|
649
|
+
which is prepended to the run_name to generate the TrialComponentName.
|
|
650
|
+
|
|
651
|
+
Returns:
|
|
652
|
+
str: The TrialComponentName used to create a trial component
|
|
653
|
+
which is unique in an account.
|
|
654
|
+
|
|
655
|
+
Raises:
|
|
656
|
+
ValueError: If either the run_name or the experiment_name exceeds
|
|
657
|
+
the length limit.
|
|
658
|
+
"""
|
|
659
|
+
buffer = 1 # leave length buffers for delimiters
|
|
660
|
+
max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer
|
|
661
|
+
err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
|
|
662
|
+
if len(run_name) > max_len:
|
|
663
|
+
raise ValueError(err_msg_template.format("run_name", len(run_name), max_len))
|
|
664
|
+
if len(experiment_name) > max_len:
|
|
665
|
+
raise ValueError(
|
|
666
|
+
err_msg_template.format("experiment_name", len(experiment_name), max_len)
|
|
667
|
+
)
|
|
668
|
+
trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name)
|
|
669
|
+
# due to mixed-case concerns on the backend
|
|
670
|
+
trial_component_name = trial_component_name.lower()
|
|
671
|
+
return trial_component_name
|
|
672
|
+
|
|
673
|
+
@staticmethod
|
|
674
|
+
def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str:
|
|
675
|
+
"""Extract the user supplied run name from a trial component name.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
trial_component_name (str): The name of a run trial component.
|
|
679
|
+
experiment_name (str): The experiment_name supplied by the user,
|
|
680
|
+
which was prepended to the run_name to generate the trial_component_name.
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
str: The name of the Run object supplied by a user.
|
|
684
|
+
"""
|
|
685
|
+
# TODO: we should revert the lower casting once backend fix reaches prod
|
|
686
|
+
return trial_component_name.replace(
|
|
687
|
+
"{}{}".format(experiment_name.lower(), DELIMITER), "", 1
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
@staticmethod
|
|
691
|
+
def _append_run_tc_label_to_tags(tags: Optional[List[TagsDict]] = None) -> list:
|
|
692
|
+
"""Append the run trial component label to tags used to create a trial component.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
tags (List[TagsDict]): The tags supplied by users to initialize a Run object.
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
list: The updated tags with the appended run trial component label.
|
|
699
|
+
"""
|
|
700
|
+
if not tags:
|
|
701
|
+
tags = []
|
|
702
|
+
if RUN_TC_TAG not in tags:
|
|
703
|
+
tags.append(RUN_TC_TAG)
|
|
704
|
+
return tags
|
|
705
|
+
|
|
706
|
+
def __enter__(self):
|
|
707
|
+
"""Updates the start time of the run.
|
|
708
|
+
|
|
709
|
+
Returns:
|
|
710
|
+
object: self.
|
|
711
|
+
"""
|
|
712
|
+
nested_with_err_msg_template = (
|
|
713
|
+
"It is not allowed to use nested 'with' statements on the {}."
|
|
714
|
+
)
|
|
715
|
+
if self._in_load:
|
|
716
|
+
if self._inside_load_context:
|
|
717
|
+
raise RuntimeError(nested_with_err_msg_template.format("load_run"))
|
|
718
|
+
self._inside_load_context = True
|
|
719
|
+
if not self._inside_init_context:
|
|
720
|
+
# Add to run context only if the load_run is called separately
|
|
721
|
+
# without under a Run init context
|
|
722
|
+
_RunContext.add_run_object(self)
|
|
723
|
+
else:
|
|
724
|
+
if _RunContext.get_current_run():
|
|
725
|
+
raise RuntimeError(nested_with_err_msg_template.format("Run"))
|
|
726
|
+
self._inside_init_context = True
|
|
727
|
+
_RunContext.add_run_object(self)
|
|
728
|
+
|
|
729
|
+
if not self._trial_component.start_time:
|
|
730
|
+
start_time = datetime.datetime.now(dateutil.tz.tzlocal())
|
|
731
|
+
self._trial_component.start_time = start_time
|
|
732
|
+
self._trial_component.status = _api_types.TrialComponentStatus(
|
|
733
|
+
primary_status=_TrialComponentStatusType.InProgress.value,
|
|
734
|
+
message="Within a run context",
|
|
735
|
+
)
|
|
736
|
+
# Save the start_time and status changes to backend
|
|
737
|
+
self._trial_component.save()
|
|
738
|
+
return self
|
|
739
|
+
|
|
740
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
741
|
+
"""Updates the end time of the run.
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
exc_type (str): The exception type.
|
|
745
|
+
exc_value (str): The exception value.
|
|
746
|
+
exc_traceback (str): The stack trace of the exception.
|
|
747
|
+
"""
|
|
748
|
+
if self._in_load:
|
|
749
|
+
self._inside_load_context = False
|
|
750
|
+
self._in_load = False
|
|
751
|
+
if not self._inside_init_context:
|
|
752
|
+
_RunContext.drop_current_run()
|
|
753
|
+
else:
|
|
754
|
+
self._inside_init_context = False
|
|
755
|
+
_RunContext.drop_current_run()
|
|
756
|
+
|
|
757
|
+
end_time = datetime.datetime.now(dateutil.tz.tzlocal())
|
|
758
|
+
self._trial_component.end_time = end_time
|
|
759
|
+
if exc_value:
|
|
760
|
+
self._trial_component.status = _api_types.TrialComponentStatus(
|
|
761
|
+
primary_status=_TrialComponentStatusType.Failed.value,
|
|
762
|
+
message=str(exc_value),
|
|
763
|
+
)
|
|
764
|
+
else:
|
|
765
|
+
self._trial_component.status = _api_types.TrialComponentStatus(
|
|
766
|
+
primary_status=_TrialComponentStatusType.Completed.value
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
self.close()
|
|
770
|
+
|
|
771
|
+
def __getstate__(self):
|
|
772
|
+
"""Overriding this method to prevent instance of Run from being pickled.
|
|
773
|
+
|
|
774
|
+
Raise:
|
|
775
|
+
NotImplementedError: If attempting to pickle this instance.
|
|
776
|
+
"""
|
|
777
|
+
raise NotImplementedError("Instance of Run type is not allowed to be pickled.")
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def load_run(
|
|
781
|
+
run_name: Optional[str] = None,
|
|
782
|
+
experiment_name: Optional[str] = None,
|
|
783
|
+
sagemaker_session: Optional["Session"] = None,
|
|
784
|
+
artifact_bucket: Optional[str] = None,
|
|
785
|
+
artifact_prefix: Optional[str] = None,
|
|
786
|
+
tags: Optional[List[Dict[str, str]]] = None,
|
|
787
|
+
) -> Run:
|
|
788
|
+
"""Load an existing run.
|
|
789
|
+
|
|
790
|
+
In order to reuse an existing run to log extra data, ``load_run`` is recommended.
|
|
791
|
+
It can be used in several ways:
|
|
792
|
+
|
|
793
|
+
1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``.
|
|
794
|
+
|
|
795
|
+
If ``run_name`` and ``experiment_name`` are passed in, they are honored over
|
|
796
|
+
the default experiment config in the job environment or the run context
|
|
797
|
+
(i.e. within the ``with`` block).
|
|
798
|
+
|
|
799
|
+
Note:
|
|
800
|
+
Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work.
|
|
801
|
+
Otherwise, you may get a ``ValueError``.
|
|
802
|
+
|
|
803
|
+
.. code:: python
|
|
804
|
+
|
|
805
|
+
with load_run(experiment_name="my-exp", run_name="my-run") as run:
|
|
806
|
+
run.log_metric(...)
|
|
807
|
+
...
|
|
808
|
+
|
|
809
|
+
2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``.
|
|
810
|
+
|
|
811
|
+
In this case, the default experiment config (specified when creating the job) is fetched
|
|
812
|
+
from the job environment to load the run.
|
|
813
|
+
|
|
814
|
+
.. code:: python
|
|
815
|
+
|
|
816
|
+
# In a job script
|
|
817
|
+
with load_run() as run:
|
|
818
|
+
run.log_metric(...)
|
|
819
|
+
...
|
|
820
|
+
|
|
821
|
+
3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block)
|
|
822
|
+
but without supplying ``run_name`` and ``experiment_name``.
|
|
823
|
+
|
|
824
|
+
Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked
|
|
825
|
+
in the run context. Then when we call ``load_run()`` under this with statement, the ``run1``
|
|
826
|
+
in the context is loaded by default.
|
|
827
|
+
|
|
828
|
+
.. code:: python
|
|
829
|
+
|
|
830
|
+
# In a notebook
|
|
831
|
+
with Run(experiment_name="my-exp", run_name="my-run", ...) as run1:
|
|
832
|
+
run1.log_parameter(...)
|
|
833
|
+
|
|
834
|
+
with load_run() as run2: # run2 is the same object as run1
|
|
835
|
+
run2.log_metric(...)
|
|
836
|
+
...
|
|
837
|
+
|
|
838
|
+
Args:
|
|
839
|
+
run_name (str): The name of the run to be loaded (default: None).
|
|
840
|
+
If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be
|
|
841
|
+
fetched to load the run.
|
|
842
|
+
experiment_name (str): The name of the Experiment that the to be loaded run
|
|
843
|
+
is associated with (default: None).
|
|
844
|
+
Note: the experiment_name must be supplied along with a valid run_name.
|
|
845
|
+
Otherwise, it will be ignored.
|
|
846
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
847
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
848
|
+
AWS services needed. If not specified, one is created using the
|
|
849
|
+
default AWS configuration chain.
|
|
850
|
+
artifact_bucket (str): The S3 bucket to upload the artifact to.
|
|
851
|
+
If not specified, the default bucket defined in `sagemaker_session`
|
|
852
|
+
will be used.
|
|
853
|
+
artifact_prefix (str): The S3 key prefix used to generate the S3 path
|
|
854
|
+
to upload the artifact to (default: "trial-component-artifacts").
|
|
855
|
+
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
|
|
856
|
+
e.g. to create an experiment, a run group, etc. (default: None).
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
Run: The loaded Run object.
|
|
860
|
+
"""
|
|
861
|
+
environment = _RunEnvironment.load()
|
|
862
|
+
|
|
863
|
+
verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
|
|
864
|
+
|
|
865
|
+
if run_name:
|
|
866
|
+
logger.warning(
|
|
867
|
+
"run_name is explicitly supplied in load_run, "
|
|
868
|
+
"which will be prioritized to load the Run object. "
|
|
869
|
+
"In other words, the run name in the experiment config, fetched from the "
|
|
870
|
+
"job environment or the current run context, will be ignored."
|
|
871
|
+
)
|
|
872
|
+
run_instance = Run(
|
|
873
|
+
experiment_name=experiment_name,
|
|
874
|
+
run_name=run_name,
|
|
875
|
+
sagemaker_session=sagemaker_session or _utils.default_session(),
|
|
876
|
+
artifact_bucket=artifact_bucket,
|
|
877
|
+
artifact_prefix=artifact_prefix,
|
|
878
|
+
tags=tags,
|
|
879
|
+
)
|
|
880
|
+
elif _RunContext.get_current_run():
|
|
881
|
+
run_instance = _RunContext.get_current_run()
|
|
882
|
+
elif environment:
|
|
883
|
+
exp_config = get_tc_and_exp_config_from_job_env(
|
|
884
|
+
environment=environment,
|
|
885
|
+
sagemaker_session=sagemaker_session or _utils.default_session(),
|
|
886
|
+
)
|
|
887
|
+
run_name = Run._extract_run_name_from_tc_name(
|
|
888
|
+
trial_component_name=exp_config[RUN_NAME],
|
|
889
|
+
experiment_name=exp_config[EXPERIMENT_NAME],
|
|
890
|
+
)
|
|
891
|
+
experiment_name = exp_config[EXPERIMENT_NAME]
|
|
892
|
+
run_instance = Run(
|
|
893
|
+
experiment_name=experiment_name,
|
|
894
|
+
run_name=run_name,
|
|
895
|
+
sagemaker_session=sagemaker_session or _utils.default_session(),
|
|
896
|
+
artifact_bucket=artifact_bucket,
|
|
897
|
+
artifact_prefix=artifact_prefix,
|
|
898
|
+
tags=tags,
|
|
899
|
+
)
|
|
900
|
+
else:
|
|
901
|
+
raise RuntimeError(
|
|
902
|
+
"Failed to load a Run object. "
|
|
903
|
+
"Please make sure a Run object has been initialized already."
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
run_instance._in_load = True
|
|
907
|
+
return run_instance
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def list_runs(
|
|
911
|
+
experiment_name: str,
|
|
912
|
+
created_before: Optional[datetime.datetime] = None,
|
|
913
|
+
created_after: Optional[datetime.datetime] = None,
|
|
914
|
+
sagemaker_session: Optional["Session"] = None,
|
|
915
|
+
max_results: Optional[int] = None,
|
|
916
|
+
next_token: Optional[str] = None,
|
|
917
|
+
sort_by: SortByType = SortByType.CREATION_TIME,
|
|
918
|
+
sort_order: SortOrderType = SortOrderType.DESCENDING,
|
|
919
|
+
) -> list:
|
|
920
|
+
"""Return a list of ``Run`` objects matching the given criteria.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
experiment_name (str): Only Run objects related to the specified experiment
|
|
924
|
+
are returned.
|
|
925
|
+
created_before (datetime.datetime): Return Run objects created before this instant
|
|
926
|
+
(default: None).
|
|
927
|
+
created_after (datetime.datetime): Return Run objects created after this instant
|
|
928
|
+
(default: None).
|
|
929
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
930
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
931
|
+
AWS services needed. If not specified, one is created using the
|
|
932
|
+
default AWS configuration chain.
|
|
933
|
+
max_results (int): Maximum number of Run objects to retrieve (default: None).
|
|
934
|
+
next_token (str): Token for next page of results (default: None).
|
|
935
|
+
sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME
|
|
936
|
+
(default: CREATION_TIME).
|
|
937
|
+
sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING).
|
|
938
|
+
|
|
939
|
+
Returns:
|
|
940
|
+
list: A list of ``Run`` objects.
|
|
941
|
+
"""
|
|
942
|
+
|
|
943
|
+
# all trial components retrieved by default
|
|
944
|
+
tc_summaries = _TrialComponent.list(
|
|
945
|
+
experiment_name=experiment_name,
|
|
946
|
+
created_before=created_before,
|
|
947
|
+
created_after=created_after,
|
|
948
|
+
sort_by=sort_by.value,
|
|
949
|
+
sort_order=sort_order.value,
|
|
950
|
+
sagemaker_session=sagemaker_session,
|
|
951
|
+
max_results=max_results,
|
|
952
|
+
next_token=next_token,
|
|
953
|
+
)
|
|
954
|
+
run_list = []
|
|
955
|
+
for tc_summary in tc_summaries:
|
|
956
|
+
if not is_run_trial_component(
|
|
957
|
+
trial_component_name=tc_summary.trial_component_name,
|
|
958
|
+
sagemaker_session=sagemaker_session,
|
|
959
|
+
):
|
|
960
|
+
continue
|
|
961
|
+
run_instance = Run(
|
|
962
|
+
experiment_name=experiment_name,
|
|
963
|
+
run_name=Run._extract_run_name_from_tc_name(
|
|
964
|
+
trial_component_name=tc_summary.trial_component_name,
|
|
965
|
+
experiment_name=experiment_name,
|
|
966
|
+
),
|
|
967
|
+
sagemaker_session=sagemaker_session,
|
|
968
|
+
)
|
|
969
|
+
run_list.append(run_instance)
|
|
970
|
+
return run_list
|