sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1731 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains utils for JumpStart."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
from sagemaker.core.helper.session_helper import Session
|
|
18
|
+
from sagemaker.core.jumpstart.models import HubContentDocument
|
|
19
|
+
|
|
20
|
+
from copy import copy
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
from functools import lru_cache, wraps
|
|
24
|
+
from typing import Any, Dict, List, Set, Optional, Tuple, Union
|
|
25
|
+
from urllib.parse import urlparse
|
|
26
|
+
import boto3
|
|
27
|
+
from botocore.exceptions import ClientError
|
|
28
|
+
from packaging.version import Version, InvalidVersion
|
|
29
|
+
import botocore
|
|
30
|
+
from sagemaker.core.shapes import ModelAccessConfig
|
|
31
|
+
import sagemaker
|
|
32
|
+
from sagemaker.core.config.config_schema import (
|
|
33
|
+
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
|
|
34
|
+
MODEL_EXECUTION_ROLE_ARN_PATH,
|
|
35
|
+
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
|
|
36
|
+
TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
|
|
37
|
+
TRAINING_JOB_ROLE_ARN_PATH,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
from sagemaker.core.jumpstart import constants, enums
|
|
41
|
+
from sagemaker.core.jumpstart import accessors
|
|
42
|
+
from sagemaker.core.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel
|
|
43
|
+
from sagemaker.core.s3 import parse_s3_url
|
|
44
|
+
from sagemaker.core.jumpstart.exceptions import (
|
|
45
|
+
DeprecatedJumpStartModelError,
|
|
46
|
+
VulnerableJumpStartModelError,
|
|
47
|
+
get_old_model_version_msg,
|
|
48
|
+
)
|
|
49
|
+
from sagemaker.core.jumpstart.types import (
|
|
50
|
+
JumpStartBenchmarkStat,
|
|
51
|
+
JumpStartMetadataConfig,
|
|
52
|
+
JumpStartModelHeader,
|
|
53
|
+
JumpStartModelSpecs,
|
|
54
|
+
JumpStartVersionedModelId,
|
|
55
|
+
DeploymentConfigMetadata,
|
|
56
|
+
)
|
|
57
|
+
from sagemaker.core.helper.session_helper import Session
|
|
58
|
+
from sagemaker.core.config.config import load_sagemaker_config
|
|
59
|
+
from sagemaker.core.common_utils import (
|
|
60
|
+
resolve_value_from_config,
|
|
61
|
+
TagsDict,
|
|
62
|
+
get_instance_rate_per_hour,
|
|
63
|
+
get_domain_for_region,
|
|
64
|
+
camel_case_to_pascal_case,
|
|
65
|
+
)
|
|
66
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def is_pipeline_variable(var: object) -> bool:
|
|
70
|
+
"""Check if the variable is a pipeline variable"""
|
|
71
|
+
return isinstance(var, PipelineVariable)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
from sagemaker.core.utils.user_agent import get_user_agent_extra_suffix
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Session] = None) -> str:
|
|
78
|
+
"""Get the EULA URL from the HubContentDocument.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
document (HubContentDocument): The HubContentDocument object.
|
|
82
|
+
sagemaker_session (Optional[Session]): SageMaker session. Defaults to None.
|
|
83
|
+
Returns:
|
|
84
|
+
str: The EULA URL.
|
|
85
|
+
"""
|
|
86
|
+
if not document.HostingEulaUri:
|
|
87
|
+
return ""
|
|
88
|
+
if sagemaker_session is None:
|
|
89
|
+
sagemaker_session = Session()
|
|
90
|
+
|
|
91
|
+
path_parts = document.HostingEulaUri.replace("s3://", "").split("/")
|
|
92
|
+
|
|
93
|
+
bucket = path_parts[0]
|
|
94
|
+
key = "/".join(path_parts[1:])
|
|
95
|
+
region = sagemaker_session.boto_region_name
|
|
96
|
+
|
|
97
|
+
botocore_session = sagemaker_session.boto_session._session
|
|
98
|
+
endpoint_resolver = botocore_session.get_component("endpoint_resolver")
|
|
99
|
+
partition = endpoint_resolver.get_partition_for_region(region)
|
|
100
|
+
dns_suffix = endpoint_resolver.get_partition_dns_suffix(partition)
|
|
101
|
+
|
|
102
|
+
return f"https://{bucket}.s3.{region}.{dns_suffix}/{key}"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_jumpstart_launched_regions_message() -> str:
|
|
106
|
+
"""Returns formatted string indicating where JumpStart is launched."""
|
|
107
|
+
if len(constants.JUMPSTART_REGION_NAME_SET) == 0:
|
|
108
|
+
return "JumpStart is not available in any region."
|
|
109
|
+
if len(constants.JUMPSTART_REGION_NAME_SET) == 1:
|
|
110
|
+
region = list(constants.JUMPSTART_REGION_NAME_SET)[0]
|
|
111
|
+
return f"JumpStart is available in {region} region."
|
|
112
|
+
|
|
113
|
+
sorted_regions = sorted(list(constants.JUMPSTART_REGION_NAME_SET))
|
|
114
|
+
if len(constants.JUMPSTART_REGION_NAME_SET) == 2:
|
|
115
|
+
return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions."
|
|
116
|
+
|
|
117
|
+
formatted_launched_regions_list = []
|
|
118
|
+
for i, region in enumerate(sorted_regions):
|
|
119
|
+
region_prefix = "" if i < len(sorted_regions) - 1 else "and "
|
|
120
|
+
formatted_launched_regions_list.append(region_prefix + region)
|
|
121
|
+
formatted_launched_regions_str = ", ".join(formatted_launched_regions_list)
|
|
122
|
+
return f"JumpStart is available in {formatted_launched_regions_str} regions."
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_jumpstart_gated_content_bucket(
|
|
126
|
+
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
|
|
127
|
+
) -> str:
|
|
128
|
+
"""Returns regionalized private content bucket name for JumpStart.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValueError: If JumpStart is not launched in ``region`` or private content
|
|
132
|
+
unavailable in that region.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
old_gated_content_bucket: Optional[str] = (
|
|
136
|
+
accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
info_logs: List[str] = []
|
|
140
|
+
|
|
141
|
+
gated_bucket_to_return: Optional[str] = None
|
|
142
|
+
if (
|
|
143
|
+
constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE in os.environ
|
|
144
|
+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE]) > 0
|
|
145
|
+
):
|
|
146
|
+
gated_bucket_to_return = os.environ[
|
|
147
|
+
constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE
|
|
148
|
+
]
|
|
149
|
+
info_logs.append(f"Using JumpStart gated bucket override: '{gated_bucket_to_return}'")
|
|
150
|
+
else:
|
|
151
|
+
try:
|
|
152
|
+
gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
|
|
153
|
+
region
|
|
154
|
+
].gated_content_bucket
|
|
155
|
+
if gated_bucket_to_return is None:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"No private content bucket for JumpStart exists in {region} region."
|
|
158
|
+
)
|
|
159
|
+
except KeyError:
|
|
160
|
+
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Unable to get private content bucket for JumpStart in {region} region. "
|
|
163
|
+
f"{formatted_launched_regions_str}"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)
|
|
167
|
+
|
|
168
|
+
if gated_bucket_to_return != old_gated_content_bucket:
|
|
169
|
+
if old_gated_content_bucket is not None:
|
|
170
|
+
accessors.JumpStartModelsAccessor.reset_cache()
|
|
171
|
+
for info_log in info_logs:
|
|
172
|
+
constants.JUMPSTART_LOGGER.info(info_log)
|
|
173
|
+
|
|
174
|
+
return gated_bucket_to_return
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def get_jumpstart_content_bucket(
|
|
178
|
+
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
|
|
179
|
+
) -> str:
|
|
180
|
+
"""Returns the regionalized content bucket name for JumpStart.
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
ValueError: If JumpStart is not launched in ``region``.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
old_content_bucket: Optional[str] = (
|
|
187
|
+
accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket()
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
info_logs: List[str] = []
|
|
191
|
+
|
|
192
|
+
bucket_to_return: Optional[str] = None
|
|
193
|
+
if (
|
|
194
|
+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
|
|
195
|
+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
|
|
196
|
+
):
|
|
197
|
+
bucket_to_return = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
|
|
198
|
+
info_logs.append(f"Using JumpStart bucket override: '{bucket_to_return}'")
|
|
199
|
+
else:
|
|
200
|
+
try:
|
|
201
|
+
bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
|
|
202
|
+
region
|
|
203
|
+
].content_bucket
|
|
204
|
+
except KeyError:
|
|
205
|
+
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Unable to get content bucket for Neo in {region} region. "
|
|
208
|
+
f"{formatted_launched_regions_str}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return)
|
|
212
|
+
|
|
213
|
+
if bucket_to_return != old_content_bucket:
|
|
214
|
+
if old_content_bucket is not None:
|
|
215
|
+
accessors.JumpStartModelsAccessor.reset_cache()
|
|
216
|
+
for info_log in info_logs:
|
|
217
|
+
constants.JUMPSTART_LOGGER.info(info_log)
|
|
218
|
+
return bucket_to_return
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def get_neo_content_bucket(
|
|
222
|
+
region: str = constants.NEO_DEFAULT_REGION_NAME,
|
|
223
|
+
) -> str:
|
|
224
|
+
"""Returns the regionalized S3 bucket name for Neo service.
|
|
225
|
+
|
|
226
|
+
Raises:
|
|
227
|
+
ValueError: If Neo is not launched in ``region``.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
bucket_to_return: Optional[str] = None
|
|
231
|
+
if (
|
|
232
|
+
constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE in os.environ
|
|
233
|
+
and len(os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]) > 0
|
|
234
|
+
):
|
|
235
|
+
bucket_to_return = os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]
|
|
236
|
+
info_log = f"Using Neo bucket override: '{bucket_to_return}'"
|
|
237
|
+
constants.JUMPSTART_LOGGER.info(info_log)
|
|
238
|
+
else:
|
|
239
|
+
try:
|
|
240
|
+
bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
|
|
241
|
+
region
|
|
242
|
+
].neo_content_bucket
|
|
243
|
+
except KeyError:
|
|
244
|
+
raise ValueError(f"Unable to get content bucket for Neo in {region} region.")
|
|
245
|
+
|
|
246
|
+
return bucket_to_return
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_formatted_manifest(
|
|
250
|
+
manifest: List[Dict],
|
|
251
|
+
) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]:
|
|
252
|
+
"""Returns formatted manifest dictionary from raw manifest.
|
|
253
|
+
|
|
254
|
+
Keys are JumpStartVersionedModelId objects, values are
|
|
255
|
+
``JumpStartModelHeader`` objects.
|
|
256
|
+
"""
|
|
257
|
+
manifest_dict = {}
|
|
258
|
+
for header in manifest:
|
|
259
|
+
header_obj = JumpStartModelHeader(header)
|
|
260
|
+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
|
|
261
|
+
header_obj
|
|
262
|
+
)
|
|
263
|
+
return manifest_dict
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_sagemaker_version() -> str:
|
|
267
|
+
"""Returns sagemaker library version.
|
|
268
|
+
|
|
269
|
+
If the sagemaker library version has not been set, this function
|
|
270
|
+
calls ``parse_sagemaker_version`` to retrieve the version and set
|
|
271
|
+
the constant.
|
|
272
|
+
"""
|
|
273
|
+
if accessors.SageMakerSettings.get_sagemaker_version() == "":
|
|
274
|
+
accessors.SageMakerSettings.set_sagemaker_version(parse_sagemaker_version())
|
|
275
|
+
return accessors.SageMakerSettings.get_sagemaker_version()
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def parse_sagemaker_version() -> str:
|
|
279
|
+
"""Returns sagemaker library version. This should only be called once.
|
|
280
|
+
|
|
281
|
+
Function reads ``__version__`` variable in ``sagemaker`` module.
|
|
282
|
+
In order to maintain compatibility with the ``packaging.version``
|
|
283
|
+
library, versions with fewer than 2, or more than 3, periods are rejected.
|
|
284
|
+
All versions that cannot be parsed with ``packaging.version`` are also
|
|
285
|
+
rejected.
|
|
286
|
+
|
|
287
|
+
Raises:
|
|
288
|
+
RuntimeError: If the SageMaker version is not readable. An exception is also raised if
|
|
289
|
+
the version cannot be parsed by ``packaging.version``.
|
|
290
|
+
"""
|
|
291
|
+
# Handle case where __version__ might not be available in development
|
|
292
|
+
try:
|
|
293
|
+
version = sagemaker.__version__
|
|
294
|
+
except AttributeError:
|
|
295
|
+
# Fallback for development environments - return a valid version format
|
|
296
|
+
return "3.0.0"
|
|
297
|
+
parsed_version = None
|
|
298
|
+
|
|
299
|
+
num_periods = version.count(".")
|
|
300
|
+
if num_periods == 2:
|
|
301
|
+
parsed_version = version
|
|
302
|
+
elif num_periods == 3:
|
|
303
|
+
trailing_period_index = version.rfind(".")
|
|
304
|
+
parsed_version = version[:trailing_period_index]
|
|
305
|
+
else:
|
|
306
|
+
raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}")
|
|
307
|
+
|
|
308
|
+
Version(parsed_version)
|
|
309
|
+
|
|
310
|
+
return parsed_version
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> bool:
|
|
314
|
+
"""Determines if `model_id` and `version` input are for JumpStart.
|
|
315
|
+
|
|
316
|
+
This method returns True if both arguments are not None, false if both arguments
|
|
317
|
+
are None, and raises an exception if one argument is None but the other isn't.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
model_id (str): Optional. Model ID of the JumpStart model.
|
|
321
|
+
version (str): Optional. Version of the JumpStart model.
|
|
322
|
+
|
|
323
|
+
Raises:
|
|
324
|
+
ValueError: If only one of the two arguments is None.
|
|
325
|
+
"""
|
|
326
|
+
if model_id is not None or version is not None:
|
|
327
|
+
if model_id is None or version is None:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
"Must specify JumpStart `model_id` and `model_version` when getting specs for "
|
|
330
|
+
"JumpStart models."
|
|
331
|
+
)
|
|
332
|
+
return True
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
|
|
337
|
+
"""Returns True if URI corresponds to a JumpStart-hosted model.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
uri (Optional[str]): uri for inference/training job.
|
|
341
|
+
"""
|
|
342
|
+
# Handle case where uri is not a string (e.g., SourceCode object)
|
|
343
|
+
if uri is None or not isinstance(uri, str):
|
|
344
|
+
return False
|
|
345
|
+
|
|
346
|
+
bucket = None
|
|
347
|
+
if urlparse(uri).scheme == "s3":
|
|
348
|
+
bucket, _ = parse_s3_url(uri)
|
|
349
|
+
|
|
350
|
+
return bucket in constants.JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def tag_key_in_array(tag_key: str, tag_array: List[Dict[str, str]]) -> bool:
|
|
354
|
+
"""Returns True if ``tag_key`` is in the ``tag_array``.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
tag_key (str): the tag key to check if it's already in the ``tag_array``.
|
|
358
|
+
tag_array (List[Dict[str, str]]): array of tags to check for ``tag_key``.
|
|
359
|
+
"""
|
|
360
|
+
for tag in tag_array:
|
|
361
|
+
if tag_key == tag["Key"]:
|
|
362
|
+
return True
|
|
363
|
+
return False
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
|
|
367
|
+
"""Return the value of a tag whose key matches the given ``tag_key``.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
tag_key (str): AWS tag for which to search.
|
|
371
|
+
tag_array (List[Dict[str, str]]): List of AWS tags, each formatted as dicts.
|
|
372
|
+
|
|
373
|
+
Raises:
|
|
374
|
+
KeyError: If the number of matches for the ``tag_key`` is not equal to 1.
|
|
375
|
+
"""
|
|
376
|
+
tag_values = [tag["Value"] for tag in tag_array if tag_key == tag["Key"]]
|
|
377
|
+
if len(tag_values) != 1:
|
|
378
|
+
raise KeyError(
|
|
379
|
+
f"Cannot get value of tag for tag key '{tag_key}' -- found {len(tag_values)} "
|
|
380
|
+
f"number of matches in the tag list."
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
return tag_values[0]
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def add_single_jumpstart_tag(
|
|
387
|
+
tag_value: str,
|
|
388
|
+
tag_key: enums.JumpStartTag,
|
|
389
|
+
curr_tags: Optional[List[Dict[str, str]]],
|
|
390
|
+
is_uri=False,
|
|
391
|
+
) -> Optional[List]:
|
|
392
|
+
"""Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
uri (str): URI which may correspond to a JumpStart model.
|
|
396
|
+
tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI
|
|
397
|
+
corresponds to a JumpStart model.
|
|
398
|
+
curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
|
|
399
|
+
is_uri (boolean): Set to True to indicate a s3 uri is to be tagged. Set to False to indicate
|
|
400
|
+
tags for JumpStart model id / version are being added. (Default: False).
|
|
401
|
+
"""
|
|
402
|
+
if not is_uri or is_jumpstart_model_uri(tag_value):
|
|
403
|
+
if curr_tags is None:
|
|
404
|
+
curr_tags = []
|
|
405
|
+
if not tag_key_in_array(tag_key, curr_tags):
|
|
406
|
+
skip_adding_tag = (
|
|
407
|
+
(
|
|
408
|
+
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
|
|
409
|
+
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
|
|
410
|
+
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
|
|
411
|
+
or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags)
|
|
412
|
+
or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags)
|
|
413
|
+
)
|
|
414
|
+
if is_uri
|
|
415
|
+
else False
|
|
416
|
+
)
|
|
417
|
+
if not skip_adding_tag:
|
|
418
|
+
curr_tags.append(
|
|
419
|
+
{
|
|
420
|
+
"Key": tag_key,
|
|
421
|
+
"Value": tag_value,
|
|
422
|
+
}
|
|
423
|
+
)
|
|
424
|
+
return curr_tags
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def get_jumpstart_base_name_if_jumpstart_model(
|
|
428
|
+
*uris: Optional[str],
|
|
429
|
+
) -> Optional[str]:
|
|
430
|
+
"""Return default JumpStart base name if a URI belongs to JumpStart.
|
|
431
|
+
|
|
432
|
+
If no URIs belong to JumpStart, return None.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
*uris (Optional[str]): URI to test for association with JumpStart.
|
|
436
|
+
"""
|
|
437
|
+
for uri in uris:
|
|
438
|
+
if is_jumpstart_model_uri(uri):
|
|
439
|
+
return constants.JUMPSTART_RESOURCE_BASE_NAME
|
|
440
|
+
return None
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def add_jumpstart_model_info_tags(
|
|
444
|
+
tags: Optional[List[TagsDict]],
|
|
445
|
+
model_id: str,
|
|
446
|
+
model_version: str,
|
|
447
|
+
model_type: Optional[enums.JumpStartModelType] = None,
|
|
448
|
+
config_name: Optional[str] = None,
|
|
449
|
+
scope: enums.JumpStartScriptScope = None,
|
|
450
|
+
) -> List[TagsDict]:
|
|
451
|
+
"""Add custom model ID and version tags to JumpStart related resources."""
|
|
452
|
+
if model_id is None or model_version is None:
|
|
453
|
+
return tags
|
|
454
|
+
tags = add_single_jumpstart_tag(
|
|
455
|
+
model_id,
|
|
456
|
+
enums.JumpStartTag.MODEL_ID,
|
|
457
|
+
tags,
|
|
458
|
+
is_uri=False,
|
|
459
|
+
)
|
|
460
|
+
# Only add version tag if it's not a wildcard (*)
|
|
461
|
+
# Wildcard is used for version resolution but is not a valid AWS tag value
|
|
462
|
+
if model_version != "*":
|
|
463
|
+
tags = add_single_jumpstart_tag(
|
|
464
|
+
model_version,
|
|
465
|
+
enums.JumpStartTag.MODEL_VERSION,
|
|
466
|
+
tags,
|
|
467
|
+
is_uri=False,
|
|
468
|
+
)
|
|
469
|
+
if model_type == enums.JumpStartModelType.PROPRIETARY:
|
|
470
|
+
tags = add_single_jumpstart_tag(
|
|
471
|
+
enums.JumpStartModelType.PROPRIETARY.value,
|
|
472
|
+
enums.JumpStartTag.MODEL_TYPE,
|
|
473
|
+
tags,
|
|
474
|
+
is_uri=False,
|
|
475
|
+
)
|
|
476
|
+
if config_name and scope == enums.JumpStartScriptScope.INFERENCE:
|
|
477
|
+
tags = add_single_jumpstart_tag(
|
|
478
|
+
config_name,
|
|
479
|
+
enums.JumpStartTag.INFERENCE_CONFIG_NAME,
|
|
480
|
+
tags,
|
|
481
|
+
is_uri=False,
|
|
482
|
+
)
|
|
483
|
+
if config_name and scope == enums.JumpStartScriptScope.TRAINING:
|
|
484
|
+
tags = add_single_jumpstart_tag(
|
|
485
|
+
config_name,
|
|
486
|
+
enums.JumpStartTag.TRAINING_CONFIG_NAME,
|
|
487
|
+
tags,
|
|
488
|
+
is_uri=False,
|
|
489
|
+
)
|
|
490
|
+
return tags
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def add_hub_content_arn_tags(
|
|
494
|
+
tags: Optional[List[TagsDict]],
|
|
495
|
+
hub_content_arn: str,
|
|
496
|
+
) -> Optional[List[TagsDict]]:
|
|
497
|
+
"""Adds custom Hub arn tag to JumpStart related resources."""
|
|
498
|
+
|
|
499
|
+
tags = add_single_jumpstart_tag(
|
|
500
|
+
hub_content_arn,
|
|
501
|
+
enums.JumpStartTag.HUB_CONTENT_ARN,
|
|
502
|
+
tags,
|
|
503
|
+
is_uri=False,
|
|
504
|
+
)
|
|
505
|
+
return tags
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def add_bedrock_store_tags(
|
|
509
|
+
tags: Optional[List[TagsDict]],
|
|
510
|
+
compatibility: str,
|
|
511
|
+
) -> Optional[List[TagsDict]]:
|
|
512
|
+
"""Adds custom Hub arn tag to JumpStart related resources."""
|
|
513
|
+
|
|
514
|
+
tags = add_single_jumpstart_tag(
|
|
515
|
+
compatibility,
|
|
516
|
+
enums.JumpStartTag.BEDROCK,
|
|
517
|
+
tags,
|
|
518
|
+
is_uri=False,
|
|
519
|
+
)
|
|
520
|
+
return tags
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def add_jumpstart_uri_tags(
|
|
524
|
+
tags: Optional[List[TagsDict]] = None,
|
|
525
|
+
inference_model_uri: Optional[Union[str, dict]] = None,
|
|
526
|
+
inference_script_uri: Optional[str] = None,
|
|
527
|
+
training_model_uri: Optional[str] = None,
|
|
528
|
+
training_script_uri: Optional[str] = None,
|
|
529
|
+
) -> Optional[List[TagsDict]]:
|
|
530
|
+
"""Add custom uri tags to JumpStart models, return the updated tags.
|
|
531
|
+
|
|
532
|
+
No-op if this is not a JumpStart model related resource.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
|
|
536
|
+
or training job. (Default: None).
|
|
537
|
+
inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact.
|
|
538
|
+
(Default: None).
|
|
539
|
+
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
|
|
540
|
+
(Default: None).
|
|
541
|
+
training_model_uri (Optional[str]): S3 URI for training model artifact.
|
|
542
|
+
(Default: None).
|
|
543
|
+
training_script_uri (Optional[str]): S3 URI for training script tarball.
|
|
544
|
+
(Default: None).
|
|
545
|
+
"""
|
|
546
|
+
warn_msg = (
|
|
547
|
+
"The URI (%s) is a pipeline variable which is only interpreted at execution time. "
|
|
548
|
+
"As a result, the JumpStart resources will not be tagged."
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
if isinstance(inference_model_uri, dict):
|
|
552
|
+
inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None)
|
|
553
|
+
|
|
554
|
+
if inference_model_uri:
|
|
555
|
+
if is_pipeline_variable(inference_model_uri):
|
|
556
|
+
logging.warning(warn_msg, "inference_model_uri")
|
|
557
|
+
else:
|
|
558
|
+
tags = add_single_jumpstart_tag(
|
|
559
|
+
inference_model_uri,
|
|
560
|
+
enums.JumpStartTag.INFERENCE_MODEL_URI,
|
|
561
|
+
tags,
|
|
562
|
+
is_uri=True,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if inference_script_uri:
|
|
566
|
+
if is_pipeline_variable(inference_script_uri):
|
|
567
|
+
logging.warning(warn_msg, "inference_script_uri")
|
|
568
|
+
else:
|
|
569
|
+
tags = add_single_jumpstart_tag(
|
|
570
|
+
inference_script_uri,
|
|
571
|
+
enums.JumpStartTag.INFERENCE_SCRIPT_URI,
|
|
572
|
+
tags,
|
|
573
|
+
is_uri=True,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
if training_model_uri:
|
|
577
|
+
if is_pipeline_variable(training_model_uri):
|
|
578
|
+
logging.warning(warn_msg, "training_model_uri")
|
|
579
|
+
else:
|
|
580
|
+
tags = add_single_jumpstart_tag(
|
|
581
|
+
training_model_uri,
|
|
582
|
+
enums.JumpStartTag.TRAINING_MODEL_URI,
|
|
583
|
+
tags,
|
|
584
|
+
is_uri=True,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
if training_script_uri:
|
|
588
|
+
if is_pipeline_variable(training_script_uri):
|
|
589
|
+
logging.warning(warn_msg, "training_script_uri")
|
|
590
|
+
else:
|
|
591
|
+
tags = add_single_jumpstart_tag(
|
|
592
|
+
training_script_uri,
|
|
593
|
+
enums.JumpStartTag.TRAINING_SCRIPT_URI,
|
|
594
|
+
tags,
|
|
595
|
+
is_uri=True,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
return tags
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def update_inference_tags_with_jumpstart_training_tags(
|
|
602
|
+
inference_tags: Optional[List[Dict[str, str]]], training_tags: Optional[List[Dict[str, str]]]
|
|
603
|
+
) -> Optional[List[Dict[str, str]]]:
|
|
604
|
+
"""Updates the tags for the ``sagemaker.model.Model.deploy`` command with any JumpStart tags.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
inference_tags (Optional[List[Dict[str, str]]]): Custom tags to appy to inference job.
|
|
608
|
+
training_tags (Optional[List[Dict[str, str]]]): Tags from training job.
|
|
609
|
+
"""
|
|
610
|
+
if training_tags:
|
|
611
|
+
for tag_key in enums.JumpStartTag:
|
|
612
|
+
if tag_key_in_array(tag_key, training_tags):
|
|
613
|
+
tag_value = get_tag_value(tag_key, training_tags)
|
|
614
|
+
if inference_tags is None:
|
|
615
|
+
inference_tags = []
|
|
616
|
+
if not tag_key_in_array(tag_key, inference_tags):
|
|
617
|
+
inference_tags.append({"Key": tag_key, "Value": tag_value})
|
|
618
|
+
|
|
619
|
+
return inference_tags
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
|
|
623
|
+
"""Returns EULA message to display if one is available, else empty string."""
|
|
624
|
+
if model_specs.hosting_eula_key is None:
|
|
625
|
+
return ""
|
|
626
|
+
return get_formatted_eula_message_template(
|
|
627
|
+
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str:
|
|
632
|
+
"""Returns a formatted EULA message."""
|
|
633
|
+
return (
|
|
634
|
+
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
|
|
635
|
+
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
|
|
636
|
+
f"{get_domain_for_region(region)}"
|
|
637
|
+
f"/{hosting_eula_key} for terms of use."
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def emit_logs_based_on_model_specs(
|
|
642
|
+
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
|
|
643
|
+
) -> None:
|
|
644
|
+
"""Emits logs based on model specs and region."""
|
|
645
|
+
|
|
646
|
+
if model_specs.hosting_eula_key:
|
|
647
|
+
constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))
|
|
648
|
+
|
|
649
|
+
full_version: str = model_specs.version
|
|
650
|
+
|
|
651
|
+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
|
|
652
|
+
region=region, s3_client=s3_client
|
|
653
|
+
)
|
|
654
|
+
max_version_for_model_id: Optional[str] = None
|
|
655
|
+
for header in models_manifest_list:
|
|
656
|
+
if header.model_id == model_specs.model_id:
|
|
657
|
+
if max_version_for_model_id is None or Version(header.version) > Version(
|
|
658
|
+
max_version_for_model_id
|
|
659
|
+
):
|
|
660
|
+
max_version_for_model_id = header.version
|
|
661
|
+
|
|
662
|
+
if full_version != max_version_for_model_id:
|
|
663
|
+
constants.JUMPSTART_LOGGER.info(
|
|
664
|
+
get_old_model_version_msg(model_specs.model_id, full_version, max_version_for_model_id)
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
if model_specs.deprecated:
|
|
668
|
+
deprecated_message = model_specs.deprecated_message or (
|
|
669
|
+
"Using deprecated JumpStart model "
|
|
670
|
+
f"'{model_specs.model_id}' and version '{model_specs.version}'."
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
constants.JUMPSTART_LOGGER.warning(deprecated_message)
|
|
674
|
+
|
|
675
|
+
if model_specs.deprecate_warn_message:
|
|
676
|
+
constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message)
|
|
677
|
+
|
|
678
|
+
if model_specs.usage_info_message:
|
|
679
|
+
constants.JUMPSTART_LOGGER.info(model_specs.usage_info_message)
|
|
680
|
+
|
|
681
|
+
if model_specs.inference_vulnerable or model_specs.training_vulnerable:
|
|
682
|
+
constants.JUMPSTART_LOGGER.warning(
|
|
683
|
+
"Using vulnerable JumpStart model '%s' and version '%s'.",
|
|
684
|
+
model_specs.model_id,
|
|
685
|
+
model_specs.version,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def verify_model_region_and_return_specs(
|
|
690
|
+
model_id: Optional[str],
|
|
691
|
+
version: Optional[str],
|
|
692
|
+
scope: Optional[str],
|
|
693
|
+
region: Optional[str] = None,
|
|
694
|
+
hub_arn: Optional[str] = None,
|
|
695
|
+
tolerate_vulnerable_model: bool = False,
|
|
696
|
+
tolerate_deprecated_model: bool = False,
|
|
697
|
+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
698
|
+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
|
|
699
|
+
config_name: Optional[str] = None,
|
|
700
|
+
) -> JumpStartModelSpecs:
|
|
701
|
+
"""Verifies that an acceptable model_id, version, scope, and region combination is provided.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
model_id (Optional[str]): model ID of the JumpStart model to verify and
|
|
705
|
+
obtains specs.
|
|
706
|
+
version (Optional[str]): version of the JumpStart model to verify and
|
|
707
|
+
obtains specs.
|
|
708
|
+
scope (Optional[str]): scope of the JumpStart model to verify.
|
|
709
|
+
region (Optional[str]): region of the JumpStart model to verify and
|
|
710
|
+
obtains specs.
|
|
711
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
712
|
+
model details from. (default: None).
|
|
713
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
714
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
715
|
+
exception if the script used by this version of the model has dependencies with known
|
|
716
|
+
security vulnerabilities. (Default: False).
|
|
717
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
718
|
+
(exception not raised). False if these models should raise an exception.
|
|
719
|
+
(Default: False).
|
|
720
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
721
|
+
object, used for SageMaker interactions. If not
|
|
722
|
+
specified, one is created using the default AWS configuration
|
|
723
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
724
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
725
|
+
|
|
726
|
+
Raises:
|
|
727
|
+
NotImplementedError: If the scope is not supported.
|
|
728
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
729
|
+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
|
|
730
|
+
known security vulnerabilities.
|
|
731
|
+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
|
|
732
|
+
"""
|
|
733
|
+
|
|
734
|
+
region = region or get_region_fallback(
|
|
735
|
+
sagemaker_session=sagemaker_session,
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
if scope is None:
|
|
739
|
+
raise ValueError(
|
|
740
|
+
"Must specify `model_scope` argument to retrieve model "
|
|
741
|
+
"artifact uri for JumpStart models."
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
if scope not in constants.SUPPORTED_JUMPSTART_SCOPES:
|
|
745
|
+
raise NotImplementedError(
|
|
746
|
+
"JumpStart models only support scopes: "
|
|
747
|
+
f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}."
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore
|
|
751
|
+
region=region,
|
|
752
|
+
model_id=model_id,
|
|
753
|
+
hub_arn=hub_arn,
|
|
754
|
+
version=version,
|
|
755
|
+
s3_client=sagemaker_session.s3_client,
|
|
756
|
+
model_type=model_type,
|
|
757
|
+
sagemaker_session=sagemaker_session,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
if (
|
|
761
|
+
scope == constants.JumpStartScriptScope.TRAINING.value
|
|
762
|
+
and not model_specs.training_supported
|
|
763
|
+
):
|
|
764
|
+
raise ValueError(
|
|
765
|
+
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
if model_specs.deprecated:
|
|
769
|
+
if not tolerate_deprecated_model:
|
|
770
|
+
raise DeprecatedJumpStartModelError(
|
|
771
|
+
model_id=model_id, version=version, message=model_specs.deprecated_message
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
|
|
775
|
+
if not tolerate_vulnerable_model:
|
|
776
|
+
raise VulnerableJumpStartModelError(
|
|
777
|
+
model_id=model_id,
|
|
778
|
+
version=version,
|
|
779
|
+
vulnerabilities=model_specs.inference_vulnerabilities,
|
|
780
|
+
scope=constants.JumpStartScriptScope.INFERENCE,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable:
|
|
784
|
+
if not tolerate_vulnerable_model:
|
|
785
|
+
raise VulnerableJumpStartModelError(
|
|
786
|
+
model_id=model_id,
|
|
787
|
+
version=version,
|
|
788
|
+
vulnerabilities=model_specs.training_vulnerabilities,
|
|
789
|
+
scope=constants.JumpStartScriptScope.TRAINING,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
if model_specs and config_name:
|
|
793
|
+
model_specs.set_config(config_name, scope)
|
|
794
|
+
|
|
795
|
+
return model_specs
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def update_dict_if_key_not_present(
|
|
799
|
+
dict_to_update: Optional[dict], key_to_add: Any, value_to_add: Any
|
|
800
|
+
) -> Optional[dict]:
|
|
801
|
+
"""If a key is not present in the dict, add the new (key, value) pair, and return dict.
|
|
802
|
+
|
|
803
|
+
If dict is empty, return None.
|
|
804
|
+
"""
|
|
805
|
+
if dict_to_update is None:
|
|
806
|
+
dict_to_update = {}
|
|
807
|
+
if key_to_add not in dict_to_update:
|
|
808
|
+
dict_to_update[key_to_add] = value_to_add
|
|
809
|
+
if dict_to_update == {}:
|
|
810
|
+
dict_to_update = None
|
|
811
|
+
|
|
812
|
+
return dict_to_update
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def resolve_model_sagemaker_config_field(
|
|
816
|
+
field_name: str,
|
|
817
|
+
field_val: Optional[Any],
|
|
818
|
+
sagemaker_session: Session,
|
|
819
|
+
default_value: Optional[str] = None,
|
|
820
|
+
) -> Any:
|
|
821
|
+
"""Given a field name, checks if there is a sagemaker config value to set.
|
|
822
|
+
|
|
823
|
+
For the role field, which is customer-supplied, we allow ``field_val`` to take precedence
|
|
824
|
+
over sagemaker config values. For all other fields, sagemaker config values take precedence
|
|
825
|
+
over the JumpStart default fields.
|
|
826
|
+
"""
|
|
827
|
+
# In case, sagemaker_session is None, get sagemaker_config from load_sagemaker_config()
|
|
828
|
+
# to resolve value from config for the respective field_name parameter
|
|
829
|
+
_sagemaker_config = load_sagemaker_config() if (sagemaker_session is None) else None
|
|
830
|
+
|
|
831
|
+
# We allow customers to define a role which takes precedence
|
|
832
|
+
# over the one defined in sagemaker config
|
|
833
|
+
if field_name == "role":
|
|
834
|
+
return resolve_value_from_config(
|
|
835
|
+
direct_input=field_val,
|
|
836
|
+
config_path=MODEL_EXECUTION_ROLE_ARN_PATH,
|
|
837
|
+
default_value=default_value or sagemaker_session.get_caller_identity_arn(),
|
|
838
|
+
sagemaker_session=sagemaker_session,
|
|
839
|
+
sagemaker_config=_sagemaker_config,
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
# JumpStart Models have certain default field values. We want
|
|
843
|
+
# sagemaker config values to take priority over the model-specific defaults.
|
|
844
|
+
if field_name == "enable_network_isolation":
|
|
845
|
+
resolved_val = resolve_value_from_config(
|
|
846
|
+
direct_input=None,
|
|
847
|
+
config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
|
|
848
|
+
sagemaker_session=sagemaker_session,
|
|
849
|
+
default_value=default_value,
|
|
850
|
+
sagemaker_config=_sagemaker_config,
|
|
851
|
+
)
|
|
852
|
+
return resolved_val if resolved_val is not None else field_val
|
|
853
|
+
|
|
854
|
+
# field is not covered by sagemaker config so return as is
|
|
855
|
+
return field_val
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def resolve_estimator_sagemaker_config_field(
|
|
859
|
+
field_name: str,
|
|
860
|
+
field_val: Optional[Any],
|
|
861
|
+
sagemaker_session: Session,
|
|
862
|
+
default_value: Optional[str] = None,
|
|
863
|
+
) -> Any:
|
|
864
|
+
"""Given a field name, checks if there is a sagemaker config value to set.
|
|
865
|
+
|
|
866
|
+
For the role field, which is customer-supplied, we allow ``field_val`` to take precedence
|
|
867
|
+
over sagemaker config values. For all other fields, sagemaker config values take precedence
|
|
868
|
+
over the JumpStart default fields.
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
# Workaround for config injection if sagemaker_session is None, since in
|
|
872
|
+
# that case sagemaker_session will not be initialized until
|
|
873
|
+
# `_init_sagemaker_session_if_does_not_exist` is called later
|
|
874
|
+
_sagemaker_config = load_sagemaker_config() if (sagemaker_session is None) else None
|
|
875
|
+
|
|
876
|
+
# We allow customers to define a role which takes precedence
|
|
877
|
+
# over the one defined in sagemaker config
|
|
878
|
+
if field_name == "role":
|
|
879
|
+
return resolve_value_from_config(
|
|
880
|
+
direct_input=field_val,
|
|
881
|
+
config_path=TRAINING_JOB_ROLE_ARN_PATH,
|
|
882
|
+
default_value=default_value or sagemaker_session.get_caller_identity_arn(),
|
|
883
|
+
sagemaker_session=sagemaker_session,
|
|
884
|
+
sagemaker_config=_sagemaker_config,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# JumpStart Estimators have certain default field values. We want
|
|
888
|
+
# sagemaker config values to take priority over the model-specific defaults.
|
|
889
|
+
if field_name == "enable_network_isolation":
|
|
890
|
+
|
|
891
|
+
resolved_val = resolve_value_from_config(
|
|
892
|
+
direct_input=None,
|
|
893
|
+
config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
|
|
894
|
+
sagemaker_session=sagemaker_session,
|
|
895
|
+
default_value=default_value,
|
|
896
|
+
sagemaker_config=_sagemaker_config,
|
|
897
|
+
)
|
|
898
|
+
return resolved_val if resolved_val is not None else field_val
|
|
899
|
+
|
|
900
|
+
if field_name == "encrypt_inter_container_traffic":
|
|
901
|
+
|
|
902
|
+
resolved_val = resolve_value_from_config(
|
|
903
|
+
direct_input=None,
|
|
904
|
+
config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
|
|
905
|
+
sagemaker_session=sagemaker_session,
|
|
906
|
+
default_value=default_value,
|
|
907
|
+
sagemaker_config=_sagemaker_config,
|
|
908
|
+
)
|
|
909
|
+
return resolved_val if resolved_val is not None else field_val
|
|
910
|
+
|
|
911
|
+
# field is not covered by sagemaker config so return as is
|
|
912
|
+
return field_val
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
def validate_model_id_and_get_type(
|
|
916
|
+
model_id: Optional[str],
|
|
917
|
+
region: Optional[str] = None,
|
|
918
|
+
model_version: Optional[str] = None,
|
|
919
|
+
script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
|
|
920
|
+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
921
|
+
hub_arn: Optional[str] = None,
|
|
922
|
+
) -> Optional[enums.JumpStartModelType]:
|
|
923
|
+
"""Returns model type if the model ID is supported for the given script.
|
|
924
|
+
|
|
925
|
+
Raises:
|
|
926
|
+
ValueError: If the script is not supported by JumpStart.
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
if model_id in {None, ""}:
|
|
930
|
+
return None
|
|
931
|
+
if not isinstance(model_id, str):
|
|
932
|
+
return None
|
|
933
|
+
if hub_arn:
|
|
934
|
+
model_types = _validate_hub_service_model_id_and_get_type(
|
|
935
|
+
model_id=model_id,
|
|
936
|
+
hub_arn=hub_arn,
|
|
937
|
+
region=region,
|
|
938
|
+
model_version=model_version,
|
|
939
|
+
sagemaker_session=sagemaker_session,
|
|
940
|
+
)
|
|
941
|
+
return (
|
|
942
|
+
model_types[0] if model_types else None
|
|
943
|
+
) # Currently this function only supports one model type
|
|
944
|
+
|
|
945
|
+
s3_client = sagemaker_session.s3_client if sagemaker_session else None
|
|
946
|
+
region = region or constants.JUMPSTART_DEFAULT_REGION_NAME
|
|
947
|
+
model_version = model_version or "*"
|
|
948
|
+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
|
|
949
|
+
region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHTS
|
|
950
|
+
)
|
|
951
|
+
open_weight_model_id_set = {model.model_id for model in models_manifest_list}
|
|
952
|
+
|
|
953
|
+
if model_id in open_weight_model_id_set:
|
|
954
|
+
return enums.JumpStartModelType.OPEN_WEIGHTS
|
|
955
|
+
|
|
956
|
+
proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
|
|
957
|
+
region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list}
|
|
961
|
+
if model_id in proprietary_model_id_set:
|
|
962
|
+
if script == enums.JumpStartScriptScope.INFERENCE:
|
|
963
|
+
return enums.JumpStartModelType.PROPRIETARY
|
|
964
|
+
raise ValueError(f"Unsupported script for Proprietary models: {script}")
|
|
965
|
+
return None
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
def _validate_hub_service_model_id_and_get_type(
|
|
969
|
+
model_id: Optional[str],
|
|
970
|
+
hub_arn: str,
|
|
971
|
+
region: Optional[str] = None,
|
|
972
|
+
model_version: Optional[str] = None,
|
|
973
|
+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
974
|
+
) -> List[enums.JumpStartModelType]:
|
|
975
|
+
"""Returns a list of JumpStartModelType based off the HubContent.
|
|
976
|
+
|
|
977
|
+
Only returns valid JumpStartModelType. Returns an empty array if none are found.
|
|
978
|
+
"""
|
|
979
|
+
hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
|
|
980
|
+
region=region,
|
|
981
|
+
model_id=model_id,
|
|
982
|
+
version=model_version,
|
|
983
|
+
hub_arn=hub_arn,
|
|
984
|
+
sagemaker_session=sagemaker_session,
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
hub_content_model_types = []
|
|
988
|
+
model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", [])
|
|
989
|
+
model_types = model_types_field if model_types_field else []
|
|
990
|
+
for model_type in model_types:
|
|
991
|
+
try:
|
|
992
|
+
hub_content_model_types.append(enums.JumpStartModelType[model_type])
|
|
993
|
+
except ValueError:
|
|
994
|
+
continue
|
|
995
|
+
|
|
996
|
+
return hub_content_model_types
|
|
997
|
+
|
|
998
|
+
|
|
999
|
+
def _extract_value_from_list_of_tags(
|
|
1000
|
+
tag_keys: List[str],
|
|
1001
|
+
list_tags_result: List[str],
|
|
1002
|
+
resource_name: str,
|
|
1003
|
+
resource_arn: str,
|
|
1004
|
+
):
|
|
1005
|
+
"""Extracts value from list of tags with check of duplicate tags.
|
|
1006
|
+
|
|
1007
|
+
Returns None if no value is found.
|
|
1008
|
+
"""
|
|
1009
|
+
resolved_value = None
|
|
1010
|
+
for tag_key in tag_keys:
|
|
1011
|
+
try:
|
|
1012
|
+
value_from_tag = get_tag_value(tag_key, list_tags_result)
|
|
1013
|
+
except KeyError:
|
|
1014
|
+
continue
|
|
1015
|
+
if value_from_tag is not None:
|
|
1016
|
+
if resolved_value is not None and value_from_tag != resolved_value:
|
|
1017
|
+
constants.JUMPSTART_LOGGER.warning(
|
|
1018
|
+
"Found multiple %s tags on the following resource: %s",
|
|
1019
|
+
resource_name,
|
|
1020
|
+
resource_arn,
|
|
1021
|
+
)
|
|
1022
|
+
resolved_value = None
|
|
1023
|
+
break
|
|
1024
|
+
resolved_value = value_from_tag
|
|
1025
|
+
return resolved_value
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
def get_jumpstart_model_info_from_resource_arn(
|
|
1029
|
+
resource_arn: str,
|
|
1030
|
+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
1031
|
+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
|
1032
|
+
"""Returns the JumpStart model ID, version and config name if in resource tags.
|
|
1033
|
+
|
|
1034
|
+
Returns 'None' if model ID or version or config name cannot be inferred from tags.
|
|
1035
|
+
"""
|
|
1036
|
+
|
|
1037
|
+
list_tags_result = sagemaker_session.list_tags(resource_arn)
|
|
1038
|
+
|
|
1039
|
+
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
|
|
1040
|
+
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
|
|
1041
|
+
inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME]
|
|
1042
|
+
training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME]
|
|
1043
|
+
|
|
1044
|
+
model_id: Optional[str] = _extract_value_from_list_of_tags(
|
|
1045
|
+
tag_keys=model_id_keys,
|
|
1046
|
+
list_tags_result=list_tags_result,
|
|
1047
|
+
resource_name="model ID",
|
|
1048
|
+
resource_arn=resource_arn,
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
model_version: Optional[str] = _extract_value_from_list_of_tags(
|
|
1052
|
+
tag_keys=model_version_keys,
|
|
1053
|
+
list_tags_result=list_tags_result,
|
|
1054
|
+
resource_name="model version",
|
|
1055
|
+
resource_arn=resource_arn,
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
inference_config_name: Optional[str] = _extract_value_from_list_of_tags(
|
|
1059
|
+
tag_keys=inference_config_name_keys,
|
|
1060
|
+
list_tags_result=list_tags_result,
|
|
1061
|
+
resource_name="inference config name",
|
|
1062
|
+
resource_arn=resource_arn,
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
training_config_name: Optional[str] = _extract_value_from_list_of_tags(
|
|
1066
|
+
tag_keys=training_config_name_keys,
|
|
1067
|
+
list_tags_result=list_tags_result,
|
|
1068
|
+
resource_name="training config name",
|
|
1069
|
+
resource_arn=resource_arn,
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
return model_id, model_version, inference_config_name, training_config_name
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def get_region_fallback(
|
|
1076
|
+
s3_bucket_name: Optional[str] = None,
|
|
1077
|
+
s3_client: Optional[boto3.client] = None,
|
|
1078
|
+
sagemaker_session: Optional[Session] = None,
|
|
1079
|
+
) -> str:
|
|
1080
|
+
"""Returns region to use for JumpStart functionality implicitly via session objects."""
|
|
1081
|
+
regions_in_s3_bucket_name: Set[str] = {
|
|
1082
|
+
region
|
|
1083
|
+
for region in constants.JUMPSTART_REGION_NAME_SET
|
|
1084
|
+
if s3_bucket_name is not None
|
|
1085
|
+
if region in s3_bucket_name
|
|
1086
|
+
}
|
|
1087
|
+
regions_in_s3_client_endpoint_url: Set[str] = {
|
|
1088
|
+
region
|
|
1089
|
+
for region in constants.JUMPSTART_REGION_NAME_SET
|
|
1090
|
+
if s3_client is not None
|
|
1091
|
+
if region in s3_client._endpoint.host
|
|
1092
|
+
}
|
|
1093
|
+
|
|
1094
|
+
regions_in_sagemaker_session: Set[str] = {
|
|
1095
|
+
region
|
|
1096
|
+
for region in constants.JUMPSTART_REGION_NAME_SET
|
|
1097
|
+
if sagemaker_session
|
|
1098
|
+
if region == sagemaker_session.boto_region_name
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
combined_regions = regions_in_s3_client_endpoint_url.union(
|
|
1102
|
+
regions_in_s3_bucket_name, regions_in_sagemaker_session
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
if len(combined_regions) > 1:
|
|
1106
|
+
raise ValueError("Unable to resolve a region name from the s3 bucket and client provided.")
|
|
1107
|
+
|
|
1108
|
+
if len(combined_regions) == 0:
|
|
1109
|
+
return constants.JUMPSTART_DEFAULT_REGION_NAME
|
|
1110
|
+
|
|
1111
|
+
return list(combined_regions)[0]
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def get_config_names(
|
|
1115
|
+
region: str,
|
|
1116
|
+
model_id: str,
|
|
1117
|
+
model_version: str,
|
|
1118
|
+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
1119
|
+
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
|
|
1120
|
+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
|
|
1121
|
+
) -> List[str]:
|
|
1122
|
+
"""Returns a list of config names for the given model ID and region.
|
|
1123
|
+
|
|
1124
|
+
Raises:
|
|
1125
|
+
ValueError: If the script scope is not supported by JumpStart.
|
|
1126
|
+
"""
|
|
1127
|
+
model_specs = verify_model_region_and_return_specs(
|
|
1128
|
+
region=region,
|
|
1129
|
+
model_id=model_id,
|
|
1130
|
+
version=model_version,
|
|
1131
|
+
sagemaker_session=sagemaker_session,
|
|
1132
|
+
scope=scope,
|
|
1133
|
+
model_type=model_type,
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
if scope == enums.JumpStartScriptScope.INFERENCE:
|
|
1137
|
+
metadata_configs = model_specs.inference_configs
|
|
1138
|
+
elif scope == enums.JumpStartScriptScope.TRAINING:
|
|
1139
|
+
metadata_configs = model_specs.training_configs
|
|
1140
|
+
else:
|
|
1141
|
+
raise ValueError(f"Unknown script scope: {scope}.")
|
|
1142
|
+
|
|
1143
|
+
return list(metadata_configs.configs.keys()) if metadata_configs else []
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
def get_benchmark_stats(
|
|
1147
|
+
region: str,
|
|
1148
|
+
model_id: str,
|
|
1149
|
+
model_version: str,
|
|
1150
|
+
config_names: Optional[List[str]] = None,
|
|
1151
|
+
hub_arn: Optional[str] = None,
|
|
1152
|
+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
1153
|
+
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
|
|
1154
|
+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
|
|
1155
|
+
) -> Dict[str, List[JumpStartBenchmarkStat]]:
|
|
1156
|
+
"""Returns benchmark stats for the given model ID and region.
|
|
1157
|
+
|
|
1158
|
+
Raises:
|
|
1159
|
+
ValueError: If the script scope is not supported by JumpStart.
|
|
1160
|
+
"""
|
|
1161
|
+
model_specs = verify_model_region_and_return_specs(
|
|
1162
|
+
region=region,
|
|
1163
|
+
model_id=model_id,
|
|
1164
|
+
version=model_version,
|
|
1165
|
+
hub_arn=hub_arn,
|
|
1166
|
+
sagemaker_session=sagemaker_session,
|
|
1167
|
+
scope=scope,
|
|
1168
|
+
model_type=model_type,
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
if scope == enums.JumpStartScriptScope.INFERENCE:
|
|
1172
|
+
metadata_configs = model_specs.inference_configs
|
|
1173
|
+
elif scope == enums.JumpStartScriptScope.TRAINING:
|
|
1174
|
+
metadata_configs = model_specs.training_configs
|
|
1175
|
+
else:
|
|
1176
|
+
raise ValueError(f"Unknown script scope: {scope}.")
|
|
1177
|
+
|
|
1178
|
+
if not config_names:
|
|
1179
|
+
config_names = metadata_configs.configs.keys() if metadata_configs else []
|
|
1180
|
+
|
|
1181
|
+
benchmark_stats = {}
|
|
1182
|
+
for config_name in config_names:
|
|
1183
|
+
if config_name not in metadata_configs.configs:
|
|
1184
|
+
raise ValueError(f"Unknown config name: {config_name}")
|
|
1185
|
+
benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics
|
|
1186
|
+
|
|
1187
|
+
return benchmark_stats
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def get_jumpstart_configs(
|
|
1191
|
+
region: str,
|
|
1192
|
+
model_id: str,
|
|
1193
|
+
model_version: str,
|
|
1194
|
+
config_names: Optional[List[str]] = None,
|
|
1195
|
+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
1196
|
+
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
|
|
1197
|
+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
|
|
1198
|
+
hub_arn: Optional[str] = None,
|
|
1199
|
+
) -> Dict[str, JumpStartMetadataConfig]:
|
|
1200
|
+
"""Returns metadata configs for the given model ID and region.
|
|
1201
|
+
|
|
1202
|
+
Raises:
|
|
1203
|
+
ValueError: If the script scope is not supported by JumpStart.
|
|
1204
|
+
"""
|
|
1205
|
+
model_specs = verify_model_region_and_return_specs(
|
|
1206
|
+
region=region,
|
|
1207
|
+
model_id=model_id,
|
|
1208
|
+
version=model_version,
|
|
1209
|
+
sagemaker_session=sagemaker_session,
|
|
1210
|
+
scope=scope,
|
|
1211
|
+
model_type=model_type,
|
|
1212
|
+
hub_arn=hub_arn,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
if scope == enums.JumpStartScriptScope.INFERENCE:
|
|
1216
|
+
metadata_configs = model_specs.inference_configs
|
|
1217
|
+
elif scope == enums.JumpStartScriptScope.TRAINING:
|
|
1218
|
+
metadata_configs = model_specs.training_configs
|
|
1219
|
+
else:
|
|
1220
|
+
raise ValueError(f"Unknown script scope: {scope}.")
|
|
1221
|
+
|
|
1222
|
+
if not config_names:
|
|
1223
|
+
config_names = (
|
|
1224
|
+
metadata_configs.config_rankings.get("overall").rankings if metadata_configs else []
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
if hub_arn:
|
|
1228
|
+
return (
|
|
1229
|
+
{
|
|
1230
|
+
config_name: metadata_configs.configs[
|
|
1231
|
+
camel_to_snake(snake_to_upper_camel(config_name))
|
|
1232
|
+
]
|
|
1233
|
+
for config_name in config_names
|
|
1234
|
+
}
|
|
1235
|
+
if metadata_configs
|
|
1236
|
+
else {}
|
|
1237
|
+
)
|
|
1238
|
+
return (
|
|
1239
|
+
{config_name: metadata_configs.configs[config_name] for config_name in config_names}
|
|
1240
|
+
if metadata_configs
|
|
1241
|
+
else {}
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
|
|
1245
|
+
def get_jumpstart_user_agent_extra_suffix(
|
|
1246
|
+
model_id: Optional[str],
|
|
1247
|
+
model_version: Optional[str],
|
|
1248
|
+
config_name: Optional[str],
|
|
1249
|
+
is_hub_content: Optional[bool],
|
|
1250
|
+
) -> str:
|
|
1251
|
+
"""Returns the model-specific user agent string to be added to requests."""
|
|
1252
|
+
sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
|
|
1253
|
+
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
|
|
1254
|
+
config_specific_suffix = f"md/js_config#{config_name}"
|
|
1255
|
+
hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}"
|
|
1256
|
+
|
|
1257
|
+
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None):
|
|
1258
|
+
headers = sagemaker_python_sdk_headers
|
|
1259
|
+
elif is_hub_content is True:
|
|
1260
|
+
if model_id is None and model_version is None:
|
|
1261
|
+
headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}"
|
|
1262
|
+
else:
|
|
1263
|
+
headers = (
|
|
1264
|
+
f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}"
|
|
1265
|
+
)
|
|
1266
|
+
else:
|
|
1267
|
+
headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
|
|
1268
|
+
|
|
1269
|
+
if config_name:
|
|
1270
|
+
headers = f"{headers} {config_specific_suffix}"
|
|
1271
|
+
|
|
1272
|
+
return headers
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
def get_top_ranked_config_name(
|
|
1276
|
+
region: str,
|
|
1277
|
+
model_id: str,
|
|
1278
|
+
model_version: str,
|
|
1279
|
+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
1280
|
+
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
|
|
1281
|
+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
|
|
1282
|
+
tolerate_deprecated_model: bool = False,
|
|
1283
|
+
tolerate_vulnerable_model: bool = False,
|
|
1284
|
+
hub_arn: Optional[str] = None,
|
|
1285
|
+
ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT,
|
|
1286
|
+
) -> Optional[str]:
|
|
1287
|
+
"""Returns the top ranked config name for the given model ID and region.
|
|
1288
|
+
|
|
1289
|
+
Raises:
|
|
1290
|
+
ValueError: If the script scope is not supported by JumpStart.
|
|
1291
|
+
"""
|
|
1292
|
+
model_specs = verify_model_region_and_return_specs(
|
|
1293
|
+
model_id=model_id,
|
|
1294
|
+
version=model_version,
|
|
1295
|
+
scope=scope,
|
|
1296
|
+
region=region,
|
|
1297
|
+
hub_arn=hub_arn,
|
|
1298
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
1299
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
1300
|
+
sagemaker_session=sagemaker_session,
|
|
1301
|
+
model_type=model_type,
|
|
1302
|
+
)
|
|
1303
|
+
|
|
1304
|
+
if scope == enums.JumpStartScriptScope.INFERENCE:
|
|
1305
|
+
return (
|
|
1306
|
+
model_specs.inference_configs.get_top_config_from_ranking(
|
|
1307
|
+
ranking_name=ranking_name
|
|
1308
|
+
).config_name
|
|
1309
|
+
if model_specs.inference_configs
|
|
1310
|
+
else None
|
|
1311
|
+
)
|
|
1312
|
+
if scope == enums.JumpStartScriptScope.TRAINING:
|
|
1313
|
+
return (
|
|
1314
|
+
model_specs.training_configs.get_top_config_from_ranking(
|
|
1315
|
+
ranking_name=ranking_name
|
|
1316
|
+
).config_name
|
|
1317
|
+
if model_specs.training_configs
|
|
1318
|
+
else None
|
|
1319
|
+
)
|
|
1320
|
+
raise ValueError(f"Unsupported script scope: {scope}.")
|
|
1321
|
+
|
|
1322
|
+
|
|
1323
|
+
def get_default_jumpstart_session_with_user_agent_suffix(
|
|
1324
|
+
model_id: Optional[str] = None,
|
|
1325
|
+
model_version: Optional[str] = None,
|
|
1326
|
+
config_name: Optional[str] = None,
|
|
1327
|
+
is_hub_content: Optional[bool] = False,
|
|
1328
|
+
) -> Session:
|
|
1329
|
+
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
|
|
1330
|
+
botocore_session = botocore.session.get_session()
|
|
1331
|
+
botocore_config = botocore.config.Config(
|
|
1332
|
+
user_agent_extra=get_jumpstart_user_agent_extra_suffix(
|
|
1333
|
+
model_id=model_id,
|
|
1334
|
+
model_version=model_version,
|
|
1335
|
+
config_name=config_name,
|
|
1336
|
+
is_hub_content=is_hub_content,
|
|
1337
|
+
),
|
|
1338
|
+
)
|
|
1339
|
+
botocore_session.set_default_client_config(botocore_config)
|
|
1340
|
+
# shallow copy to not affect default session constant
|
|
1341
|
+
session = copy(constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION)
|
|
1342
|
+
session.boto_session = boto3.Session(
|
|
1343
|
+
region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, botocore_session=botocore_session
|
|
1344
|
+
)
|
|
1345
|
+
session.sagemaker_client = boto3.client(
|
|
1346
|
+
"sagemaker", region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, config=botocore_config
|
|
1347
|
+
)
|
|
1348
|
+
session.sagemaker_runtime_client = boto3.client(
|
|
1349
|
+
"sagemaker-runtime",
|
|
1350
|
+
region_name=constants.JUMPSTART_DEFAULT_REGION_NAME,
|
|
1351
|
+
config=botocore_config,
|
|
1352
|
+
)
|
|
1353
|
+
return session
|
|
1354
|
+
|
|
1355
|
+
|
|
1356
|
+
def add_instance_rate_stats_to_benchmark_metrics(
|
|
1357
|
+
region: str,
|
|
1358
|
+
benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]],
|
|
1359
|
+
) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
|
|
1360
|
+
"""Adds instance types metric stats to the given benchmark_metrics dict.
|
|
1361
|
+
|
|
1362
|
+
Args:
|
|
1363
|
+
region (str): AWS region.
|
|
1364
|
+
benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]):
|
|
1365
|
+
Returns:
|
|
1366
|
+
Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
|
|
1367
|
+
Contains Error and metrics.
|
|
1368
|
+
"""
|
|
1369
|
+
if not benchmark_metrics:
|
|
1370
|
+
return None
|
|
1371
|
+
|
|
1372
|
+
err_message = None
|
|
1373
|
+
final_benchmark_metrics = {}
|
|
1374
|
+
for instance_type, benchmark_metric_stats in benchmark_metrics.items():
|
|
1375
|
+
instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}"
|
|
1376
|
+
|
|
1377
|
+
if not has_instance_rate_stat(benchmark_metric_stats) and not err_message:
|
|
1378
|
+
try:
|
|
1379
|
+
instance_type_rate = get_instance_rate_per_hour(
|
|
1380
|
+
instance_type=instance_type, region=region
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
if not benchmark_metric_stats:
|
|
1384
|
+
benchmark_metric_stats = []
|
|
1385
|
+
benchmark_metric_stats.append(
|
|
1386
|
+
JumpStartBenchmarkStat({"concurrency": None, **instance_type_rate})
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
final_benchmark_metrics[instance_type] = benchmark_metric_stats
|
|
1390
|
+
except ClientError as e:
|
|
1391
|
+
final_benchmark_metrics[instance_type] = benchmark_metric_stats
|
|
1392
|
+
err_message = e.response["Error"]
|
|
1393
|
+
except Exception: # pylint: disable=W0703
|
|
1394
|
+
final_benchmark_metrics[instance_type] = benchmark_metric_stats
|
|
1395
|
+
else:
|
|
1396
|
+
final_benchmark_metrics[instance_type] = benchmark_metric_stats
|
|
1397
|
+
|
|
1398
|
+
return err_message, final_benchmark_metrics
|
|
1399
|
+
|
|
1400
|
+
|
|
1401
|
+
def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool:
|
|
1402
|
+
"""Determines whether a benchmark metric stats contains instance rate metric stat.
|
|
1403
|
+
|
|
1404
|
+
Args:
|
|
1405
|
+
benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]):
|
|
1406
|
+
List of benchmark metric stats.
|
|
1407
|
+
Returns:
|
|
1408
|
+
bool: Whether the benchmark metric stats contains instance rate metric stat.
|
|
1409
|
+
"""
|
|
1410
|
+
if benchmark_metric_stats is None:
|
|
1411
|
+
return True
|
|
1412
|
+
for benchmark_metric_stat in benchmark_metric_stats:
|
|
1413
|
+
if benchmark_metric_stat.name.lower() == "instance rate":
|
|
1414
|
+
return True
|
|
1415
|
+
return False
|
|
1416
|
+
|
|
1417
|
+
|
|
1418
|
+
def get_metrics_from_deployment_configs(
|
|
1419
|
+
deployment_configs: Optional[List[DeploymentConfigMetadata]],
|
|
1420
|
+
) -> Dict[str, List[str]]:
|
|
1421
|
+
"""Extracts benchmark metrics from deployment configs metadata.
|
|
1422
|
+
|
|
1423
|
+
Args:
|
|
1424
|
+
deployment_configs (Optional[List[DeploymentConfigMetadata]]):
|
|
1425
|
+
List of deployment configs metadata.
|
|
1426
|
+
Returns:
|
|
1427
|
+
Dict[str, List[str]]: Deployment configs bench metrics dict.
|
|
1428
|
+
"""
|
|
1429
|
+
if not deployment_configs:
|
|
1430
|
+
return {}
|
|
1431
|
+
|
|
1432
|
+
data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []}
|
|
1433
|
+
instance_rate_data = {}
|
|
1434
|
+
for index, deployment_config in enumerate(deployment_configs):
|
|
1435
|
+
benchmark_metrics = deployment_config.benchmark_metrics
|
|
1436
|
+
if not deployment_config.deployment_args or not benchmark_metrics:
|
|
1437
|
+
continue
|
|
1438
|
+
|
|
1439
|
+
for current_instance_type, current_instance_type_metrics in benchmark_metrics.items():
|
|
1440
|
+
instance_type_rate, concurrent_users = _normalize_benchmark_metrics(
|
|
1441
|
+
current_instance_type_metrics
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
for concurrent_user, metrics in concurrent_users.items():
|
|
1445
|
+
instance_type_to_display = (
|
|
1446
|
+
f"{current_instance_type} (Default)"
|
|
1447
|
+
if index == 0
|
|
1448
|
+
and concurrent_user
|
|
1449
|
+
and int(concurrent_user) == 1
|
|
1450
|
+
and current_instance_type
|
|
1451
|
+
== deployment_config.deployment_args.default_instance_type
|
|
1452
|
+
else current_instance_type
|
|
1453
|
+
)
|
|
1454
|
+
|
|
1455
|
+
data["Config Name"].append(deployment_config.deployment_config_name)
|
|
1456
|
+
data["Instance Type"].append(instance_type_to_display)
|
|
1457
|
+
data["Concurrent Users"].append(concurrent_user)
|
|
1458
|
+
|
|
1459
|
+
if instance_type_rate:
|
|
1460
|
+
instance_rate_column_name = (
|
|
1461
|
+
f"{instance_type_rate.name} ({instance_type_rate.unit})"
|
|
1462
|
+
)
|
|
1463
|
+
instance_rate_data[instance_rate_column_name] = instance_rate_data.get(
|
|
1464
|
+
instance_rate_column_name, []
|
|
1465
|
+
)
|
|
1466
|
+
instance_rate_data[instance_rate_column_name].append(instance_type_rate.value)
|
|
1467
|
+
|
|
1468
|
+
for metric in metrics:
|
|
1469
|
+
column_name = _normalize_benchmark_metric_column_name(metric.name, metric.unit)
|
|
1470
|
+
data[column_name] = data.get(column_name, [])
|
|
1471
|
+
data[column_name].append(metric.value)
|
|
1472
|
+
|
|
1473
|
+
data = {**data, **instance_rate_data}
|
|
1474
|
+
return data
|
|
1475
|
+
|
|
1476
|
+
|
|
1477
|
+
def _normalize_benchmark_metric_column_name(name: str, unit: str) -> str:
|
|
1478
|
+
"""Normalizes benchmark metric column name.
|
|
1479
|
+
|
|
1480
|
+
Args:
|
|
1481
|
+
name (str): Name of the metric.
|
|
1482
|
+
unit (str): Unit of the metric.
|
|
1483
|
+
Returns:
|
|
1484
|
+
str: Normalized metric column name.
|
|
1485
|
+
"""
|
|
1486
|
+
if "latency" in name.lower():
|
|
1487
|
+
name = f"Latency, TTFT (P50 in {unit.lower()})"
|
|
1488
|
+
elif "throughput" in name.lower():
|
|
1489
|
+
name = f"Throughput (P50 in {unit.lower()}/user)"
|
|
1490
|
+
return name
|
|
1491
|
+
|
|
1492
|
+
|
|
1493
|
+
def _normalize_benchmark_metrics(
|
|
1494
|
+
benchmark_metric_stats: List[JumpStartBenchmarkStat],
|
|
1495
|
+
) -> Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]:
|
|
1496
|
+
"""Normalizes benchmark metrics dict.
|
|
1497
|
+
|
|
1498
|
+
Args:
|
|
1499
|
+
benchmark_metric_stats (List[JumpStartBenchmarkStat]):
|
|
1500
|
+
List of benchmark metrics stats.
|
|
1501
|
+
Returns:
|
|
1502
|
+
Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]:
|
|
1503
|
+
Normalized benchmark metrics dict.
|
|
1504
|
+
"""
|
|
1505
|
+
instance_type_rate = None
|
|
1506
|
+
concurrent_users = {}
|
|
1507
|
+
for current_instance_type_metric in benchmark_metric_stats:
|
|
1508
|
+
if "instance rate" in current_instance_type_metric.name.lower():
|
|
1509
|
+
instance_type_rate = current_instance_type_metric
|
|
1510
|
+
elif current_instance_type_metric.concurrency not in concurrent_users:
|
|
1511
|
+
concurrent_users[current_instance_type_metric.concurrency] = [
|
|
1512
|
+
current_instance_type_metric
|
|
1513
|
+
]
|
|
1514
|
+
else:
|
|
1515
|
+
concurrent_users[current_instance_type_metric.concurrency].append(
|
|
1516
|
+
current_instance_type_metric
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
return instance_type_rate, concurrent_users
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
def deployment_config_response_data(
|
|
1523
|
+
deployment_configs: Optional[List[DeploymentConfigMetadata]],
|
|
1524
|
+
) -> List[Dict[str, Any]]:
|
|
1525
|
+
"""Deployment config api response data.
|
|
1526
|
+
|
|
1527
|
+
Args:
|
|
1528
|
+
deployment_configs (Optional[List[DeploymentConfigMetadata]]):
|
|
1529
|
+
List of deployment configs metadata.
|
|
1530
|
+
Returns:
|
|
1531
|
+
List[Dict[str, Any]]: List of deployment config api response data.
|
|
1532
|
+
"""
|
|
1533
|
+
configs = []
|
|
1534
|
+
if not deployment_configs:
|
|
1535
|
+
return configs
|
|
1536
|
+
|
|
1537
|
+
for deployment_config in deployment_configs:
|
|
1538
|
+
deployment_config_json = deployment_config.to_json()
|
|
1539
|
+
benchmark_metrics = deployment_config_json.get("BenchmarkMetrics")
|
|
1540
|
+
if benchmark_metrics and deployment_config.deployment_args:
|
|
1541
|
+
deployment_config_json["BenchmarkMetrics"] = {
|
|
1542
|
+
deployment_config.deployment_args.instance_type: benchmark_metrics.get(
|
|
1543
|
+
deployment_config.deployment_args.instance_type
|
|
1544
|
+
)
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
configs.append(deployment_config_json)
|
|
1548
|
+
return configs
|
|
1549
|
+
|
|
1550
|
+
|
|
1551
|
+
def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False):
|
|
1552
|
+
"""LRU cache for deployment configs."""
|
|
1553
|
+
|
|
1554
|
+
def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool:
|
|
1555
|
+
"""Determines whether metadata config contains instance rate metric stat.
|
|
1556
|
+
|
|
1557
|
+
Args:
|
|
1558
|
+
config (DeploymentConfigMetadata): Metadata config metadata.
|
|
1559
|
+
Returns:
|
|
1560
|
+
bool: Whether the metadata config contains instance rate metric stat.
|
|
1561
|
+
"""
|
|
1562
|
+
if config.benchmark_metrics is None:
|
|
1563
|
+
return True
|
|
1564
|
+
for benchmark_metric_stats in config.benchmark_metrics.values():
|
|
1565
|
+
if not has_instance_rate_stat(benchmark_metric_stats):
|
|
1566
|
+
return False
|
|
1567
|
+
return True
|
|
1568
|
+
|
|
1569
|
+
def wrapper_cache(f):
|
|
1570
|
+
f = lru_cache(maxsize=maxsize, typed=typed)(f)
|
|
1571
|
+
|
|
1572
|
+
@wraps(f)
|
|
1573
|
+
def wrapped_f(*args, **kwargs):
|
|
1574
|
+
res = f(*args, **kwargs)
|
|
1575
|
+
|
|
1576
|
+
# Clear cache on first call if
|
|
1577
|
+
# - The output does not contain Instant rate metrics
|
|
1578
|
+
# as this is caused by missing policy.
|
|
1579
|
+
if f.cache_info().hits == 0 and f.cache_info().misses == 1:
|
|
1580
|
+
if isinstance(res, list):
|
|
1581
|
+
for item in res:
|
|
1582
|
+
if isinstance(
|
|
1583
|
+
item, DeploymentConfigMetadata
|
|
1584
|
+
) and not has_instance_rate_metric(item):
|
|
1585
|
+
f.cache_clear()
|
|
1586
|
+
break
|
|
1587
|
+
elif isinstance(res, dict):
|
|
1588
|
+
keys = list(res.keys())
|
|
1589
|
+
if len(keys) == 0 or "Instance Rate" not in keys[-1]:
|
|
1590
|
+
f.cache_clear()
|
|
1591
|
+
elif len(res[keys[1]]) > len(res[keys[-1]]):
|
|
1592
|
+
del res[keys[-1]]
|
|
1593
|
+
f.cache_clear()
|
|
1594
|
+
return res
|
|
1595
|
+
|
|
1596
|
+
wrapped_f.cache_info = f.cache_info
|
|
1597
|
+
wrapped_f.cache_clear = f.cache_clear
|
|
1598
|
+
return wrapped_f
|
|
1599
|
+
|
|
1600
|
+
if _func is None:
|
|
1601
|
+
return wrapper_cache
|
|
1602
|
+
return wrapper_cache(_func)
|
|
1603
|
+
|
|
1604
|
+
|
|
1605
|
+
def _add_model_access_configs_to_model_data_sources(
|
|
1606
|
+
model_data_sources: List[Dict[str, any]],
|
|
1607
|
+
model_access_configs: Dict[str, ModelAccessConfig],
|
|
1608
|
+
model_id: str,
|
|
1609
|
+
region: str,
|
|
1610
|
+
) -> List[Dict[str, any]]:
|
|
1611
|
+
"""Iterate over the accept EULA configs to ensure all channels are matched
|
|
1612
|
+
|
|
1613
|
+
Args:
|
|
1614
|
+
model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated
|
|
1615
|
+
model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field
|
|
1616
|
+
model_id (DeploymentConfigMetadata): Jumpstart model id.
|
|
1617
|
+
region (str): Region where the user is operating in.
|
|
1618
|
+
Returns:
|
|
1619
|
+
List[Dict[str, Any]]: List of model data sources with accept EULA configs applied
|
|
1620
|
+
Raise:
|
|
1621
|
+
ValueError if at least one channel that requires EULA acceptance as not passed.
|
|
1622
|
+
"""
|
|
1623
|
+
if not model_data_sources:
|
|
1624
|
+
return model_data_sources
|
|
1625
|
+
|
|
1626
|
+
acked_model_data_sources = []
|
|
1627
|
+
for model_data_source in model_data_sources:
|
|
1628
|
+
hosting_eula_key = model_data_source.get("HostingEulaKey")
|
|
1629
|
+
mutable_model_data_source = model_data_source.copy()
|
|
1630
|
+
if hosting_eula_key:
|
|
1631
|
+
if (
|
|
1632
|
+
not model_access_configs
|
|
1633
|
+
or not model_access_configs.get(model_id)
|
|
1634
|
+
or not model_access_configs.get(model_id).accept_eula
|
|
1635
|
+
):
|
|
1636
|
+
eula_message_template = (
|
|
1637
|
+
"{model_source}{base_eula_message}{model_access_configs_message}"
|
|
1638
|
+
)
|
|
1639
|
+
model_access_config_entry = (
|
|
1640
|
+
'"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id)
|
|
1641
|
+
)
|
|
1642
|
+
raise ValueError(
|
|
1643
|
+
eula_message_template.format(
|
|
1644
|
+
model_source="Additional " if model_data_source.get("ChannelName") else "",
|
|
1645
|
+
base_eula_message=get_formatted_eula_message_template(
|
|
1646
|
+
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
|
|
1647
|
+
),
|
|
1648
|
+
model_access_configs_message=(
|
|
1649
|
+
"Please add a ModelAccessConfig entry:"
|
|
1650
|
+
f" {model_access_config_entry} "
|
|
1651
|
+
"to model_access_configs to accept the EULA."
|
|
1652
|
+
),
|
|
1653
|
+
)
|
|
1654
|
+
)
|
|
1655
|
+
mutable_model_data_source.pop(
|
|
1656
|
+
"HostingEulaKey"
|
|
1657
|
+
) # pop when model access config is applied
|
|
1658
|
+
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
|
|
1659
|
+
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
|
|
1660
|
+
)
|
|
1661
|
+
acked_model_data_sources.append(mutable_model_data_source)
|
|
1662
|
+
else:
|
|
1663
|
+
if "HostingEulaKey" in mutable_model_data_source:
|
|
1664
|
+
mutable_model_data_source.pop(
|
|
1665
|
+
"HostingEulaKey"
|
|
1666
|
+
) # pop when model access config is not applicable
|
|
1667
|
+
acked_model_data_sources.append(mutable_model_data_source)
|
|
1668
|
+
return acked_model_data_sources
|
|
1669
|
+
|
|
1670
|
+
|
|
1671
|
+
def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
|
|
1672
|
+
"""Returns the correct content bucket for a 1p draft model."""
|
|
1673
|
+
neo_bucket = get_neo_content_bucket(region=region)
|
|
1674
|
+
if not provider:
|
|
1675
|
+
return neo_bucket
|
|
1676
|
+
provider_name = provider.get("name", "")
|
|
1677
|
+
if provider_name == "JumpStart":
|
|
1678
|
+
classification = provider.get("classification", "ungated")
|
|
1679
|
+
if classification == "gated":
|
|
1680
|
+
return get_jumpstart_gated_content_bucket(region=region)
|
|
1681
|
+
return get_jumpstart_content_bucket(region=region)
|
|
1682
|
+
return neo_bucket
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
|
|
1686
|
+
init_kwargs: dict, accept_eula: Optional[bool]
|
|
1687
|
+
):
|
|
1688
|
+
"""Remove env vars if access configs are used
|
|
1689
|
+
|
|
1690
|
+
Args:
|
|
1691
|
+
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
|
|
1692
|
+
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
|
|
1693
|
+
"""
|
|
1694
|
+
if accept_eula is not None and init_kwargs["environment"]:
|
|
1695
|
+
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]
|
|
1696
|
+
|
|
1697
|
+
|
|
1698
|
+
def get_hub_access_config(hub_content_arn: Optional[str]):
|
|
1699
|
+
"""Get hub access config
|
|
1700
|
+
|
|
1701
|
+
Args:
|
|
1702
|
+
hub_content_arn (Optional[bool]): Arn of the model reference hub content
|
|
1703
|
+
"""
|
|
1704
|
+
if hub_content_arn is not None:
|
|
1705
|
+
hub_access_config = {"HubContentArn": hub_content_arn}
|
|
1706
|
+
else:
|
|
1707
|
+
hub_access_config = None
|
|
1708
|
+
|
|
1709
|
+
return hub_access_config
|
|
1710
|
+
|
|
1711
|
+
|
|
1712
|
+
def get_model_access_config(accept_eula: Optional[bool]):
|
|
1713
|
+
"""Get access configs
|
|
1714
|
+
|
|
1715
|
+
Args:
|
|
1716
|
+
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
|
|
1717
|
+
"""
|
|
1718
|
+
if accept_eula is not None:
|
|
1719
|
+
model_access_config = {"AcceptEula": accept_eula}
|
|
1720
|
+
else:
|
|
1721
|
+
model_access_config = None
|
|
1722
|
+
|
|
1723
|
+
return model_access_config
|
|
1724
|
+
|
|
1725
|
+
|
|
1726
|
+
def get_latest_version(versions: List[str]) -> Optional[str]:
|
|
1727
|
+
"""Returns the latest version using sem-ver when possible."""
|
|
1728
|
+
try:
|
|
1729
|
+
return None if not versions else max(versions, key=Version)
|
|
1730
|
+
except InvalidVersion:
|
|
1731
|
+
return max(versions)
|