sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Utils module."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import json
|
|
18
|
+
import subprocess
|
|
19
|
+
import tempfile
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from datetime import datetime
|
|
23
|
+
from typing import Literal, Any
|
|
24
|
+
|
|
25
|
+
from sagemaker.core.shapes import Unassigned
|
|
26
|
+
from sagemaker.core.modules import logger
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
|
|
30
|
+
"""Check if the path is a valid S3 URI.
|
|
31
|
+
|
|
32
|
+
This method checks if the path is a valid S3 URI. If the path_type is specified,
|
|
33
|
+
it will also check if the path is a file or a directory.
|
|
34
|
+
This method does not check if the S3 bucket or object exists.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
path (str): S3 URI to validate
|
|
38
|
+
path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
|
|
39
|
+
Defaults to "Any".
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
bool: True if the path is a valid S3 URI, False otherwise
|
|
43
|
+
"""
|
|
44
|
+
# Check if the path is a valid S3 URI
|
|
45
|
+
if not path.startswith("s3://"):
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
if path_type == "File":
|
|
49
|
+
# If it's a file, it should not end with a slash
|
|
50
|
+
return not path.endswith("/")
|
|
51
|
+
if path_type == "Directory":
|
|
52
|
+
# If it's a directory, it should end with a slash
|
|
53
|
+
return path.endswith("/")
|
|
54
|
+
|
|
55
|
+
return path_type == "Any"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
|
|
59
|
+
"""Check if the path is a valid local path.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
path (str): Local path to validate
|
|
63
|
+
path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
|
|
64
|
+
Defaults to "Any".
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
bool: True if the path is a valid local path, False otherwise
|
|
68
|
+
"""
|
|
69
|
+
if not os.path.exists(path):
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
if path_type == "File":
|
|
73
|
+
return os.path.isfile(path)
|
|
74
|
+
if path_type == "Directory":
|
|
75
|
+
return os.path.isdir(path)
|
|
76
|
+
|
|
77
|
+
return path_type == "Any"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _get_unique_name(base, max_length=63):
|
|
81
|
+
"""Generate a unique name based on the base name.
|
|
82
|
+
|
|
83
|
+
This method generates a unique name based on the base name.
|
|
84
|
+
The unique name is generated by appending the current timestamp
|
|
85
|
+
to the base name.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
base (str): The base name to use
|
|
89
|
+
max_length (int): The maximum length of the unique name. Defaults to 63.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
str: The unique name
|
|
93
|
+
"""
|
|
94
|
+
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
95
|
+
base = base.replace("_", "-")
|
|
96
|
+
unique_name = f"{base}-{current_time}"
|
|
97
|
+
unique_name = unique_name[:max_length] # Truncate to max_length
|
|
98
|
+
return unique_name
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _get_repo_name_from_image(image: str) -> str:
|
|
102
|
+
"""Get the repository name from the image URI.
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
``` python
|
|
106
|
+
_get_repo_name_from_image("123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest")
|
|
107
|
+
# Returns "my-repo"
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
image (str): The image URI
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
str: The repository name
|
|
115
|
+
"""
|
|
116
|
+
return image.split("/")[-1].split(":")[0]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def convert_unassigned_to_none(instance) -> Any:
|
|
120
|
+
"""Convert Unassigned values to None for any instance."""
|
|
121
|
+
for name, value in instance.__dict__.items():
|
|
122
|
+
if isinstance(value, Unassigned):
|
|
123
|
+
setattr(instance, name, None)
|
|
124
|
+
return instance
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def safe_serialize(data):
|
|
128
|
+
"""Serialize the data without wrapping strings in quotes.
|
|
129
|
+
|
|
130
|
+
This function handles the following cases:
|
|
131
|
+
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
|
|
132
|
+
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
|
|
133
|
+
the JSON-encoded string using `json.dumps()`.
|
|
134
|
+
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
|
|
135
|
+
representation of the data using `str(data)`.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
data (Any): The data to serialize.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
str: The serialized JSON-compatible string or the string representation of the input.
|
|
142
|
+
"""
|
|
143
|
+
if isinstance(data, str):
|
|
144
|
+
return data
|
|
145
|
+
try:
|
|
146
|
+
return json.dumps(data)
|
|
147
|
+
except TypeError:
|
|
148
|
+
return str(data)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _run_clone_command_silent(repo_url, dest_dir):
|
|
152
|
+
"""Run the 'git clone' command with the repo url and the directory to clone the repo into.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
repo_url (str): Git repo url to be cloned.
|
|
156
|
+
dest_dir: (str): Local path where the repo should be cloned into.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
CalledProcessError: If failed to clone git repo.
|
|
160
|
+
"""
|
|
161
|
+
my_env = os.environ.copy()
|
|
162
|
+
if repo_url.startswith("https://"):
|
|
163
|
+
try:
|
|
164
|
+
my_env["GIT_TERMINAL_PROMPT"] = "0"
|
|
165
|
+
subprocess.check_call(
|
|
166
|
+
["git", "clone", repo_url, dest_dir],
|
|
167
|
+
env=my_env,
|
|
168
|
+
stdout=subprocess.DEVNULL,
|
|
169
|
+
stderr=subprocess.DEVNULL,
|
|
170
|
+
)
|
|
171
|
+
except subprocess.CalledProcessError as e:
|
|
172
|
+
logger.error(f"Failed to clone repository: {repo_url}")
|
|
173
|
+
logger.error(f"Error output:\n{e}")
|
|
174
|
+
raise
|
|
175
|
+
elif repo_url.startswith("git@") or repo_url.startswith("ssh://"):
|
|
176
|
+
try:
|
|
177
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
178
|
+
custom_ssh_executable = Path(tmp_dir) / "ssh_batch"
|
|
179
|
+
with open(custom_ssh_executable, "w") as pipe:
|
|
180
|
+
print("#!/bin/sh", file=pipe)
|
|
181
|
+
print("ssh -oBatchMode=yes $@", file=pipe)
|
|
182
|
+
os.chmod(custom_ssh_executable, 0o511)
|
|
183
|
+
my_env["GIT_SSH"] = str(custom_ssh_executable)
|
|
184
|
+
subprocess.check_call(
|
|
185
|
+
["git", "clone", repo_url, dest_dir],
|
|
186
|
+
env=my_env,
|
|
187
|
+
stdout=subprocess.DEVNULL,
|
|
188
|
+
stderr=subprocess.DEVNULL,
|
|
189
|
+
)
|
|
190
|
+
except subprocess.CalledProcessError as e:
|
|
191
|
+
del my_env["GIT_SSH"]
|
|
192
|
+
logger.error(f"Failed to clone repository: {repo_url}")
|
|
193
|
+
logger.error(f"Error output:\n{e}")
|
|
194
|
+
raise
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This file contains code related to network configuration.
|
|
14
|
+
|
|
15
|
+
It also includes encryption, network isolation, and VPC configurations.
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
|
|
19
|
+
from typing import Union, Optional, List
|
|
20
|
+
|
|
21
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NetworkConfig(object):
|
|
25
|
+
"""Accepts network configuration parameters for conversion to request dict.
|
|
26
|
+
|
|
27
|
+
The `_to_request_dict` provides a method to turn the parameters into a dict.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
enable_network_isolation: Union[bool, PipelineVariable] = None,
|
|
33
|
+
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
34
|
+
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
|
|
35
|
+
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
|
|
36
|
+
):
|
|
37
|
+
"""Initialize a ``NetworkConfig`` instance.
|
|
38
|
+
|
|
39
|
+
NetworkConfig accepts network configuration parameters and provides a method to turn
|
|
40
|
+
these parameters into a dictionary.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
enable_network_isolation (bool or PipelineVariable): Boolean that determines
|
|
44
|
+
whether to enable network isolation.
|
|
45
|
+
security_group_ids (list[str] or list[PipelineVariable]): A list of strings representing
|
|
46
|
+
security group IDs.
|
|
47
|
+
subnets (list[str] or list[PipelineVariable]): A list of strings representing subnets.
|
|
48
|
+
encrypt_inter_container_traffic (bool or PipelineVariable): Boolean that determines
|
|
49
|
+
whether to encrypt inter-container traffic. Default value is None.
|
|
50
|
+
"""
|
|
51
|
+
self.enable_network_isolation = enable_network_isolation
|
|
52
|
+
self.security_group_ids = security_group_ids
|
|
53
|
+
self.subnets = subnets
|
|
54
|
+
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
|
|
55
|
+
|
|
56
|
+
def _to_request_dict(self):
|
|
57
|
+
"""Generates a request dictionary using the parameters provided to the class."""
|
|
58
|
+
# Enable Network Isolation should default to False if it is not provided.
|
|
59
|
+
enable_network_isolation = (
|
|
60
|
+
False if self.enable_network_isolation is None else self.enable_network_isolation
|
|
61
|
+
)
|
|
62
|
+
network_config_request = {"EnableNetworkIsolation": enable_network_isolation}
|
|
63
|
+
|
|
64
|
+
if self.encrypt_inter_container_traffic is not None:
|
|
65
|
+
network_config_request["EnableInterContainerTrafficEncryption"] = (
|
|
66
|
+
self.encrypt_inter_container_traffic
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if self.security_group_ids is not None or self.subnets is not None:
|
|
70
|
+
network_config_request["VpcConfig"] = {}
|
|
71
|
+
|
|
72
|
+
if self.security_group_ids is not None:
|
|
73
|
+
network_config_request["VpcConfig"]["SecurityGroupIds"] = self.security_group_ids
|
|
74
|
+
|
|
75
|
+
if self.subnets is not None:
|
|
76
|
+
network_config_request["VpcConfig"]["Subnets"] = self.subnets
|
|
77
|
+
|
|
78
|
+
return network_config_request
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# VPC Utilities (merged from vpc_utils.py)
|
|
82
|
+
|
|
83
|
+
SUBNETS_KEY = "Subnets"
|
|
84
|
+
SECURITY_GROUP_IDS_KEY = "SecurityGroupIds"
|
|
85
|
+
VPC_CONFIG_KEY = "VpcConfig"
|
|
86
|
+
|
|
87
|
+
# A global constant value for methods which can optionally override VpcConfig
|
|
88
|
+
VPC_CONFIG_DEFAULT = "VPC_CONFIG_DEFAULT"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def to_dict(subnets, security_group_ids):
|
|
92
|
+
"""Prepares a VpcConfig dict containing keys 'Subnets' and 'SecurityGroupIds'.
|
|
93
|
+
|
|
94
|
+
This is the dict format expected by SageMaker CreateTrainingJob and CreateModel APIs.
|
|
95
|
+
See https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
subnets (list): list of subnet IDs to use in VpcConfig
|
|
99
|
+
security_group_ids (list): list of security group IDs to use in
|
|
100
|
+
VpcConfig
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A VpcConfig dict containing keys 'Subnets' and 'SecurityGroupIds' If
|
|
104
|
+
either or both parameters are None, returns None
|
|
105
|
+
"""
|
|
106
|
+
if subnets is None or security_group_ids is None:
|
|
107
|
+
return None
|
|
108
|
+
return {SUBNETS_KEY: subnets, SECURITY_GROUP_IDS_KEY: security_group_ids}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def from_dict(vpc_config, do_sanitize=False):
|
|
112
|
+
"""Extracts subnets and security group ids as lists from a VpcConfig dict
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
vpc_config (dict): a VpcConfig dict containing 'Subnets' and
|
|
116
|
+
'SecurityGroupIds'
|
|
117
|
+
do_sanitize (bool): whether to sanitize the VpcConfig dict before
|
|
118
|
+
extracting values
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Tuple of lists as (subnets, security_group_ids) If vpc_config parameter
|
|
122
|
+
is None, returns (None, None)
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
* ValueError if sanitize enabled and vpc_config is invalid
|
|
126
|
+
|
|
127
|
+
* KeyError if sanitize disabled and vpc_config is missing key(s)
|
|
128
|
+
"""
|
|
129
|
+
if do_sanitize:
|
|
130
|
+
vpc_config = sanitize(vpc_config)
|
|
131
|
+
if vpc_config is None:
|
|
132
|
+
return None, None
|
|
133
|
+
return vpc_config[SUBNETS_KEY], vpc_config[SECURITY_GROUP_IDS_KEY]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def sanitize(vpc_config):
|
|
137
|
+
"""Checks and removes unexpected keys from VpcConfig or raises error for violations.
|
|
138
|
+
|
|
139
|
+
Checks that an instance of VpcConfig has the expected keys and values,
|
|
140
|
+
removes unexpected keys, and raises ValueErrors if any expectations are
|
|
141
|
+
violated.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
vpc_config (dict): a VpcConfig dict containing 'Subnets' and
|
|
145
|
+
'SecurityGroupIds'
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
A valid VpcConfig dict containing only 'Subnets' and 'SecurityGroupIds'
|
|
149
|
+
from the vpc_config parameter If vpc_config parameter is None, returns
|
|
150
|
+
None
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError if any expectations are violated:
|
|
154
|
+
* vpc_config must be a non-empty dict
|
|
155
|
+
* vpc_config must have key `Subnets` and the value must be a non-empty list
|
|
156
|
+
* vpc_config must have key `SecurityGroupIds` and the value must be a non-empty list
|
|
157
|
+
"""
|
|
158
|
+
if vpc_config is None:
|
|
159
|
+
return vpc_config
|
|
160
|
+
if not isinstance(vpc_config, dict):
|
|
161
|
+
raise ValueError("vpc_config is not a dict: {}".format(vpc_config))
|
|
162
|
+
if not vpc_config:
|
|
163
|
+
raise ValueError("vpc_config is empty")
|
|
164
|
+
|
|
165
|
+
subnets = vpc_config.get(SUBNETS_KEY)
|
|
166
|
+
if subnets is None:
|
|
167
|
+
raise ValueError("vpc_config is missing key: {}".format(SUBNETS_KEY))
|
|
168
|
+
if not isinstance(subnets, list):
|
|
169
|
+
raise ValueError("vpc_config value for {} is not a list: {}".format(SUBNETS_KEY, subnets))
|
|
170
|
+
if not subnets:
|
|
171
|
+
raise ValueError("vpc_config value for {} is empty".format(SUBNETS_KEY))
|
|
172
|
+
|
|
173
|
+
security_group_ids = vpc_config.get(SECURITY_GROUP_IDS_KEY)
|
|
174
|
+
if security_group_ids is None:
|
|
175
|
+
raise ValueError("vpc_config is missing key: {}".format(SECURITY_GROUP_IDS_KEY))
|
|
176
|
+
if not isinstance(security_group_ids, list):
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"vpc_config value for {} is not a list: {}".format(
|
|
179
|
+
SECURITY_GROUP_IDS_KEY, security_group_ids
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
if not security_group_ids:
|
|
183
|
+
raise ValueError("vpc_config value for {} is empty".format(SECURITY_GROUP_IDS_KEY))
|
|
184
|
+
|
|
185
|
+
return to_dict(subnets, security_group_ids)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Placeholder docstring"""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
from typing import Union
|
|
18
|
+
|
|
19
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
20
|
+
from sagemaker.core.common_utils import to_string
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ParameterRange(object):
|
|
24
|
+
"""Base class for representing parameter ranges.
|
|
25
|
+
|
|
26
|
+
This is used to define what hyperparameters to tune for an Amazon SageMaker
|
|
27
|
+
hyperparameter tuning job and to verify hyperparameters for Marketplace Algorithms.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
__all_types__ = ("Continuous", "Categorical", "Integer")
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
min_value: Union[int, float, PipelineVariable],
|
|
35
|
+
max_value: Union[int, float, PipelineVariable],
|
|
36
|
+
scaling_type: Union[str, PipelineVariable] = "Auto",
|
|
37
|
+
):
|
|
38
|
+
"""Initialize a parameter range.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
min_value (float or int or PipelineVariable): The minimum value for the range.
|
|
42
|
+
max_value (float or int or PipelineVariable): The maximum value for the range.
|
|
43
|
+
scaling_type (str or PipelineVariable): The scale used for searching the range during
|
|
44
|
+
tuning (default: 'Auto'). Valid values: 'Auto', 'Linear',
|
|
45
|
+
'Logarithmic' and 'ReverseLogarithmic'.
|
|
46
|
+
"""
|
|
47
|
+
self.min_value = min_value
|
|
48
|
+
self.max_value = max_value
|
|
49
|
+
self.scaling_type = scaling_type
|
|
50
|
+
|
|
51
|
+
def is_valid(self, value):
|
|
52
|
+
"""Determine if a value is valid within this ParameterRange.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
value (float or int): The value to be verified.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
bool: True if valid, False otherwise.
|
|
59
|
+
"""
|
|
60
|
+
return self.min_value <= value <= self.max_value
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def cast_to_type(cls, value):
|
|
64
|
+
"""Placeholder docstring"""
|
|
65
|
+
return float(value)
|
|
66
|
+
|
|
67
|
+
def as_tuning_range(self, name):
|
|
68
|
+
"""Represent the parameter range as a dictionary.
|
|
69
|
+
|
|
70
|
+
It is suitable for a request to create an Amazon SageMaker hyperparameter tuning job.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
name (str): The name of the hyperparameter.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
dict[str, str]: A dictionary that contains the name and values of
|
|
77
|
+
the hyperparameter.
|
|
78
|
+
"""
|
|
79
|
+
return {
|
|
80
|
+
"Name": name,
|
|
81
|
+
"MinValue": to_string(self.min_value),
|
|
82
|
+
"MaxValue": to_string(self.max_value),
|
|
83
|
+
"ScalingType": self.scaling_type,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ContinuousParameter(ParameterRange):
|
|
88
|
+
"""A class for representing hyperparameters that have a continuous range of possible values.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
min_value (float): The minimum value for the range.
|
|
92
|
+
max_value (float): The maximum value for the range.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
__name__ = "Continuous"
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def cast_to_type(cls, value):
|
|
99
|
+
"""Placeholder docstring"""
|
|
100
|
+
return float(value)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class CategoricalParameter(ParameterRange):
|
|
104
|
+
"""A class for representing hyperparameters that have a discrete list of possible values."""
|
|
105
|
+
|
|
106
|
+
__name__ = "Categorical"
|
|
107
|
+
|
|
108
|
+
def __init__(self, values): # pylint: disable=super-init-not-called
|
|
109
|
+
"""Initialize a ``CategoricalParameter``.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
values (list or object): The possible values for the hyperparameter.
|
|
113
|
+
This input will be converted into a list of strings.
|
|
114
|
+
"""
|
|
115
|
+
values = values if isinstance(values, list) else [values]
|
|
116
|
+
self.values = [to_string(v) for v in values]
|
|
117
|
+
|
|
118
|
+
def as_tuning_range(self, name):
|
|
119
|
+
"""Represent the parameter range as a dictionary.
|
|
120
|
+
|
|
121
|
+
It is suitable for a request to create an Amazon SageMaker hyperparameter tuning job.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
name (str): The name of the hyperparameter.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
dict[str, list[str]]: A dictionary that contains the name and values
|
|
128
|
+
of the hyperparameter.
|
|
129
|
+
"""
|
|
130
|
+
return {"Name": name, "Values": self.values}
|
|
131
|
+
|
|
132
|
+
def as_json_range(self, name):
|
|
133
|
+
"""Represent the parameter range as a dictionary.
|
|
134
|
+
|
|
135
|
+
Dictionary is suitable for a request to create an Amazon SageMaker hyperparameter tuning job
|
|
136
|
+
using one of the deep learning frameworks.
|
|
137
|
+
|
|
138
|
+
The deep learning framework images require that hyperparameters be
|
|
139
|
+
serialized as JSON.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
name (str): The name of the hyperparameter.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
dict[str, list[str]]: A dictionary that contains the name and values of the
|
|
146
|
+
hyperparameter, where the values are serialized as JSON.
|
|
147
|
+
"""
|
|
148
|
+
return {"Name": name, "Values": [json.dumps(v) for v in self.values]}
|
|
149
|
+
|
|
150
|
+
def is_valid(self, value):
|
|
151
|
+
"""Placeholder docstring"""
|
|
152
|
+
return value in self.values
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def cast_to_type(cls, value):
|
|
156
|
+
"""Placeholder docstring"""
|
|
157
|
+
return str(value)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class IntegerParameter(ParameterRange):
|
|
161
|
+
"""A class for representing hyperparameters that have an integer range of possible values.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
min_value (int): The minimum value for the range.
|
|
165
|
+
max_value (int): The maximum value for the range.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
__name__ = "Integer"
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def cast_to_type(cls, value):
|
|
172
|
+
"""Placeholder docstring"""
|
|
173
|
+
return int(value)
|