sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
sagemaker/core/enums.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Defines enum values."""
|
|
14
|
+
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
LOGGER = logging.getLogger("sagemaker")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EndpointType(Enum):
|
|
25
|
+
"""Types of endpoint"""
|
|
26
|
+
|
|
27
|
+
MODEL_BASED = "ModelBased" # Amazon SageMaker Model Based Endpoint
|
|
28
|
+
INFERENCE_COMPONENT_BASED = (
|
|
29
|
+
"InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RoutingStrategy(Enum):
|
|
34
|
+
"""Strategy for routing https traffics."""
|
|
35
|
+
|
|
36
|
+
RANDOM = "RANDOM"
|
|
37
|
+
"""The endpoint routes each request to a randomly chosen instance.
|
|
38
|
+
"""
|
|
39
|
+
LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS"
|
|
40
|
+
"""The endpoint routes requests to the specific instances that have
|
|
41
|
+
more capacity to process them.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Tag(str, Enum):
|
|
46
|
+
"""Enum class for tag keys to apply to models."""
|
|
47
|
+
|
|
48
|
+
OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
|
|
49
|
+
SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
|
|
50
|
+
FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
|
|
51
|
+
FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Accessors to retrieve environment variables for hosting containers."""
|
|
14
|
+
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Dict, Optional
|
|
19
|
+
|
|
20
|
+
from sagemaker.core.jumpstart import utils as jumpstart_utils
|
|
21
|
+
from sagemaker.core.jumpstart import artifacts
|
|
22
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
23
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType, JumpStartScriptScope
|
|
24
|
+
from sagemaker.core.helper.session_helper import Session
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def retrieve_default(
|
|
30
|
+
region: Optional[str] = None,
|
|
31
|
+
model_id: Optional[str] = None,
|
|
32
|
+
model_version: Optional[str] = None,
|
|
33
|
+
hub_arn: Optional[str] = None,
|
|
34
|
+
tolerate_vulnerable_model: bool = False,
|
|
35
|
+
tolerate_deprecated_model: bool = False,
|
|
36
|
+
include_aws_sdk_env_vars: bool = True,
|
|
37
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
38
|
+
instance_type: Optional[str] = None,
|
|
39
|
+
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
|
|
40
|
+
config_name: Optional[str] = None,
|
|
41
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
42
|
+
) -> Dict[str, str]:
|
|
43
|
+
"""Retrieves the default container environment variables for the model matching the arguments.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
region (str): Optional. The AWS Region for which to retrieve the default environment
|
|
47
|
+
variables. (Default: None).
|
|
48
|
+
model_id (str): Optional. The model ID of the model for which to
|
|
49
|
+
retrieve the default environment variables. (Default: None).
|
|
50
|
+
model_version (str): Optional. The version of the model for which to retrieve the
|
|
51
|
+
default environment variables. (Default: None).
|
|
52
|
+
hub_arn (str): The arn of the SageMaker Hub for which to
|
|
53
|
+
retrieve model details from. (Default: None).
|
|
54
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
55
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
56
|
+
exception if the script used by this version of the model has dependencies with known
|
|
57
|
+
security vulnerabilities. (Default: False).
|
|
58
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
59
|
+
(exception not raised). False if these models should raise an exception.
|
|
60
|
+
(Default: False).
|
|
61
|
+
include_aws_sdk_env_vars (bool): True if environment variables for low-level AWS API call
|
|
62
|
+
should be included. The `Model` class of the SageMaker Python SDK inserts environment
|
|
63
|
+
variables that would be required when making the low-level AWS API call.
|
|
64
|
+
(Default: True).
|
|
65
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
66
|
+
object, used for SageMaker interactions. If not
|
|
67
|
+
specified, one is created using the default AWS configuration
|
|
68
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
69
|
+
instance_type (str): An instance type to optionally supply in order to get environment
|
|
70
|
+
variables specific for the instance type.
|
|
71
|
+
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
|
|
72
|
+
variables.
|
|
73
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
74
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
75
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
76
|
+
Returns:
|
|
77
|
+
dict: The variables to use for the model.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
81
|
+
"""
|
|
82
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Must specify JumpStart `model_id` and `model_version` "
|
|
85
|
+
"when retrieving environment variables."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return artifacts._retrieve_default_environment_variables(
|
|
89
|
+
model_id=model_id,
|
|
90
|
+
model_version=model_version,
|
|
91
|
+
hub_arn=hub_arn,
|
|
92
|
+
region=region,
|
|
93
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
94
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
95
|
+
include_aws_sdk_env_vars=include_aws_sdk_env_vars,
|
|
96
|
+
sagemaker_session=sagemaker_session,
|
|
97
|
+
instance_type=instance_type,
|
|
98
|
+
script=script,
|
|
99
|
+
config_name=config_name,
|
|
100
|
+
model_type=model_type,
|
|
101
|
+
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Custom exception classes for Sagemaker SDK"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnexpectedStatusException(ValueError):
|
|
18
|
+
"""Raised when resource status is not expected and thus not allowed for further execution"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, message, allowed_statuses, actual_status):
|
|
21
|
+
self.allowed_statuses = allowed_statuses
|
|
22
|
+
self.actual_status = actual_status
|
|
23
|
+
super(UnexpectedStatusException, self).__init__(message)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CapacityError(UnexpectedStatusException):
|
|
27
|
+
"""Raised when resource status is not expected and fails with a reason of CapacityError"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AsyncInferenceError(Exception):
|
|
31
|
+
"""The base exception class for Async Inference exceptions."""
|
|
32
|
+
|
|
33
|
+
fmt = "An unspecified error occurred"
|
|
34
|
+
|
|
35
|
+
def __init__(self, **kwargs):
|
|
36
|
+
msg = self.fmt.format(**kwargs)
|
|
37
|
+
Exception.__init__(self, msg)
|
|
38
|
+
self.kwargs = kwargs
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ObjectNotExistedError(AsyncInferenceError):
|
|
42
|
+
"""Raised when Amazon S3 object not exist in the given path"""
|
|
43
|
+
|
|
44
|
+
fmt = "Object not exist at {output_path}. {message}"
|
|
45
|
+
|
|
46
|
+
def __init__(self, message, output_path):
|
|
47
|
+
super().__init__(message=message, output_path=output_path)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class PollingTimeoutError(AsyncInferenceError):
|
|
51
|
+
"""Raised when wait longer than expected and no result object in Amazon S3 bucket yet"""
|
|
52
|
+
|
|
53
|
+
fmt = "No result at {output_path} after polling for {seconds} seconds. {message}"
|
|
54
|
+
|
|
55
|
+
def __init__(self, message, output_path, seconds):
|
|
56
|
+
super().__init__(message=message, output_path=output_path, seconds=seconds)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class UnexpectedClientError(AsyncInferenceError):
|
|
60
|
+
"""Raised when ClientError's error code is not expected"""
|
|
61
|
+
|
|
62
|
+
fmt = "Encountered unexpected client error: {message}"
|
|
63
|
+
|
|
64
|
+
def __init__(self, message):
|
|
65
|
+
super().__init__(message=message)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AutoMLStepInvalidModeError(Exception):
|
|
69
|
+
"""Raised when the automl mode passed into AutoMLStep in invalid"""
|
|
70
|
+
|
|
71
|
+
fmt = (
|
|
72
|
+
"Mode in AutoMLJobConfig must be defined for AutoMLStep. "
|
|
73
|
+
"AutoMLStep currently only supports ENSEMBLING mode"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def __init__(self, **kwargs):
|
|
77
|
+
msg = self.fmt.format(**kwargs)
|
|
78
|
+
Exception.__init__(self, msg)
|
|
79
|
+
self.kwargs = kwargs
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class AsyncInferenceModelError(AsyncInferenceError):
|
|
83
|
+
"""Raised when model returns errors for failed requests"""
|
|
84
|
+
|
|
85
|
+
fmt = "Model returned error: {message} "
|
|
86
|
+
|
|
87
|
+
def __init__(self, message):
|
|
88
|
+
super().__init__(message=message)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ModelStreamError(Exception):
|
|
92
|
+
"""Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, message="An error occurred", code=None):
|
|
95
|
+
self.message = message
|
|
96
|
+
self.code = code
|
|
97
|
+
if code is not None:
|
|
98
|
+
super().__init__(f"{message} (Code: {code})")
|
|
99
|
+
else:
|
|
100
|
+
super().__init__(message)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class InternalStreamFailure(Exception):
|
|
104
|
+
"""Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, message="An error occurred"):
|
|
107
|
+
self.message = message
|
|
108
|
+
super().__init__(self.message)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""SageMaker Experiments module for tracking experiments, trials, and runs."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
# Lazy imports to avoid circular dependencies during package initialization
|
|
17
|
+
# Users should import directly from the specific modules:
|
|
18
|
+
# from sagemaker.core.experiments.experiment import Experiment
|
|
19
|
+
# from sagemaker.core.experiments.run import Run
|
|
20
|
+
# etc.
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"Experiment",
|
|
24
|
+
"Run",
|
|
25
|
+
"_RunContext",
|
|
26
|
+
"_Trial",
|
|
27
|
+
"_TrialComponent",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name):
|
|
32
|
+
"""Lazy import to avoid circular dependencies."""
|
|
33
|
+
if name == "Experiment":
|
|
34
|
+
from sagemaker.core.experiments.experiment import Experiment
|
|
35
|
+
|
|
36
|
+
return Experiment
|
|
37
|
+
elif name == "Run":
|
|
38
|
+
from sagemaker.core.experiments.run import Run
|
|
39
|
+
|
|
40
|
+
return Run
|
|
41
|
+
elif name == "_RunContext":
|
|
42
|
+
from sagemaker.core.experiments._run_context import _RunContext
|
|
43
|
+
|
|
44
|
+
return _RunContext
|
|
45
|
+
elif name == "_Trial":
|
|
46
|
+
from sagemaker.core.experiments.trial import _Trial
|
|
47
|
+
|
|
48
|
+
return _Trial
|
|
49
|
+
elif name == "_TrialComponent":
|
|
50
|
+
from sagemaker.core.experiments.trial_component import _TrialComponent
|
|
51
|
+
|
|
52
|
+
return _TrialComponent
|
|
53
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Contains API objects for SageMaker experiments."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import enum
|
|
17
|
+
import numbers
|
|
18
|
+
|
|
19
|
+
from sagemaker.core.apiutils import _base_types
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TrialComponentMetricSummary(_base_types.ApiObject):
|
|
23
|
+
"""Summary model of a trial component.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
metric_name (str): The name of the metric.
|
|
27
|
+
source_arn (str): The ARN of the source.
|
|
28
|
+
time_stamp (datetime): Metric last updated value.
|
|
29
|
+
max (float): The max value of the metric.
|
|
30
|
+
min (float): The min value of the metric.
|
|
31
|
+
last (float): The last value of the metric.
|
|
32
|
+
count (float): The number of samples used to generate the metric.
|
|
33
|
+
avg (float): The average value of the metric.
|
|
34
|
+
std_dev (float): The standard deviation of the metric.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
metric_name = None
|
|
38
|
+
source_arn = None
|
|
39
|
+
time_stamp = None
|
|
40
|
+
max = None
|
|
41
|
+
min = None
|
|
42
|
+
last = None
|
|
43
|
+
count = None
|
|
44
|
+
avg = None
|
|
45
|
+
std_dev = None
|
|
46
|
+
|
|
47
|
+
def __init__(self, metric_name=None, source_arn=None, **kwargs):
|
|
48
|
+
super(TrialComponentMetricSummary, self).__init__(
|
|
49
|
+
metric_name=metric_name, source_arn=source_arn, **kwargs
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TrialComponentParameters(_base_types.ApiObject):
|
|
54
|
+
"""A dictionary of TrialComponentParameterValues"""
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_boto(cls, boto_dict, **kwargs):
|
|
58
|
+
"""Converts a boto dict to a dictionary of TrialComponentParameterValues
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
boto_dict (dict): boto response dictionary.
|
|
62
|
+
**kwargs: Arbitrary keyword arguments.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
dict: Dictionary of parameter values.
|
|
66
|
+
"""
|
|
67
|
+
return_map = {}
|
|
68
|
+
for key, value in boto_dict.items():
|
|
69
|
+
return_map[key] = value.get("NumberValue", value.get("StringValue", None))
|
|
70
|
+
return return_map
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def to_boto(cls, parameters):
|
|
74
|
+
"""Converts TrialComponentParameters to dict.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
parameters (TrialComponentParameters): Dictionary to convert.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
dict: Dictionary of trial component parameters in boto format.
|
|
81
|
+
"""
|
|
82
|
+
boto_map = {}
|
|
83
|
+
for key, value in parameters.items():
|
|
84
|
+
if isinstance(value, numbers.Number):
|
|
85
|
+
boto_map[key] = {"NumberValue": value}
|
|
86
|
+
else:
|
|
87
|
+
boto_map[key] = {"StringValue": str(value)}
|
|
88
|
+
return boto_map
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class TrialComponentArtifact(_base_types.ApiObject):
|
|
92
|
+
"""Trial component artifact.
|
|
93
|
+
|
|
94
|
+
Attributes:
|
|
95
|
+
value (str): The artifact value.
|
|
96
|
+
media_type (str): The media type.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
value = None
|
|
100
|
+
media_type = None
|
|
101
|
+
|
|
102
|
+
def __init__(self, value=None, media_type=None, **kwargs):
|
|
103
|
+
super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _TrialComponentStatusType(enum.Enum):
|
|
107
|
+
"""The type of trial component status"""
|
|
108
|
+
|
|
109
|
+
InProgress = "InProgress"
|
|
110
|
+
Completed = "Completed"
|
|
111
|
+
Failed = "Failed"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class TrialComponentStatus(_base_types.ApiObject):
|
|
115
|
+
"""Status of the trial component.
|
|
116
|
+
|
|
117
|
+
Attributes:
|
|
118
|
+
primary_status (str): The status of a trial component.
|
|
119
|
+
message (str): Status message.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
primary_status = None
|
|
123
|
+
message = None
|
|
124
|
+
|
|
125
|
+
def __init__(self, primary_status=None, message=None, **kwargs):
|
|
126
|
+
super(TrialComponentStatus, self).__init__(
|
|
127
|
+
primary_status=primary_status, message=message, **kwargs
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TrialComponentSummary(_base_types.ApiObject):
|
|
132
|
+
"""Summary model of a trial component.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
trial_component_name (str): Name of trial component.
|
|
136
|
+
trial_component_arn (str): ARN of the trial component.
|
|
137
|
+
display_name (str): Friendly display name in UI.
|
|
138
|
+
source_arn (str): ARN of the trial component source.
|
|
139
|
+
status (str): Status.
|
|
140
|
+
start_time (datetime): Start time.
|
|
141
|
+
end_time (datetime): End time.
|
|
142
|
+
creation_time (datetime): Creation time.
|
|
143
|
+
created_by (str): Created by.
|
|
144
|
+
last_modified_time (datetime): Date last modified.
|
|
145
|
+
last_modified_by (datetime): User last modified.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
_custom_boto_types = {
|
|
149
|
+
"status": (TrialComponentStatus, False),
|
|
150
|
+
}
|
|
151
|
+
trial_component_name = None
|
|
152
|
+
trial_component_arn = None
|
|
153
|
+
display_name = None
|
|
154
|
+
source_arn = None
|
|
155
|
+
status = None
|
|
156
|
+
start_time = None
|
|
157
|
+
end_time = None
|
|
158
|
+
creation_time = None
|
|
159
|
+
created_by = None
|
|
160
|
+
last_modified_time = None
|
|
161
|
+
last_modified_by = None
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class TrialComponentSource(_base_types.ApiObject):
|
|
165
|
+
"""Trial Component Source
|
|
166
|
+
|
|
167
|
+
Attributes:
|
|
168
|
+
source_arn (str): The ARN of the source.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
source_arn = None
|
|
172
|
+
|
|
173
|
+
def __init__(self, source_arn=None, **kwargs):
|
|
174
|
+
super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Parent(_base_types.ApiObject):
|
|
178
|
+
"""The trial/experiment/run that a trial component is associated with.
|
|
179
|
+
|
|
180
|
+
Attributes:
|
|
181
|
+
trial_name (str): Name of the trial.
|
|
182
|
+
experiment_name (str): Name of the experiment.
|
|
183
|
+
run_name (str): Name of the run.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
trial_name = None
|
|
187
|
+
experiment_name = None
|
|
188
|
+
run_name = None
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class TrialComponentSearchResult(_base_types.ApiObject):
|
|
192
|
+
"""Summary model of an Trial Component search result.
|
|
193
|
+
|
|
194
|
+
Attributes:
|
|
195
|
+
trial_component_arn (str): ARN of the trial component.
|
|
196
|
+
trial_component_name (str): Name of the trial component.
|
|
197
|
+
display_name (str): Display name of the trial component for UI display.
|
|
198
|
+
source (dict): The source of the trial component.
|
|
199
|
+
status (dict): The status of the trial component.
|
|
200
|
+
start_time (datetime): Start time.
|
|
201
|
+
end_time (datetime): End time.
|
|
202
|
+
creation_time (datetime): Creation time.
|
|
203
|
+
created_by (str): Created by.
|
|
204
|
+
last_modified_time (datetime): Date last modified.
|
|
205
|
+
last_modified_by (datetime): User last modified.
|
|
206
|
+
parameters (dict): The hyperparameters of the component.
|
|
207
|
+
input_artifacts (dict): The input artifacts of the component.
|
|
208
|
+
output_artifacts (dict): The output artifacts of the component.
|
|
209
|
+
metrics (list): The metrics for the component.
|
|
210
|
+
source_detail (dict): The source of the trial component.
|
|
211
|
+
tags (list): The list of tags that are associated with the trial component.
|
|
212
|
+
parents (list[Parent]): The parent of trial component.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
_custom_boto_types = {
|
|
216
|
+
"parents": (Parent, True), # parents is a collection (list) of Parent objects
|
|
217
|
+
}
|
|
218
|
+
trial_component_arn = None
|
|
219
|
+
trial_component_name = None
|
|
220
|
+
display_name = None
|
|
221
|
+
source = None
|
|
222
|
+
status = None
|
|
223
|
+
start_time = None
|
|
224
|
+
end_time = None
|
|
225
|
+
creation_time = None
|
|
226
|
+
created_by = None
|
|
227
|
+
last_modified_time = None
|
|
228
|
+
last_modified_by = None
|
|
229
|
+
parameters = None
|
|
230
|
+
input_artifacts = None
|
|
231
|
+
output_artifacts = None
|
|
232
|
+
metrics = None
|
|
233
|
+
source_detail = None
|
|
234
|
+
tags = None
|
|
235
|
+
parents = None
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class TrialSummary(_base_types.ApiObject):
|
|
239
|
+
"""Summary model of a trial.
|
|
240
|
+
|
|
241
|
+
Attributes:
|
|
242
|
+
trial_arn (str): The ARN of the trial.
|
|
243
|
+
trial_name (str): The name of the trial.
|
|
244
|
+
creation_time (datetime): When the trial was created.
|
|
245
|
+
last_modified_time (datetime): When the trial was last modified.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
trial_arn = None
|
|
249
|
+
trial_name = None
|
|
250
|
+
creation_time = None
|
|
251
|
+
last_modified_time = None
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Contains the _RunEnvironment class."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import enum
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
from sagemaker.core.helper.session_helper import Session
|
|
22
|
+
from sagemaker.core.experiments import trial_component
|
|
23
|
+
from sagemaker.core.common_utils import retry_with_backoff
|
|
24
|
+
|
|
25
|
+
TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
|
|
26
|
+
PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
|
|
27
|
+
TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN"
|
|
28
|
+
MAX_RETRY_ATTEMPTS = 7
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class _EnvironmentType(enum.Enum):
|
|
34
|
+
"""SageMaker jobs which data can be pulled from the environment."""
|
|
35
|
+
|
|
36
|
+
SageMakerTrainingJob = 1
|
|
37
|
+
SageMakerProcessingJob = 2
|
|
38
|
+
SageMakerTransformJob = 3
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _RunEnvironment(object):
|
|
42
|
+
"""Retrieves job specific data from the environment."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, environment_type: _EnvironmentType, source_arn: str):
|
|
45
|
+
"""Init for _RunEnvironment.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
environment_type (_EnvironmentType): The environment type.
|
|
49
|
+
source_arn (str): The ARN of the current job.
|
|
50
|
+
"""
|
|
51
|
+
self.environment_type = environment_type
|
|
52
|
+
self.source_arn = source_arn
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def load(
|
|
56
|
+
cls,
|
|
57
|
+
training_job_arn_env: str = TRAINING_JOB_ARN_ENV,
|
|
58
|
+
processing_job_config_path: str = PROCESSING_JOB_CONFIG_PATH,
|
|
59
|
+
transform_job_arn_env: str = TRANSFORM_JOB_ARN_ENV,
|
|
60
|
+
):
|
|
61
|
+
"""Loads source arn of current job from environment.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
training_job_arn_env (str): The environment key for training job ARN
|
|
65
|
+
(default: `TRAINING_JOB_ARN`).
|
|
66
|
+
processing_job_config_path (str): The processing job config path
|
|
67
|
+
(default: `/opt/ml/config/processingjobconfig.json`).
|
|
68
|
+
transform_job_arn_env (str): The environment key for transform job ARN
|
|
69
|
+
(default: `TRANSFORM_JOB_ARN_ENV`).
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
|
|
73
|
+
"""
|
|
74
|
+
if training_job_arn_env in os.environ:
|
|
75
|
+
environment_type = _EnvironmentType.SageMakerTrainingJob
|
|
76
|
+
source_arn = os.environ.get(training_job_arn_env)
|
|
77
|
+
return _RunEnvironment(environment_type, source_arn)
|
|
78
|
+
if os.path.exists(processing_job_config_path):
|
|
79
|
+
environment_type = _EnvironmentType.SageMakerProcessingJob
|
|
80
|
+
source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
|
|
81
|
+
return _RunEnvironment(environment_type, source_arn)
|
|
82
|
+
if transform_job_arn_env in os.environ:
|
|
83
|
+
environment_type = _EnvironmentType.SageMakerTransformJob
|
|
84
|
+
# TODO: need to update to get source_arn from config file once Transform side ready
|
|
85
|
+
source_arn = os.environ.get(transform_job_arn_env)
|
|
86
|
+
return _RunEnvironment(environment_type, source_arn)
|
|
87
|
+
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def get_trial_component(self, sagemaker_session: Session):
|
|
91
|
+
"""Retrieves the trial component from the job in the environment.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
|
|
95
|
+
manages interactions with Amazon SageMaker APIs and any other
|
|
96
|
+
AWS services needed. If not specified, one is created using the
|
|
97
|
+
default AWS configuration chain.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
_TrialComponent: The trial component created from the job. None if not found.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def _get_trial_component():
|
|
104
|
+
summaries = list(
|
|
105
|
+
trial_component._TrialComponent.list(
|
|
106
|
+
source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
if summaries:
|
|
110
|
+
summary = summaries[0]
|
|
111
|
+
return trial_component._TrialComponent.load(
|
|
112
|
+
trial_component_name=summary.trial_component_name,
|
|
113
|
+
sagemaker_session=sagemaker_session,
|
|
114
|
+
)
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
job_tc = None
|
|
118
|
+
try:
|
|
119
|
+
job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS)
|
|
120
|
+
except Exception as ex: # pylint: disable=broad-except
|
|
121
|
+
logger.error(
|
|
122
|
+
"Failed to get trail component in the current environment due to %s", str(ex)
|
|
123
|
+
)
|
|
124
|
+
return job_tc
|