sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,3044 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
# pylint: skip-file
|
|
14
|
+
"""This module stores types related to SageMaker JumpStart."""
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
import re
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Union, TYPE_CHECKING
|
|
21
|
+
from sagemaker.core.shapes import ModelAccessConfig as CoreModelAccessConfig
|
|
22
|
+
from sagemaker.core.shapes.model_card_shapes import ModelCardContent as ModelCard
|
|
23
|
+
from sagemaker.core.shapes.model_card_shapes import ModelCardContent as ModelPackageModelCard
|
|
24
|
+
from sagemaker.core.common_utils import (
|
|
25
|
+
S3_PREFIX,
|
|
26
|
+
get_instance_type_family,
|
|
27
|
+
format_tags,
|
|
28
|
+
Tags,
|
|
29
|
+
deep_override_dict,
|
|
30
|
+
camel_to_snake,
|
|
31
|
+
walk_and_apply_json,
|
|
32
|
+
)
|
|
33
|
+
from sagemaker.core.model_metrics import ModelMetrics
|
|
34
|
+
from sagemaker.core.metadata_properties import MetadataProperties
|
|
35
|
+
from sagemaker.core.drift_check_baselines import DriftCheckBaselines
|
|
36
|
+
from sagemaker.core.jumpstart.enums import (
|
|
37
|
+
JumpStartModelType,
|
|
38
|
+
JumpStartScriptScope,
|
|
39
|
+
JumpStartConfigRankingName,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from sagemaker.core.helper.session_helper import Session
|
|
43
|
+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
|
|
44
|
+
from sagemaker.core.enums import EndpointType
|
|
45
|
+
|
|
46
|
+
from sagemaker.core.model_life_cycle import ModelLifeCycle
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from sagemaker.core.resource_requirements import ResourceRequirements
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class JumpStartDataHolderType:
|
|
53
|
+
"""Base class for many JumpStart types.
|
|
54
|
+
|
|
55
|
+
Allows objects to be added to dicts and sets,
|
|
56
|
+
and improves string representation. This class overrides the ``__eq__``
|
|
57
|
+
and ``__hash__`` methods so that different objects with the same attributes/types
|
|
58
|
+
can be compared.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
__slots__: List[str] = []
|
|
62
|
+
|
|
63
|
+
_non_serializable_slots: List[str] = []
|
|
64
|
+
|
|
65
|
+
def __eq__(self, other: Any) -> bool:
|
|
66
|
+
"""Returns True if ``other`` is of the same type and has all attributes equal.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
other (Any): Other object to which to compare this object.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
if not isinstance(other, type(self)):
|
|
73
|
+
return False
|
|
74
|
+
if getattr(other, "__slots__", None) is None:
|
|
75
|
+
return False
|
|
76
|
+
if self.__slots__ != other.__slots__:
|
|
77
|
+
return False
|
|
78
|
+
for attribute in self.__slots__:
|
|
79
|
+
if (hasattr(self, attribute) and not hasattr(other, attribute)) or (
|
|
80
|
+
hasattr(other, attribute) and not hasattr(self, attribute)
|
|
81
|
+
):
|
|
82
|
+
return False
|
|
83
|
+
if hasattr(self, attribute) and hasattr(other, attribute):
|
|
84
|
+
if getattr(self, attribute) != getattr(other, attribute):
|
|
85
|
+
return False
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
def __hash__(self) -> int:
|
|
89
|
+
"""Makes hash of object.
|
|
90
|
+
|
|
91
|
+
Maps object to unique tuple, which then gets hashed.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
return hash((type(self),) + tuple([getattr(self, att) for att in self.__slots__]))
|
|
95
|
+
|
|
96
|
+
def __str__(self) -> str:
|
|
97
|
+
"""Returns string representation of object. Example:
|
|
98
|
+
|
|
99
|
+
"JumpStartLaunchedRegionInfo:
|
|
100
|
+
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
att_dict = {
|
|
104
|
+
att: getattr(self, att)
|
|
105
|
+
for att in self.__slots__
|
|
106
|
+
if hasattr(self, att) and att not in self._non_serializable_slots
|
|
107
|
+
}
|
|
108
|
+
return f"{type(self).__name__}: {str(att_dict)}"
|
|
109
|
+
|
|
110
|
+
def __repr__(self) -> str:
|
|
111
|
+
"""Returns ``__repr__`` string of object. Example:
|
|
112
|
+
|
|
113
|
+
"JumpStartLaunchedRegionInfo at 0x7f664529efa0:
|
|
114
|
+
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
att_dict = {
|
|
118
|
+
att: getattr(self, att)
|
|
119
|
+
for att in self.__slots__
|
|
120
|
+
if hasattr(self, att) and att not in self._non_serializable_slots
|
|
121
|
+
}
|
|
122
|
+
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class JumpStartS3FileType(str, Enum):
|
|
126
|
+
"""Type of files published in JumpStart S3 distribution buckets."""
|
|
127
|
+
|
|
128
|
+
OPEN_WEIGHT_MANIFEST = "manifest"
|
|
129
|
+
OPEN_WEIGHT_SPECS = "specs"
|
|
130
|
+
PROPRIETARY_MANIFEST = "proprietary_manifest"
|
|
131
|
+
PROPRIETARY_SPECS = "proprietary_specs"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class HubType(str, Enum):
|
|
135
|
+
"""Enum for Hub objects."""
|
|
136
|
+
|
|
137
|
+
HUB = "Hub"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class HubContentType(str, Enum):
|
|
141
|
+
"""Enum for Hub content objects."""
|
|
142
|
+
|
|
143
|
+
MODEL = "Model"
|
|
144
|
+
NOTEBOOK = "Notebook"
|
|
145
|
+
MODEL_REFERENCE = "ModelReference"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType]
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
|
|
152
|
+
"""Data class for launched region info."""
|
|
153
|
+
|
|
154
|
+
__slots__ = ["content_bucket", "region_name", "gated_content_bucket", "neo_content_bucket"]
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
content_bucket: str,
|
|
159
|
+
region_name: str,
|
|
160
|
+
gated_content_bucket: Optional[str] = None,
|
|
161
|
+
neo_content_bucket: Optional[str] = None,
|
|
162
|
+
):
|
|
163
|
+
"""Instantiates JumpStartLaunchedRegionInfo object.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
content_bucket (str): Name of JumpStart s3 content bucket associated with region.
|
|
167
|
+
region_name (str): Name of JumpStart launched region.
|
|
168
|
+
gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
|
|
169
|
+
optionally associated with region.
|
|
170
|
+
neo_content_bucket (Optional[str]): Name of Neo service s3 content bucket
|
|
171
|
+
optionally associated with region.
|
|
172
|
+
"""
|
|
173
|
+
self.content_bucket = content_bucket
|
|
174
|
+
self.gated_content_bucket = gated_content_bucket
|
|
175
|
+
self.region_name = region_name
|
|
176
|
+
self.neo_content_bucket = neo_content_bucket
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class JumpStartModelHeader(JumpStartDataHolderType):
|
|
180
|
+
"""Data class JumpStart model header."""
|
|
181
|
+
|
|
182
|
+
__slots__ = ["model_id", "version", "min_version", "spec_key", "search_keywords"]
|
|
183
|
+
|
|
184
|
+
def __init__(self, header: Dict[str, str]):
|
|
185
|
+
"""Initializes a JumpStartModelHeader object from its json representation.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
header (Dict[str, str]): Dictionary representation of header.
|
|
189
|
+
"""
|
|
190
|
+
self.from_json(header)
|
|
191
|
+
|
|
192
|
+
def to_json(self) -> Dict[str, str]:
|
|
193
|
+
"""Returns json representation of JumpStartModelHeader object."""
|
|
194
|
+
json_obj = {
|
|
195
|
+
att: getattr(self, att)
|
|
196
|
+
for att in self.__slots__
|
|
197
|
+
if getattr(self, att, None) is not None
|
|
198
|
+
}
|
|
199
|
+
return json_obj
|
|
200
|
+
|
|
201
|
+
def from_json(self, json_obj: Dict[str, str]) -> None:
|
|
202
|
+
"""Sets fields in object based on json of header.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
json_obj (Dict[str, str]): Dictionary representation of header.
|
|
206
|
+
"""
|
|
207
|
+
self.model_id: str = json_obj["model_id"]
|
|
208
|
+
self.version: str = json_obj["version"]
|
|
209
|
+
self.min_version: str = json_obj["min_version"]
|
|
210
|
+
self.spec_key: str = json_obj["spec_key"]
|
|
211
|
+
self.search_keywords: Optional[List[str]] = json_obj.get("search_keywords")
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class JumpStartECRSpecs(JumpStartDataHolderType):
|
|
215
|
+
"""Data class for JumpStart ECR specs."""
|
|
216
|
+
|
|
217
|
+
__slots__ = [
|
|
218
|
+
"framework",
|
|
219
|
+
"framework_version",
|
|
220
|
+
"py_version",
|
|
221
|
+
"huggingface_transformers_version",
|
|
222
|
+
"_is_hub_content",
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
_non_serializable_slots = ["_is_hub_content"]
|
|
226
|
+
|
|
227
|
+
def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
|
|
228
|
+
"""Initializes a JumpStartECRSpecs object from its json representation.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
spec (Dict[str, Any]): Dictionary representation of spec.
|
|
232
|
+
"""
|
|
233
|
+
self._is_hub_content = is_hub_content
|
|
234
|
+
self.from_json(spec)
|
|
235
|
+
|
|
236
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
237
|
+
"""Sets fields in object based on json.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
json_obj (Dict[str, Any]): Dictionary representation of spec.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
if not json_obj:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
if self._is_hub_content:
|
|
247
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
248
|
+
|
|
249
|
+
self.framework = json_obj.get("framework")
|
|
250
|
+
self.framework_version = json_obj.get("framework_version")
|
|
251
|
+
self.py_version = json_obj.get("py_version")
|
|
252
|
+
huggingface_transformers_version = json_obj.get("huggingface_transformers_version")
|
|
253
|
+
if huggingface_transformers_version is not None:
|
|
254
|
+
self.huggingface_transformers_version = huggingface_transformers_version
|
|
255
|
+
|
|
256
|
+
def to_json(self) -> Dict[str, Any]:
|
|
257
|
+
"""Returns json representation of JumpStartECRSpecs object."""
|
|
258
|
+
json_obj = {
|
|
259
|
+
att: getattr(self, att)
|
|
260
|
+
for att in self.__slots__
|
|
261
|
+
if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
|
|
262
|
+
}
|
|
263
|
+
return json_obj
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class JumpStartHyperparameter(JumpStartDataHolderType):
|
|
267
|
+
"""Data class for JumpStart hyperparameter definition in the training container."""
|
|
268
|
+
|
|
269
|
+
__slots__ = [
|
|
270
|
+
"name",
|
|
271
|
+
"type",
|
|
272
|
+
"options",
|
|
273
|
+
"default",
|
|
274
|
+
"scope",
|
|
275
|
+
"min",
|
|
276
|
+
"max",
|
|
277
|
+
"exclusive_min",
|
|
278
|
+
"exclusive_max",
|
|
279
|
+
"_is_hub_content",
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
_non_serializable_slots = ["_is_hub_content"]
|
|
283
|
+
|
|
284
|
+
def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
|
|
285
|
+
"""Initializes a JumpStartHyperparameter object from its json representation.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
spec (Dict[str, Any]): Dictionary representation of hyperparameter.
|
|
289
|
+
"""
|
|
290
|
+
self._is_hub_content = is_hub_content
|
|
291
|
+
self.from_json(spec)
|
|
292
|
+
|
|
293
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
294
|
+
"""Sets fields in object based on json.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
json_obj (Dict[str, Any]): Dictionary representation of hyperparameter.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
if self._is_hub_content:
|
|
301
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
302
|
+
self.name = json_obj["name"]
|
|
303
|
+
self.type = json_obj["type"]
|
|
304
|
+
self.default = json_obj["default"]
|
|
305
|
+
self.scope = json_obj["scope"]
|
|
306
|
+
|
|
307
|
+
options = json_obj.get("options")
|
|
308
|
+
if options is not None:
|
|
309
|
+
self.options = options
|
|
310
|
+
|
|
311
|
+
min_val = json_obj.get("min")
|
|
312
|
+
if min_val is not None:
|
|
313
|
+
self.min = min_val
|
|
314
|
+
|
|
315
|
+
max_val = json_obj.get("max")
|
|
316
|
+
if max_val is not None:
|
|
317
|
+
self.max = max_val
|
|
318
|
+
|
|
319
|
+
# HubContentDocument model schema does not allow exclusive min/max.
|
|
320
|
+
if self._is_hub_content:
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
exclusive_min_val = json_obj.get("exclusive_min")
|
|
324
|
+
exclusive_max_val = json_obj.get("exclusive_max")
|
|
325
|
+
if exclusive_min_val is not None:
|
|
326
|
+
self.exclusive_min = exclusive_min_val
|
|
327
|
+
if exclusive_max_val is not None:
|
|
328
|
+
self.exclusive_max = exclusive_max_val
|
|
329
|
+
|
|
330
|
+
def to_json(self) -> Dict[str, Any]:
|
|
331
|
+
"""Returns json representation of JumpStartHyperparameter object."""
|
|
332
|
+
json_obj = {
|
|
333
|
+
att: getattr(self, att)
|
|
334
|
+
for att in self.__slots__
|
|
335
|
+
if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
|
|
336
|
+
}
|
|
337
|
+
return json_obj
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class JumpStartEnvironmentVariable(JumpStartDataHolderType):
|
|
341
|
+
"""Data class for JumpStart environment variable definitions in the hosting container."""
|
|
342
|
+
|
|
343
|
+
__slots__ = [
|
|
344
|
+
"name",
|
|
345
|
+
"type",
|
|
346
|
+
"default",
|
|
347
|
+
"scope",
|
|
348
|
+
"required_for_model_class",
|
|
349
|
+
"_is_hub_content",
|
|
350
|
+
]
|
|
351
|
+
|
|
352
|
+
_non_serializable_slots = ["_is_hub_content"]
|
|
353
|
+
|
|
354
|
+
def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
|
|
355
|
+
"""Initializes a JumpStartEnvironmentVariable object from its json representation.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
spec (Dict[str, Any]): Dictionary representation of environment variable.
|
|
359
|
+
"""
|
|
360
|
+
self._is_hub_content = is_hub_content
|
|
361
|
+
self.from_json(spec)
|
|
362
|
+
|
|
363
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
364
|
+
"""Sets fields in object based on json.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
json_obj (Dict[str, Any]): Dictionary representation of environment variable.
|
|
368
|
+
"""
|
|
369
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
370
|
+
self.name = json_obj["name"]
|
|
371
|
+
self.type = json_obj["type"]
|
|
372
|
+
self.default = json_obj["default"]
|
|
373
|
+
self.scope = json_obj["scope"]
|
|
374
|
+
self.required_for_model_class: bool = json_obj.get("required_for_model_class", False)
|
|
375
|
+
|
|
376
|
+
def to_json(self) -> Dict[str, Any]:
|
|
377
|
+
"""Returns json representation of JumpStartEnvironmentVariable object."""
|
|
378
|
+
json_obj = {
|
|
379
|
+
att: getattr(self, att)
|
|
380
|
+
for att in self.__slots__
|
|
381
|
+
if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
|
|
382
|
+
}
|
|
383
|
+
return json_obj
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class JumpStartPredictorSpecs(JumpStartDataHolderType):
|
|
387
|
+
"""Data class for JumpStart Predictor specs."""
|
|
388
|
+
|
|
389
|
+
__slots__ = [
|
|
390
|
+
"default_content_type",
|
|
391
|
+
"supported_content_types",
|
|
392
|
+
"default_accept_type",
|
|
393
|
+
"supported_accept_types",
|
|
394
|
+
"_is_hub_content",
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
_non_serializable_slots = ["_is_hub_content"]
|
|
398
|
+
|
|
399
|
+
def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False):
|
|
400
|
+
"""Initializes a JumpStartPredictorSpecs object from its json representation.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
spec (Dict[str, Any]): Dictionary representation of predictor specs.
|
|
404
|
+
"""
|
|
405
|
+
self._is_hub_content = is_hub_content
|
|
406
|
+
self.from_json(spec)
|
|
407
|
+
|
|
408
|
+
def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
|
|
409
|
+
"""Sets fields in object based on json.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
json_obj (Dict[str, Any]): Dictionary representation of predictor specs.
|
|
413
|
+
"""
|
|
414
|
+
|
|
415
|
+
if json_obj is None:
|
|
416
|
+
return
|
|
417
|
+
|
|
418
|
+
if self._is_hub_content:
|
|
419
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
420
|
+
self.default_content_type = json_obj["default_content_type"]
|
|
421
|
+
self.supported_content_types = json_obj["supported_content_types"]
|
|
422
|
+
self.default_accept_type = json_obj["default_accept_type"]
|
|
423
|
+
self.supported_accept_types = json_obj["supported_accept_types"]
|
|
424
|
+
|
|
425
|
+
def to_json(self) -> Dict[str, Any]:
|
|
426
|
+
"""Returns json representation of JumpStartPredictorSpecs object."""
|
|
427
|
+
json_obj = {
|
|
428
|
+
att: getattr(self, att)
|
|
429
|
+
for att in self.__slots__
|
|
430
|
+
if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
|
|
431
|
+
}
|
|
432
|
+
return json_obj
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class JumpStartSerializablePayload(JumpStartDataHolderType):
|
|
436
|
+
"""Data class for JumpStart serialized payload specs."""
|
|
437
|
+
|
|
438
|
+
__slots__ = [
|
|
439
|
+
"raw_payload",
|
|
440
|
+
"content_type",
|
|
441
|
+
"accept",
|
|
442
|
+
"body",
|
|
443
|
+
"prompt_key",
|
|
444
|
+
"_is_hub_content",
|
|
445
|
+
]
|
|
446
|
+
|
|
447
|
+
_non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"]
|
|
448
|
+
|
|
449
|
+
def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False):
|
|
450
|
+
"""Initializes a JumpStartSerializablePayload object from its json representation.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
spec (Dict[str, Any]): Dictionary representation of payload specs.
|
|
454
|
+
"""
|
|
455
|
+
self._is_hub_content = is_hub_content
|
|
456
|
+
self.from_json(spec)
|
|
457
|
+
|
|
458
|
+
def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
|
|
459
|
+
"""Sets fields in object based on json.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
json_obj (Dict[str, Any]): Dictionary representation of serializable
|
|
463
|
+
payload specs.
|
|
464
|
+
|
|
465
|
+
Raises:
|
|
466
|
+
KeyError: If the dictionary is missing keys.
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
if json_obj is None:
|
|
470
|
+
return
|
|
471
|
+
|
|
472
|
+
if self._is_hub_content:
|
|
473
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
474
|
+
self.raw_payload = json_obj
|
|
475
|
+
self.content_type = json_obj["content_type"]
|
|
476
|
+
self.body = json_obj.get("body")
|
|
477
|
+
accept = json_obj.get("accept")
|
|
478
|
+
self.prompt_key = json_obj.get("prompt_key")
|
|
479
|
+
if accept:
|
|
480
|
+
self.accept = accept
|
|
481
|
+
|
|
482
|
+
def to_json(self) -> Dict[str, Any]:
|
|
483
|
+
"""Returns json representation of JumpStartSerializablePayload object."""
|
|
484
|
+
return deepcopy(self.raw_payload)
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
class JumpStartInstanceTypeVariants(JumpStartDataHolderType):
|
|
488
|
+
"""Data class for JumpStart instance type variants."""
|
|
489
|
+
|
|
490
|
+
__slots__ = [
|
|
491
|
+
"regional_aliases",
|
|
492
|
+
"aliases",
|
|
493
|
+
"variants",
|
|
494
|
+
"_is_hub_content",
|
|
495
|
+
]
|
|
496
|
+
|
|
497
|
+
_non_serializable_slots = ["_is_hub_content"]
|
|
498
|
+
|
|
499
|
+
def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False):
|
|
500
|
+
"""Initializes a JumpStartInstanceTypeVariants object from its json representation.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
spec (Dict[str, Any]): Dictionary representation of instance type variants.
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
self._is_hub_content = is_hub_content
|
|
507
|
+
|
|
508
|
+
if self._is_hub_content:
|
|
509
|
+
self.from_describe_hub_content_response(spec)
|
|
510
|
+
else:
|
|
511
|
+
self.from_json(spec)
|
|
512
|
+
|
|
513
|
+
def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
|
|
514
|
+
"""Sets fields in object based on json.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
json_obj (Dict[str, Any]): Dictionary representation of instance type variants.
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
if json_obj is None:
|
|
521
|
+
return
|
|
522
|
+
|
|
523
|
+
self.aliases = None
|
|
524
|
+
self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases")
|
|
525
|
+
self.variants: Optional[dict] = json_obj.get("variants")
|
|
526
|
+
|
|
527
|
+
def to_json(self) -> Dict[str, Any]:
|
|
528
|
+
"""Returns json representation of JumpStartInstance object."""
|
|
529
|
+
json_obj = {
|
|
530
|
+
att: getattr(self, att)
|
|
531
|
+
for att in self.__slots__
|
|
532
|
+
if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", [])
|
|
533
|
+
}
|
|
534
|
+
return json_obj
|
|
535
|
+
|
|
536
|
+
def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None:
|
|
537
|
+
"""Sets fields in object based on DescribeHubContent response.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
response (Dict[str, Any]): Dictionary representation of instance type variants.
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
if response is None:
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
response = walk_and_apply_json(response, camel_to_snake)
|
|
547
|
+
self.aliases: Optional[dict] = response.get("aliases")
|
|
548
|
+
self.regional_aliases = None
|
|
549
|
+
self.variants: Optional[dict] = response.get("variants")
|
|
550
|
+
|
|
551
|
+
def regionalize( # pylint: disable=inconsistent-return-statements
|
|
552
|
+
self, region: str
|
|
553
|
+
) -> Optional[Dict[str, Any]]:
|
|
554
|
+
"""Returns regionalized instance type variants."""
|
|
555
|
+
|
|
556
|
+
if self.regional_aliases is None or self.aliases is not None:
|
|
557
|
+
return
|
|
558
|
+
aliases = self.regional_aliases.get(region, {})
|
|
559
|
+
variants = {}
|
|
560
|
+
for instance_name, properties in self.variants.items():
|
|
561
|
+
if properties.get("regional_properties") is not None:
|
|
562
|
+
variants.update({instance_name: properties.get("regional_properties")})
|
|
563
|
+
if properties.get("properties") is not None:
|
|
564
|
+
variants.update({instance_name: properties.get("properties")})
|
|
565
|
+
return {"Aliases": aliases, "Variants": variants}
|
|
566
|
+
|
|
567
|
+
def get_instance_specific_metric_definitions(
|
|
568
|
+
self, instance_type: str
|
|
569
|
+
) -> List[JumpStartHyperparameter]:
|
|
570
|
+
"""Returns instance specific metric definitions.
|
|
571
|
+
|
|
572
|
+
Returns empty list if a model, instance type tuple does not have specific
|
|
573
|
+
metric definitions.
|
|
574
|
+
"""
|
|
575
|
+
|
|
576
|
+
if self.variants is None:
|
|
577
|
+
return []
|
|
578
|
+
|
|
579
|
+
instance_specific_metric_definitions: List[Dict[str, Union[str, Any]]] = (
|
|
580
|
+
self.variants.get(instance_type, {}).get("properties", {}).get("metrics", [])
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
584
|
+
|
|
585
|
+
instance_family_metric_definitions: List[Dict[str, Union[str, Any]]] = (
|
|
586
|
+
self.variants.get(instance_type_family, {}).get("properties", {}).get("metrics", [])
|
|
587
|
+
if instance_type_family not in {"", None}
|
|
588
|
+
else []
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
instance_specific_metric_names = {
|
|
592
|
+
metric_definition["Name"] for metric_definition in instance_specific_metric_definitions
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
metric_definitions_to_return = deepcopy(instance_specific_metric_definitions)
|
|
596
|
+
|
|
597
|
+
for instance_family_metric_definition in instance_family_metric_definitions:
|
|
598
|
+
if instance_family_metric_definition["Name"] not in instance_specific_metric_names:
|
|
599
|
+
metric_definitions_to_return.append(instance_family_metric_definition)
|
|
600
|
+
|
|
601
|
+
return metric_definitions_to_return
|
|
602
|
+
|
|
603
|
+
def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
|
|
604
|
+
"""Returns instance specific model artifact key.
|
|
605
|
+
|
|
606
|
+
Returns None if a model, instance type tuple does not have specific
|
|
607
|
+
artifact key.
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
return self._get_instance_specific_property(
|
|
611
|
+
instance_type=instance_type, property_name="prepacked_artifact_key"
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
|
|
615
|
+
"""Returns instance specific model artifact key.
|
|
616
|
+
|
|
617
|
+
Returns None if a model, instance type tuple does not have specific
|
|
618
|
+
artifact key.
|
|
619
|
+
"""
|
|
620
|
+
|
|
621
|
+
return self._get_instance_specific_property(
|
|
622
|
+
instance_type=instance_type, property_name="artifact_key"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]:
|
|
626
|
+
"""Returns instance specific training artifact key.
|
|
627
|
+
|
|
628
|
+
Returns None if a model, instance type tuple does not have specific
|
|
629
|
+
training artifact key.
|
|
630
|
+
"""
|
|
631
|
+
|
|
632
|
+
return self._get_instance_specific_property(
|
|
633
|
+
instance_type=instance_type, property_name="training_artifact_uri"
|
|
634
|
+
) or self._get_instance_specific_property(
|
|
635
|
+
instance_type=instance_type, property_name="training_artifact_key"
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
|
|
639
|
+
"""Returns instance specific resource requirements.
|
|
640
|
+
|
|
641
|
+
If a value exists for both the instance family and instance type, the instance type value
|
|
642
|
+
is chosen.
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
instance_specific_resource_requirements: dict = (
|
|
646
|
+
self.variants.get(instance_type, {})
|
|
647
|
+
.get("properties", {})
|
|
648
|
+
.get("resource_requirements", {})
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
652
|
+
|
|
653
|
+
instance_family_resource_requirements: dict = (
|
|
654
|
+
self.variants.get(instance_type_family, {})
|
|
655
|
+
.get("properties", {})
|
|
656
|
+
.get("resource_requirements", {})
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
|
|
660
|
+
|
|
661
|
+
def _get_instance_specific_property(
|
|
662
|
+
self, instance_type: str, property_name: str
|
|
663
|
+
) -> Optional[str]:
|
|
664
|
+
"""Returns instance specific property.
|
|
665
|
+
|
|
666
|
+
If a value exists for both the instance family and instance type,
|
|
667
|
+
the instance type value is chosen.
|
|
668
|
+
|
|
669
|
+
Returns None if a (model, instance type, property name) tuple does not have
|
|
670
|
+
specific prepacked artifact key.
|
|
671
|
+
"""
|
|
672
|
+
|
|
673
|
+
if self.variants is None:
|
|
674
|
+
return None
|
|
675
|
+
|
|
676
|
+
instance_specific_property: Optional[str] = (
|
|
677
|
+
self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
if instance_specific_property:
|
|
681
|
+
return instance_specific_property
|
|
682
|
+
|
|
683
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
684
|
+
|
|
685
|
+
instance_family_property: Optional[str] = (
|
|
686
|
+
self.variants.get(instance_type_family, {})
|
|
687
|
+
.get("properties", {})
|
|
688
|
+
.get(property_name, None)
|
|
689
|
+
if instance_type_family not in {"", None}
|
|
690
|
+
else None
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
return instance_family_property
|
|
694
|
+
|
|
695
|
+
def get_instance_specific_hyperparameters(
|
|
696
|
+
self, instance_type: str
|
|
697
|
+
) -> List[JumpStartHyperparameter]:
|
|
698
|
+
"""Returns instance specific hyperparameters.
|
|
699
|
+
|
|
700
|
+
Returns empty list if a model, instance type tuple does not have specific
|
|
701
|
+
hyperparameters.
|
|
702
|
+
"""
|
|
703
|
+
|
|
704
|
+
if self.variants is None:
|
|
705
|
+
return []
|
|
706
|
+
|
|
707
|
+
instance_specific_hyperparameters: List[JumpStartHyperparameter] = [
|
|
708
|
+
JumpStartHyperparameter(json)
|
|
709
|
+
for json in self.variants.get(instance_type, {})
|
|
710
|
+
.get("properties", {})
|
|
711
|
+
.get("hyperparameters", [])
|
|
712
|
+
]
|
|
713
|
+
|
|
714
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
715
|
+
|
|
716
|
+
instance_family_hyperparameters: List[JumpStartHyperparameter] = [
|
|
717
|
+
JumpStartHyperparameter(json)
|
|
718
|
+
for json in (
|
|
719
|
+
self.variants.get(instance_type_family, {})
|
|
720
|
+
.get("properties", {})
|
|
721
|
+
.get("hyperparameters", [])
|
|
722
|
+
if instance_type_family not in {"", None}
|
|
723
|
+
else []
|
|
724
|
+
)
|
|
725
|
+
]
|
|
726
|
+
|
|
727
|
+
instance_specific_hyperparameter_names = {
|
|
728
|
+
hyperparameter.name for hyperparameter in instance_specific_hyperparameters
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
hyperparams_to_return = deepcopy(instance_specific_hyperparameters)
|
|
732
|
+
|
|
733
|
+
for hyperparameter in instance_family_hyperparameters:
|
|
734
|
+
if hyperparameter.name not in instance_specific_hyperparameter_names:
|
|
735
|
+
hyperparams_to_return.append(hyperparameter)
|
|
736
|
+
|
|
737
|
+
return hyperparams_to_return
|
|
738
|
+
|
|
739
|
+
def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
|
|
740
|
+
"""Returns instance specific environment variables.
|
|
741
|
+
|
|
742
|
+
Returns empty dict if a model, instance type tuple does not have specific
|
|
743
|
+
environment variables.
|
|
744
|
+
"""
|
|
745
|
+
|
|
746
|
+
if self.variants is None:
|
|
747
|
+
return {}
|
|
748
|
+
|
|
749
|
+
instance_specific_environment_variables: Dict[str, str] = (
|
|
750
|
+
self.variants.get(instance_type, {})
|
|
751
|
+
.get("properties", {})
|
|
752
|
+
.get("environment_variables", {})
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
756
|
+
|
|
757
|
+
instance_family_environment_variables: dict = (
|
|
758
|
+
self.variants.get(instance_type_family, {})
|
|
759
|
+
.get("properties", {})
|
|
760
|
+
.get("environment_variables", {})
|
|
761
|
+
if instance_type_family not in {"", None}
|
|
762
|
+
else {}
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
instance_family_environment_variables.update(instance_specific_environment_variables)
|
|
766
|
+
|
|
767
|
+
return instance_family_environment_variables
|
|
768
|
+
|
|
769
|
+
def get_instance_specific_gated_model_key_env_var_value(
|
|
770
|
+
self, instance_type: str
|
|
771
|
+
) -> Optional[str]:
|
|
772
|
+
"""Returns instance specific gated model env var s3 key.
|
|
773
|
+
|
|
774
|
+
Returns None if a model, instance type tuple does not have instance
|
|
775
|
+
specific property.
|
|
776
|
+
"""
|
|
777
|
+
|
|
778
|
+
gated_model_key_env_var_value = (
|
|
779
|
+
"gated_model_env_var_uri" if self._is_hub_content else "gated_model_key_env_var_value"
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value)
|
|
783
|
+
|
|
784
|
+
def get_instance_specific_default_inference_instance_type(
|
|
785
|
+
self, instance_type: str
|
|
786
|
+
) -> Optional[str]:
|
|
787
|
+
"""Returns instance specific default inference instance type.
|
|
788
|
+
|
|
789
|
+
Returns None if a model, instance type tuple does not have instance
|
|
790
|
+
specific inference instance types.
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
return self._get_instance_specific_property(
|
|
794
|
+
instance_type, "default_inference_instance_type"
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
def get_instance_specific_supported_inference_instance_types(
|
|
798
|
+
self, instance_type: str
|
|
799
|
+
) -> List[str]:
|
|
800
|
+
"""Returns instance specific supported inference instance types.
|
|
801
|
+
|
|
802
|
+
Returns empty list if a model, instance type tuple does not have instance
|
|
803
|
+
specific inference instance types.
|
|
804
|
+
"""
|
|
805
|
+
|
|
806
|
+
if self.variants is None:
|
|
807
|
+
return []
|
|
808
|
+
|
|
809
|
+
instance_specific_inference_instance_types: List[str] = (
|
|
810
|
+
self.variants.get(instance_type, {})
|
|
811
|
+
.get("properties", {})
|
|
812
|
+
.get("supported_inference_instance_types", [])
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
816
|
+
|
|
817
|
+
instance_family_inference_instance_types: List[str] = (
|
|
818
|
+
self.variants.get(instance_type_family, {})
|
|
819
|
+
.get("properties", {})
|
|
820
|
+
.get("supported_inference_instance_types", [])
|
|
821
|
+
if instance_type_family not in {"", None}
|
|
822
|
+
else []
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
return sorted(
|
|
826
|
+
list(
|
|
827
|
+
set(
|
|
828
|
+
instance_specific_inference_instance_types
|
|
829
|
+
+ instance_family_inference_instance_types
|
|
830
|
+
)
|
|
831
|
+
)
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
def get_image_uri(self, instance_type: str, region: Optional[str] = None) -> Optional[str]:
|
|
835
|
+
"""Returns image uri from instance type and region.
|
|
836
|
+
|
|
837
|
+
Returns None if no instance type is available or found.
|
|
838
|
+
None is also returned if the metadata is improperly formatted.
|
|
839
|
+
"""
|
|
840
|
+
return self._get_regional_property(
|
|
841
|
+
instance_type=instance_type, region=region, property_name="image_uri"
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str]:
|
|
845
|
+
"""Returns model package arn from instance type and region.
|
|
846
|
+
|
|
847
|
+
Returns None if no instance type is available or found.
|
|
848
|
+
None is also returned if the metadata is improperly formatted.
|
|
849
|
+
"""
|
|
850
|
+
return self._get_regional_property(
|
|
851
|
+
instance_type=instance_type, region=region, property_name="model_package_arn"
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
def _get_regional_property(
|
|
855
|
+
self, instance_type: str, region: Optional[str], property_name: str
|
|
856
|
+
) -> Optional[str]:
|
|
857
|
+
"""Returns regional property from instance type and region.
|
|
858
|
+
|
|
859
|
+
Returns None if no instance type is available or found.
|
|
860
|
+
None is also returned if the metadata is improperly formatted.
|
|
861
|
+
"""
|
|
862
|
+
# pylint: disable=too-many-return-statements
|
|
863
|
+
# if self.variants is None or (self.aliases is None and self.regional_aliases is None):
|
|
864
|
+
# return None
|
|
865
|
+
|
|
866
|
+
if self.variants is None:
|
|
867
|
+
return None
|
|
868
|
+
|
|
869
|
+
if region is None and self.regional_aliases is not None:
|
|
870
|
+
return None
|
|
871
|
+
|
|
872
|
+
regional_property_alias: Optional[str] = None
|
|
873
|
+
regional_property_value: Optional[str] = None
|
|
874
|
+
|
|
875
|
+
if self.regional_aliases:
|
|
876
|
+
regional_property_alias = (
|
|
877
|
+
self.variants.get(instance_type, {})
|
|
878
|
+
.get("regional_properties", {})
|
|
879
|
+
.get(property_name)
|
|
880
|
+
)
|
|
881
|
+
else:
|
|
882
|
+
regional_property_value = (
|
|
883
|
+
self.variants.get(instance_type, {}).get("properties", {}).get(property_name)
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
if regional_property_alias is None and regional_property_value is None:
|
|
887
|
+
instance_type_family = get_instance_type_family(instance_type)
|
|
888
|
+
if instance_type_family in {"", None}:
|
|
889
|
+
return None
|
|
890
|
+
if self.regional_aliases:
|
|
891
|
+
regional_property_alias = (
|
|
892
|
+
self.variants.get(instance_type_family, {})
|
|
893
|
+
.get("regional_properties", {})
|
|
894
|
+
.get(property_name)
|
|
895
|
+
)
|
|
896
|
+
else:
|
|
897
|
+
# if reading from HubContent, aliases are already regionalized
|
|
898
|
+
regional_property_value = (
|
|
899
|
+
self.variants.get(instance_type_family, {})
|
|
900
|
+
.get("properties", {})
|
|
901
|
+
.get(property_name)
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
if (regional_property_alias is None or len(regional_property_alias) == 0) and (
|
|
905
|
+
regional_property_value is None or len(regional_property_value) == 0
|
|
906
|
+
):
|
|
907
|
+
return None
|
|
908
|
+
|
|
909
|
+
if regional_property_alias and not regional_property_alias.startswith("$"):
|
|
910
|
+
# No leading '$' indicates bad metadata.
|
|
911
|
+
# There are tests to ensure this never happens.
|
|
912
|
+
# However, to allow for fallback options in the unlikely event
|
|
913
|
+
# of a regression, we do not raise an exception here.
|
|
914
|
+
# We return None, indicating the field does not exist.
|
|
915
|
+
return None
|
|
916
|
+
|
|
917
|
+
if self.regional_aliases and region not in self.regional_aliases:
|
|
918
|
+
return None
|
|
919
|
+
|
|
920
|
+
if self.regional_aliases:
|
|
921
|
+
alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None)
|
|
922
|
+
return alias_value
|
|
923
|
+
return regional_property_value
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
class JumpStartAdditionalDataSources(JumpStartDataHolderType):
|
|
927
|
+
"""Data class of additional data sources."""
|
|
928
|
+
|
|
929
|
+
__slots__ = ["speculative_decoding", "scripts"]
|
|
930
|
+
|
|
931
|
+
def __init__(self, spec: Dict[str, Any]):
|
|
932
|
+
"""Initializes a AdditionalDataSources object.
|
|
933
|
+
|
|
934
|
+
Args:
|
|
935
|
+
spec (Dict[str, Any]): Dictionary representation of data source.
|
|
936
|
+
"""
|
|
937
|
+
self.from_json(spec)
|
|
938
|
+
|
|
939
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
940
|
+
"""Sets fields in object based on json.
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
json_obj (Dict[str, Any]): Dictionary representation of data source.
|
|
944
|
+
"""
|
|
945
|
+
self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = (
|
|
946
|
+
[
|
|
947
|
+
JumpStartModelDataSource(data_source)
|
|
948
|
+
for data_source in json_obj["speculative_decoding"]
|
|
949
|
+
]
|
|
950
|
+
if json_obj.get("speculative_decoding")
|
|
951
|
+
else None
|
|
952
|
+
)
|
|
953
|
+
self.scripts: Optional[List[JumpStartModelDataSource]] = (
|
|
954
|
+
[JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]]
|
|
955
|
+
if json_obj.get("scripts")
|
|
956
|
+
else None
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
def to_json(self) -> Dict[str, Any]:
|
|
960
|
+
"""Returns json representation of AdditionalDataSources object."""
|
|
961
|
+
json_obj = {}
|
|
962
|
+
for att in self.__slots__:
|
|
963
|
+
if hasattr(self, att):
|
|
964
|
+
cur_val = getattr(self, att)
|
|
965
|
+
if isinstance(cur_val, list):
|
|
966
|
+
json_obj[att] = []
|
|
967
|
+
for obj in cur_val:
|
|
968
|
+
if issubclass(type(obj), JumpStartDataHolderType):
|
|
969
|
+
json_obj[att].append(obj.to_json())
|
|
970
|
+
else:
|
|
971
|
+
json_obj[att].append(obj)
|
|
972
|
+
else:
|
|
973
|
+
json_obj[att] = cur_val
|
|
974
|
+
return json_obj
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
class ModelAccessConfig(JumpStartDataHolderType):
|
|
978
|
+
"""Data class of model access config that mirrors CreateModel API."""
|
|
979
|
+
|
|
980
|
+
__slots__ = ["accept_eula"]
|
|
981
|
+
|
|
982
|
+
def __init__(self, spec: Dict[str, Any]):
|
|
983
|
+
"""Initializes a ModelAccessConfig object.
|
|
984
|
+
|
|
985
|
+
Args:
|
|
986
|
+
spec (Dict[str, Any]): Dictionary representation of data source.
|
|
987
|
+
"""
|
|
988
|
+
self.from_json(spec)
|
|
989
|
+
|
|
990
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
991
|
+
"""Sets fields in object based on json.
|
|
992
|
+
|
|
993
|
+
Args:
|
|
994
|
+
json_obj (Dict[str, Any]): Dictionary representation of data source.
|
|
995
|
+
"""
|
|
996
|
+
self.accept_eula: bool = json_obj["accept_eula"]
|
|
997
|
+
|
|
998
|
+
def to_json(self) -> Dict[str, Any]:
|
|
999
|
+
"""Returns json representation of ModelAccessConfig object."""
|
|
1000
|
+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
|
|
1001
|
+
return json_obj
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
class HubAccessConfig(JumpStartDataHolderType):
|
|
1005
|
+
"""Data class of model access config that mirrors CreateModel API."""
|
|
1006
|
+
|
|
1007
|
+
__slots__ = ["hub_content_arn"]
|
|
1008
|
+
|
|
1009
|
+
def __init__(self, spec: Dict[str, Any]):
|
|
1010
|
+
"""Initializes a HubAccessConfig object.
|
|
1011
|
+
|
|
1012
|
+
Args:
|
|
1013
|
+
spec (Dict[str, Any]): Dictionary representation of data source.
|
|
1014
|
+
"""
|
|
1015
|
+
self.from_json(spec)
|
|
1016
|
+
|
|
1017
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1018
|
+
"""Sets fields in object based on json.
|
|
1019
|
+
|
|
1020
|
+
Args:
|
|
1021
|
+
json_obj (Dict[str, Any]): Dictionary representation of data source.
|
|
1022
|
+
"""
|
|
1023
|
+
self.hub_content_arn: bool = json_obj["accept_eula"]
|
|
1024
|
+
|
|
1025
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1026
|
+
"""Returns json representation of ModelAccessConfig object."""
|
|
1027
|
+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
|
|
1028
|
+
return json_obj
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
class S3DataSource(JumpStartDataHolderType):
|
|
1032
|
+
"""Data class of S3 data source that mirrors CreateModel API."""
|
|
1033
|
+
|
|
1034
|
+
__slots__ = [
|
|
1035
|
+
"compression_type",
|
|
1036
|
+
"s3_data_type",
|
|
1037
|
+
"s3_uri",
|
|
1038
|
+
"model_access_config",
|
|
1039
|
+
"hub_access_config",
|
|
1040
|
+
]
|
|
1041
|
+
|
|
1042
|
+
def __init__(self, spec: Dict[str, Any]):
|
|
1043
|
+
"""Initializes a S3DataSource object.
|
|
1044
|
+
|
|
1045
|
+
Args:
|
|
1046
|
+
spec (Dict[str, Any]): Dictionary representation of data source.
|
|
1047
|
+
"""
|
|
1048
|
+
self.from_json(spec)
|
|
1049
|
+
|
|
1050
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1051
|
+
"""Sets fields in object based on json.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
json_obj (Dict[str, Any]): Dictionary representation of data source.
|
|
1055
|
+
"""
|
|
1056
|
+
self.compression_type: str = json_obj["compression_type"]
|
|
1057
|
+
self.s3_data_type: str = json_obj["s3_data_type"]
|
|
1058
|
+
self.s3_uri: str = json_obj["s3_uri"]
|
|
1059
|
+
self.model_access_config: ModelAccessConfig = (
|
|
1060
|
+
ModelAccessConfig(json_obj["model_access_config"])
|
|
1061
|
+
if json_obj.get("model_access_config")
|
|
1062
|
+
else None
|
|
1063
|
+
)
|
|
1064
|
+
self.hub_access_config: HubAccessConfig = (
|
|
1065
|
+
HubAccessConfig(json_obj["hub_access_config"])
|
|
1066
|
+
if json_obj.get("hub_access_config")
|
|
1067
|
+
else None
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1071
|
+
"""Returns json representation of S3DataSource object."""
|
|
1072
|
+
json_obj = {}
|
|
1073
|
+
for att in self.__slots__:
|
|
1074
|
+
if hasattr(self, att):
|
|
1075
|
+
cur_val = getattr(self, att)
|
|
1076
|
+
if issubclass(type(cur_val), JumpStartDataHolderType):
|
|
1077
|
+
json_obj[att] = cur_val.to_json()
|
|
1078
|
+
elif cur_val:
|
|
1079
|
+
json_obj[att] = cur_val
|
|
1080
|
+
return json_obj
|
|
1081
|
+
|
|
1082
|
+
def set_bucket(self, bucket: str) -> None:
|
|
1083
|
+
"""Sets bucket name from S3 URI."""
|
|
1084
|
+
|
|
1085
|
+
if self.s3_uri.startswith(S3_PREFIX):
|
|
1086
|
+
s3_path = self.s3_uri[len(S3_PREFIX) :]
|
|
1087
|
+
old_bucket = s3_path.split("/")[0]
|
|
1088
|
+
key = s3_path[len(old_bucket) :]
|
|
1089
|
+
self.s3_uri = f"{S3_PREFIX}{bucket}{key}" # pylint: disable=W0201
|
|
1090
|
+
return
|
|
1091
|
+
|
|
1092
|
+
if not bucket.endswith("/"):
|
|
1093
|
+
bucket += "/"
|
|
1094
|
+
|
|
1095
|
+
self.s3_uri = f"{S3_PREFIX}{bucket}{self.s3_uri}" # pylint: disable=W0201
|
|
1096
|
+
|
|
1097
|
+
|
|
1098
|
+
class AdditionalModelDataSource(JumpStartDataHolderType):
|
|
1099
|
+
"""Data class of additional model data source mirrors CreateModel API."""
|
|
1100
|
+
|
|
1101
|
+
SERIALIZATION_EXCLUSION_SET = {"provider"}
|
|
1102
|
+
|
|
1103
|
+
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]
|
|
1104
|
+
|
|
1105
|
+
def __init__(self, spec: Dict[str, Any]):
|
|
1106
|
+
"""Initializes a AdditionalModelDataSource object.
|
|
1107
|
+
|
|
1108
|
+
Args:
|
|
1109
|
+
spec (Dict[str, Any]): Dictionary representation of data source.
|
|
1110
|
+
"""
|
|
1111
|
+
self.from_json(spec)
|
|
1112
|
+
|
|
1113
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1114
|
+
"""Sets fields in object based on json.
|
|
1115
|
+
|
|
1116
|
+
Args:
|
|
1117
|
+
json_obj (Dict[str, Any]): Dictionary representation of data source.
|
|
1118
|
+
"""
|
|
1119
|
+
self.channel_name: str = json_obj["channel_name"]
|
|
1120
|
+
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
|
|
1121
|
+
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
|
|
1122
|
+
self.provider: Dict = json_obj.get("provider", {})
|
|
1123
|
+
|
|
1124
|
+
def to_json(self, exclude_keys=True) -> Dict[str, Any]:
|
|
1125
|
+
"""Returns json representation of AdditionalModelDataSource object."""
|
|
1126
|
+
json_obj = {}
|
|
1127
|
+
for att in self.__slots__:
|
|
1128
|
+
if hasattr(self, att):
|
|
1129
|
+
if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys:
|
|
1130
|
+
cur_val = getattr(self, att)
|
|
1131
|
+
if issubclass(type(cur_val), JumpStartDataHolderType):
|
|
1132
|
+
json_obj[att] = cur_val.to_json()
|
|
1133
|
+
else:
|
|
1134
|
+
json_obj[att] = cur_val
|
|
1135
|
+
return json_obj
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
class JumpStartModelDataSource(AdditionalModelDataSource):
|
|
1139
|
+
"""Data class JumpStart additional model data source."""
|
|
1140
|
+
|
|
1141
|
+
SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
|
|
1142
|
+
{"artifact_version"}
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__
|
|
1146
|
+
|
|
1147
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1148
|
+
"""Sets fields in object based on json.
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
json_obj (Dict[str, Any]): Dictionary representation of data source.
|
|
1152
|
+
"""
|
|
1153
|
+
super().from_json(json_obj)
|
|
1154
|
+
self.artifact_version: str = json_obj["artifact_version"]
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
class JumpStartBenchmarkStat(JumpStartDataHolderType):
|
|
1158
|
+
"""Data class JumpStart benchmark stat."""
|
|
1159
|
+
|
|
1160
|
+
__slots__ = ["name", "value", "unit", "concurrency"]
|
|
1161
|
+
|
|
1162
|
+
def __init__(self, spec: Dict[str, Any]):
|
|
1163
|
+
"""Initializes a JumpStartBenchmarkStat object.
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
spec (Dict[str, Any]): Dictionary representation of benchmark stat.
|
|
1167
|
+
"""
|
|
1168
|
+
self.from_json(spec)
|
|
1169
|
+
|
|
1170
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1171
|
+
"""Sets fields in object based on json.
|
|
1172
|
+
|
|
1173
|
+
Args:
|
|
1174
|
+
json_obj (Dict[str, Any]): Dictionary representation of benchmark stats.
|
|
1175
|
+
"""
|
|
1176
|
+
self.name: str = json_obj["name"]
|
|
1177
|
+
self.value: str = json_obj["value"]
|
|
1178
|
+
self.unit: Union[int, str] = json_obj["unit"]
|
|
1179
|
+
self.concurrency: Union[int, str] = json_obj["concurrency"]
|
|
1180
|
+
|
|
1181
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1182
|
+
"""Returns json representation of JumpStartBenchmarkStat object."""
|
|
1183
|
+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
|
|
1184
|
+
return json_obj
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
class JumpStartConfigRanking(JumpStartDataHolderType):
|
|
1188
|
+
"""Data class JumpStart config ranking."""
|
|
1189
|
+
|
|
1190
|
+
__slots__ = ["description", "rankings"]
|
|
1191
|
+
|
|
1192
|
+
def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
|
|
1193
|
+
"""Initializes a JumpStartConfigRanking object.
|
|
1194
|
+
|
|
1195
|
+
Args:
|
|
1196
|
+
spec (Dict[str, Any]): Dictionary representation of training config ranking.
|
|
1197
|
+
"""
|
|
1198
|
+
if is_hub_content:
|
|
1199
|
+
spec = walk_and_apply_json(spec, camel_to_snake)
|
|
1200
|
+
self.from_json(spec)
|
|
1201
|
+
|
|
1202
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1203
|
+
"""Sets fields in object based on json.
|
|
1204
|
+
|
|
1205
|
+
Args:
|
|
1206
|
+
json_obj (Dict[str, Any]): Dictionary representation of config ranking.
|
|
1207
|
+
"""
|
|
1208
|
+
self.description: str = json_obj["description"]
|
|
1209
|
+
self.rankings: List[str] = json_obj["rankings"]
|
|
1210
|
+
|
|
1211
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1212
|
+
"""Returns json representation of JumpStartConfigRanking object."""
|
|
1213
|
+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
|
|
1214
|
+
return json_obj
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
class JumpStartMetadataBaseFields(JumpStartDataHolderType):
|
|
1218
|
+
"""Data class JumpStart metadata base fields that can be overridden."""
|
|
1219
|
+
|
|
1220
|
+
__slots__ = [
|
|
1221
|
+
"model_id",
|
|
1222
|
+
"url",
|
|
1223
|
+
"version",
|
|
1224
|
+
"min_sdk_version",
|
|
1225
|
+
"model_types",
|
|
1226
|
+
"capabilities",
|
|
1227
|
+
"incremental_training_supported",
|
|
1228
|
+
"hosting_ecr_specs",
|
|
1229
|
+
"hosting_ecr_uri",
|
|
1230
|
+
"hosting_artifact_uri",
|
|
1231
|
+
"hosting_artifact_key",
|
|
1232
|
+
"hosting_script_key",
|
|
1233
|
+
"training_supported",
|
|
1234
|
+
"training_ecr_specs",
|
|
1235
|
+
"training_ecr_uri",
|
|
1236
|
+
"training_artifact_key",
|
|
1237
|
+
"training_script_key",
|
|
1238
|
+
"hyperparameters",
|
|
1239
|
+
"inference_environment_variables",
|
|
1240
|
+
"inference_vulnerable",
|
|
1241
|
+
"inference_dependencies",
|
|
1242
|
+
"inference_vulnerabilities",
|
|
1243
|
+
"training_vulnerable",
|
|
1244
|
+
"training_dependencies",
|
|
1245
|
+
"training_vulnerabilities",
|
|
1246
|
+
"deprecated",
|
|
1247
|
+
"usage_info_message",
|
|
1248
|
+
"deprecated_message",
|
|
1249
|
+
"deprecate_warn_message",
|
|
1250
|
+
"default_inference_instance_type",
|
|
1251
|
+
"supported_inference_instance_types",
|
|
1252
|
+
"dynamic_container_deployment_supported",
|
|
1253
|
+
"hosting_resource_requirements",
|
|
1254
|
+
"default_training_instance_type",
|
|
1255
|
+
"supported_training_instance_types",
|
|
1256
|
+
"metrics",
|
|
1257
|
+
"training_prepacked_script_key",
|
|
1258
|
+
"training_prepacked_script_version",
|
|
1259
|
+
"hosting_prepacked_artifact_key",
|
|
1260
|
+
"hosting_prepacked_artifact_version",
|
|
1261
|
+
"model_kwargs",
|
|
1262
|
+
"deploy_kwargs",
|
|
1263
|
+
"estimator_kwargs",
|
|
1264
|
+
"fit_kwargs",
|
|
1265
|
+
"predictor_specs",
|
|
1266
|
+
"inference_volume_size",
|
|
1267
|
+
"training_volume_size",
|
|
1268
|
+
"inference_enable_network_isolation",
|
|
1269
|
+
"training_enable_network_isolation",
|
|
1270
|
+
"resource_name_base",
|
|
1271
|
+
"hosting_eula_key",
|
|
1272
|
+
"hosting_model_package_arns",
|
|
1273
|
+
"training_model_package_artifact_uris",
|
|
1274
|
+
"hosting_use_script_uri",
|
|
1275
|
+
"hosting_instance_type_variants",
|
|
1276
|
+
"training_instance_type_variants",
|
|
1277
|
+
"default_payloads",
|
|
1278
|
+
"gated_bucket",
|
|
1279
|
+
"model_subscription_link",
|
|
1280
|
+
"hosting_additional_data_sources",
|
|
1281
|
+
"hosting_neuron_model_id",
|
|
1282
|
+
"hosting_neuron_model_version",
|
|
1283
|
+
"hub_content_type",
|
|
1284
|
+
"_is_hub_content",
|
|
1285
|
+
"default_training_dataset_key",
|
|
1286
|
+
"default_training_dataset_uri",
|
|
1287
|
+
]
|
|
1288
|
+
|
|
1289
|
+
_non_serializable_slots = ["_is_hub_content"]
|
|
1290
|
+
|
|
1291
|
+
def __init__(self, fields: Dict[str, Any], is_hub_content: Optional[bool] = False):
|
|
1292
|
+
"""Initializes a JumpStartMetadataFields object.
|
|
1293
|
+
|
|
1294
|
+
Args:
|
|
1295
|
+
fields (Dict[str, Any]): Dictionary representation of metadata fields.
|
|
1296
|
+
"""
|
|
1297
|
+
self._is_hub_content = is_hub_content
|
|
1298
|
+
self.from_json(fields)
|
|
1299
|
+
|
|
1300
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1301
|
+
"""Sets fields in object based on json of header.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
json_obj (Dict[str, Any]): Dictionary representation of spec.
|
|
1305
|
+
"""
|
|
1306
|
+
if self._is_hub_content:
|
|
1307
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
1308
|
+
self.model_id: str = json_obj.get("model_id")
|
|
1309
|
+
self.url: str = json_obj.get("url")
|
|
1310
|
+
self.version: str = json_obj.get("version")
|
|
1311
|
+
self.min_sdk_version: str = json_obj.get("min_sdk_version")
|
|
1312
|
+
self.incremental_training_supported: bool = bool(
|
|
1313
|
+
json_obj.get("incremental_training_supported", False)
|
|
1314
|
+
)
|
|
1315
|
+
if self._is_hub_content:
|
|
1316
|
+
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
|
|
1317
|
+
self.model_types: Optional[List[str]] = json_obj.get("model_types")
|
|
1318
|
+
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
|
|
1319
|
+
self._non_serializable_slots.append("hosting_ecr_specs")
|
|
1320
|
+
else:
|
|
1321
|
+
self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = (
|
|
1322
|
+
JumpStartECRSpecs(
|
|
1323
|
+
json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content
|
|
1324
|
+
)
|
|
1325
|
+
if "hosting_ecr_specs" in json_obj
|
|
1326
|
+
else None
|
|
1327
|
+
)
|
|
1328
|
+
self._non_serializable_slots.append("hosting_ecr_uri")
|
|
1329
|
+
self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key")
|
|
1330
|
+
self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri")
|
|
1331
|
+
self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key")
|
|
1332
|
+
self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False))
|
|
1333
|
+
self.inference_environment_variables = [
|
|
1334
|
+
JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content)
|
|
1335
|
+
for env_variable in json_obj.get("inference_environment_variables", [])
|
|
1336
|
+
]
|
|
1337
|
+
self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False))
|
|
1338
|
+
self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", [])
|
|
1339
|
+
self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", [])
|
|
1340
|
+
self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False))
|
|
1341
|
+
self.training_dependencies: List[str] = json_obj.get("training_dependencies", [])
|
|
1342
|
+
self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", [])
|
|
1343
|
+
self.deprecated: bool = bool(json_obj.get("deprecated", False))
|
|
1344
|
+
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
|
|
1345
|
+
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
|
|
1346
|
+
self.usage_info_message: Optional[str] = json_obj.get("usage_info_message")
|
|
1347
|
+
self.default_inference_instance_type: Optional[str] = json_obj.get(
|
|
1348
|
+
"default_inference_instance_type"
|
|
1349
|
+
)
|
|
1350
|
+
self.default_training_instance_type: Optional[str] = json_obj.get(
|
|
1351
|
+
"default_training_instance_type"
|
|
1352
|
+
)
|
|
1353
|
+
self.supported_inference_instance_types: Optional[List[str]] = json_obj.get(
|
|
1354
|
+
"supported_inference_instance_types"
|
|
1355
|
+
)
|
|
1356
|
+
self.supported_training_instance_types: Optional[List[str]] = json_obj.get(
|
|
1357
|
+
"supported_training_instance_types"
|
|
1358
|
+
)
|
|
1359
|
+
self.dynamic_container_deployment_supported: Optional[bool] = bool(
|
|
1360
|
+
json_obj.get("dynamic_container_deployment_supported")
|
|
1361
|
+
)
|
|
1362
|
+
self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get(
|
|
1363
|
+
"hosting_resource_requirements", None
|
|
1364
|
+
)
|
|
1365
|
+
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
|
|
1366
|
+
self.training_prepacked_script_key: Optional[str] = json_obj.get(
|
|
1367
|
+
"training_prepacked_script_key", None
|
|
1368
|
+
)
|
|
1369
|
+
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
|
|
1370
|
+
"hosting_prepacked_artifact_key", None
|
|
1371
|
+
)
|
|
1372
|
+
# New fields required for Hub model.
|
|
1373
|
+
if self._is_hub_content:
|
|
1374
|
+
self.training_prepacked_script_version: Optional[str] = json_obj.get(
|
|
1375
|
+
"training_prepacked_script_version"
|
|
1376
|
+
)
|
|
1377
|
+
self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get(
|
|
1378
|
+
"hosting_prepacked_artifact_version"
|
|
1379
|
+
)
|
|
1380
|
+
self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {}))
|
|
1381
|
+
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
|
|
1382
|
+
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
|
|
1383
|
+
JumpStartPredictorSpecs(
|
|
1384
|
+
json_obj.get("predictor_specs"),
|
|
1385
|
+
is_hub_content=self._is_hub_content,
|
|
1386
|
+
)
|
|
1387
|
+
if json_obj.get("predictor_specs")
|
|
1388
|
+
else None
|
|
1389
|
+
)
|
|
1390
|
+
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
|
|
1391
|
+
{
|
|
1392
|
+
alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content)
|
|
1393
|
+
for alias, payload in json_obj["default_payloads"].items()
|
|
1394
|
+
}
|
|
1395
|
+
if json_obj.get("default_payloads")
|
|
1396
|
+
else None
|
|
1397
|
+
)
|
|
1398
|
+
self.gated_bucket = json_obj.get("gated_bucket", False)
|
|
1399
|
+
self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
|
|
1400
|
+
self.inference_enable_network_isolation: bool = json_obj.get(
|
|
1401
|
+
"inference_enable_network_isolation", False
|
|
1402
|
+
)
|
|
1403
|
+
self.resource_name_base: bool = json_obj.get("resource_name_base")
|
|
1404
|
+
|
|
1405
|
+
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
|
|
1406
|
+
|
|
1407
|
+
model_package_arns = json_obj.get("hosting_model_package_arns")
|
|
1408
|
+
self.hosting_model_package_arns: Optional[Dict] = (
|
|
1409
|
+
model_package_arns if model_package_arns is not None else {}
|
|
1410
|
+
)
|
|
1411
|
+
|
|
1412
|
+
self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)
|
|
1413
|
+
|
|
1414
|
+
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
|
|
1415
|
+
JumpStartInstanceTypeVariants(
|
|
1416
|
+
json_obj["hosting_instance_type_variants"], self._is_hub_content
|
|
1417
|
+
)
|
|
1418
|
+
if json_obj.get("hosting_instance_type_variants")
|
|
1419
|
+
else None
|
|
1420
|
+
)
|
|
1421
|
+
self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = (
|
|
1422
|
+
JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"])
|
|
1423
|
+
if json_obj.get("hosting_additional_data_sources")
|
|
1424
|
+
else None
|
|
1425
|
+
)
|
|
1426
|
+
self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id")
|
|
1427
|
+
self.hosting_neuron_model_version: Optional[str] = json_obj.get(
|
|
1428
|
+
"hosting_neuron_model_version"
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1431
|
+
if self.training_supported:
|
|
1432
|
+
if self._is_hub_content:
|
|
1433
|
+
self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri")
|
|
1434
|
+
self._non_serializable_slots.append("training_ecr_specs")
|
|
1435
|
+
else:
|
|
1436
|
+
self.training_ecr_specs: Optional[JumpStartECRSpecs] = (
|
|
1437
|
+
JumpStartECRSpecs(json_obj["training_ecr_specs"])
|
|
1438
|
+
if "training_ecr_specs" in json_obj
|
|
1439
|
+
else None
|
|
1440
|
+
)
|
|
1441
|
+
self._non_serializable_slots.append("training_ecr_uri")
|
|
1442
|
+
self.training_artifact_key: str = json_obj["training_artifact_key"]
|
|
1443
|
+
self.training_script_key: str = json_obj["training_script_key"]
|
|
1444
|
+
hyperparameters: Any = json_obj.get("hyperparameters")
|
|
1445
|
+
self.hyperparameters: List[JumpStartHyperparameter] = []
|
|
1446
|
+
if hyperparameters is not None:
|
|
1447
|
+
self.hyperparameters.extend(
|
|
1448
|
+
[
|
|
1449
|
+
JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content)
|
|
1450
|
+
for hyperparameter in hyperparameters
|
|
1451
|
+
]
|
|
1452
|
+
)
|
|
1453
|
+
self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {}))
|
|
1454
|
+
self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {}))
|
|
1455
|
+
self.training_volume_size: Optional[int] = json_obj.get("training_volume_size")
|
|
1456
|
+
self.training_enable_network_isolation: bool = json_obj.get(
|
|
1457
|
+
"training_enable_network_isolation", False
|
|
1458
|
+
)
|
|
1459
|
+
self.training_model_package_artifact_uris: Optional[Dict] = json_obj.get(
|
|
1460
|
+
"training_model_package_artifact_uris"
|
|
1461
|
+
)
|
|
1462
|
+
self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
|
|
1463
|
+
JumpStartInstanceTypeVariants(
|
|
1464
|
+
json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content
|
|
1465
|
+
)
|
|
1466
|
+
if json_obj.get("training_instance_type_variants")
|
|
1467
|
+
else None
|
|
1468
|
+
)
|
|
1469
|
+
self.model_subscription_link = json_obj.get("model_subscription_link")
|
|
1470
|
+
self.default_training_dataset_key: Optional[str] = json_obj.get(
|
|
1471
|
+
"default_training_dataset_key"
|
|
1472
|
+
)
|
|
1473
|
+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
|
|
1474
|
+
"default_training_dataset_uri"
|
|
1475
|
+
)
|
|
1476
|
+
|
|
1477
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1478
|
+
"""Returns json representation of JumpStartMetadataBaseFields object."""
|
|
1479
|
+
json_obj = {}
|
|
1480
|
+
for att in self.__slots__:
|
|
1481
|
+
if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []):
|
|
1482
|
+
cur_val = getattr(self, att)
|
|
1483
|
+
if issubclass(type(cur_val), JumpStartDataHolderType):
|
|
1484
|
+
json_obj[att] = cur_val.to_json()
|
|
1485
|
+
elif isinstance(cur_val, list):
|
|
1486
|
+
json_obj[att] = []
|
|
1487
|
+
for obj in cur_val:
|
|
1488
|
+
if issubclass(type(obj), JumpStartDataHolderType):
|
|
1489
|
+
json_obj[att].append(obj.to_json())
|
|
1490
|
+
else:
|
|
1491
|
+
json_obj[att].append(obj)
|
|
1492
|
+
elif isinstance(cur_val, dict):
|
|
1493
|
+
json_obj[att] = {}
|
|
1494
|
+
for key, val in cur_val.items():
|
|
1495
|
+
if issubclass(type(val), JumpStartDataHolderType):
|
|
1496
|
+
json_obj[att][key] = val.to_json()
|
|
1497
|
+
else:
|
|
1498
|
+
json_obj[att][key] = val
|
|
1499
|
+
else:
|
|
1500
|
+
json_obj[att] = cur_val
|
|
1501
|
+
return json_obj
|
|
1502
|
+
|
|
1503
|
+
def set_hub_content_type(self, hub_content_type: HubContentType) -> None:
|
|
1504
|
+
"""Sets the hub content type."""
|
|
1505
|
+
if self._is_hub_content:
|
|
1506
|
+
self.hub_content_type = hub_content_type
|
|
1507
|
+
|
|
1508
|
+
|
|
1509
|
+
class JumpStartConfigComponent(JumpStartMetadataBaseFields):
|
|
1510
|
+
"""Data class of JumpStart config component."""
|
|
1511
|
+
|
|
1512
|
+
slots = ["component_name"]
|
|
1513
|
+
|
|
1514
|
+
# List of fields that is not allowed to override to JumpStartMetadataBaseFields
|
|
1515
|
+
OVERRIDING_DENY_LIST = [
|
|
1516
|
+
"model_id",
|
|
1517
|
+
"url",
|
|
1518
|
+
"version",
|
|
1519
|
+
"min_sdk_version",
|
|
1520
|
+
"deprecated",
|
|
1521
|
+
"deprecated_message",
|
|
1522
|
+
"deprecate_warn_message",
|
|
1523
|
+
"resource_name_base",
|
|
1524
|
+
"gated_bucket",
|
|
1525
|
+
"training_supported",
|
|
1526
|
+
"incremental_training_supported",
|
|
1527
|
+
]
|
|
1528
|
+
|
|
1529
|
+
# Map of HubContent fields that map to custom names in MetadataBaseFields
|
|
1530
|
+
CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"}
|
|
1531
|
+
|
|
1532
|
+
__slots__ = slots + JumpStartMetadataBaseFields.__slots__
|
|
1533
|
+
|
|
1534
|
+
def __init__(
|
|
1535
|
+
self, component_name: str, component: Optional[Dict[str, Any]], is_hub_content=False
|
|
1536
|
+
):
|
|
1537
|
+
"""Initializes a JumpStartConfigComponent object from its json representation.
|
|
1538
|
+
|
|
1539
|
+
Args:
|
|
1540
|
+
component_name (str): Name of the component.
|
|
1541
|
+
component (Dict[str, Any]):
|
|
1542
|
+
Dictionary representation of the config component.
|
|
1543
|
+
Raises:
|
|
1544
|
+
ValueError: If the component field is invalid.
|
|
1545
|
+
"""
|
|
1546
|
+
if is_hub_content:
|
|
1547
|
+
component = walk_and_apply_json(component, camel_to_snake)
|
|
1548
|
+
self.component_name = component_name
|
|
1549
|
+
super().__init__(component, is_hub_content)
|
|
1550
|
+
self.from_json(component)
|
|
1551
|
+
|
|
1552
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1553
|
+
"""Initializes a JumpStartConfigComponent object from its json representation.
|
|
1554
|
+
|
|
1555
|
+
Args:
|
|
1556
|
+
json_obj (Dict[str, Any]):
|
|
1557
|
+
Dictionary representation of the config component.
|
|
1558
|
+
"""
|
|
1559
|
+
for field in json_obj.keys():
|
|
1560
|
+
if field in self.__slots__:
|
|
1561
|
+
setattr(self, field, json_obj[field])
|
|
1562
|
+
|
|
1563
|
+
# Handle custom fields
|
|
1564
|
+
for custom_field, field in self.CUSTOM_FIELD_MAP.items():
|
|
1565
|
+
if custom_field in json_obj:
|
|
1566
|
+
setattr(self, field, json_obj.get(custom_field))
|
|
1567
|
+
|
|
1568
|
+
|
|
1569
|
+
class JumpStartMetadataConfig(JumpStartDataHolderType):
|
|
1570
|
+
"""Data class of JumpStart metadata config."""
|
|
1571
|
+
|
|
1572
|
+
__slots__ = [
|
|
1573
|
+
"base_fields",
|
|
1574
|
+
"benchmark_metrics",
|
|
1575
|
+
"acceleration_configs",
|
|
1576
|
+
"config_components",
|
|
1577
|
+
"resolved_metadata_config",
|
|
1578
|
+
"config_name",
|
|
1579
|
+
"default_inference_config",
|
|
1580
|
+
"default_incremental_training_config",
|
|
1581
|
+
"supported_inference_configs",
|
|
1582
|
+
"supported_incremental_training_configs",
|
|
1583
|
+
]
|
|
1584
|
+
|
|
1585
|
+
def __init__(
|
|
1586
|
+
self,
|
|
1587
|
+
config_name: str,
|
|
1588
|
+
config: Dict[str, Any],
|
|
1589
|
+
base_fields: Dict[str, Any],
|
|
1590
|
+
config_components: Dict[str, JumpStartConfigComponent],
|
|
1591
|
+
is_hub_content=False,
|
|
1592
|
+
):
|
|
1593
|
+
"""Initializes a JumpStartMetadataConfig object from its json representation.
|
|
1594
|
+
|
|
1595
|
+
Args:
|
|
1596
|
+
config_name (str): Name of the config,
|
|
1597
|
+
config (Dict[str, Any]):
|
|
1598
|
+
Dictionary representation of the config.
|
|
1599
|
+
base_fields (Dict[str, Any]):
|
|
1600
|
+
The default base fields that are used to construct the resolved config.
|
|
1601
|
+
config_components (Dict[str, JumpStartConfigComponent]):
|
|
1602
|
+
The list of components that are used to construct the resolved config.
|
|
1603
|
+
"""
|
|
1604
|
+
if is_hub_content:
|
|
1605
|
+
config = walk_and_apply_json(config, camel_to_snake)
|
|
1606
|
+
base_fields = walk_and_apply_json(base_fields, camel_to_snake)
|
|
1607
|
+
self.base_fields = base_fields
|
|
1608
|
+
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
|
|
1609
|
+
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
|
|
1610
|
+
{
|
|
1611
|
+
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
|
|
1612
|
+
for stat_name, stats in config.get("benchmark_metrics").items()
|
|
1613
|
+
}
|
|
1614
|
+
if config and config.get("benchmark_metrics")
|
|
1615
|
+
else None
|
|
1616
|
+
)
|
|
1617
|
+
self.acceleration_configs = config.get("acceleration_configs")
|
|
1618
|
+
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
|
|
1619
|
+
self.config_name: Optional[str] = config_name
|
|
1620
|
+
self.default_inference_config: Optional[str] = config.get("default_inference_config")
|
|
1621
|
+
self.default_incremental_training_config: Optional[str] = config.get(
|
|
1622
|
+
"default_incremental_training_config"
|
|
1623
|
+
)
|
|
1624
|
+
self.supported_inference_configs: Optional[List[str]] = config.get(
|
|
1625
|
+
"supported_inference_configs"
|
|
1626
|
+
)
|
|
1627
|
+
self.supported_incremental_training_configs: Optional[List[str]] = config.get(
|
|
1628
|
+
"supported_incremental_training_configs"
|
|
1629
|
+
)
|
|
1630
|
+
|
|
1631
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1632
|
+
"""Returns json representation of JumpStartMetadataConfig object."""
|
|
1633
|
+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
|
|
1634
|
+
return json_obj
|
|
1635
|
+
|
|
1636
|
+
@property
|
|
1637
|
+
def resolved_config(self) -> Dict[str, Any]:
|
|
1638
|
+
"""Returns the final config that is resolved from the components map.
|
|
1639
|
+
|
|
1640
|
+
Construct the final config by applying the list of configs from list index,
|
|
1641
|
+
and apply to the base default fields in the current model specs.
|
|
1642
|
+
"""
|
|
1643
|
+
if self.resolved_metadata_config:
|
|
1644
|
+
return self.resolved_metadata_config
|
|
1645
|
+
|
|
1646
|
+
resolved_config = JumpStartMetadataBaseFields(self.base_fields)
|
|
1647
|
+
for component in self.config_components.values():
|
|
1648
|
+
resolved_config = deep_override_dict(
|
|
1649
|
+
deepcopy(resolved_config.to_json()),
|
|
1650
|
+
deepcopy(component.to_json()),
|
|
1651
|
+
component.OVERRIDING_DENY_LIST,
|
|
1652
|
+
)
|
|
1653
|
+
|
|
1654
|
+
# Remove environment variables from resolved config if using model packages
|
|
1655
|
+
hosting_model_pacakge_arns = resolved_config.get("hosting_model_package_arns")
|
|
1656
|
+
if hosting_model_pacakge_arns is not None and hosting_model_pacakge_arns != {}:
|
|
1657
|
+
resolved_config["inference_environment_variables"] = []
|
|
1658
|
+
|
|
1659
|
+
self.resolved_metadata_config = resolved_config
|
|
1660
|
+
|
|
1661
|
+
return resolved_config
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
class JumpStartMetadataConfigs(JumpStartDataHolderType):
|
|
1665
|
+
"""Data class to hold the set of JumpStart Metadata configs."""
|
|
1666
|
+
|
|
1667
|
+
__slots__ = ["configs", "config_rankings", "scope"]
|
|
1668
|
+
|
|
1669
|
+
def __init__(
|
|
1670
|
+
self,
|
|
1671
|
+
configs: Optional[Dict[str, JumpStartMetadataConfig]],
|
|
1672
|
+
config_rankings: Optional[Dict[str, JumpStartConfigRanking]],
|
|
1673
|
+
scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
|
|
1674
|
+
):
|
|
1675
|
+
"""Initializes a JumpStartMetadataConfigs object.
|
|
1676
|
+
|
|
1677
|
+
Args:
|
|
1678
|
+
configs (Dict[str, JumpStartMetadataConfig]):
|
|
1679
|
+
The map of JumpStartMetadataConfig object, with config name being the key.
|
|
1680
|
+
config_rankings (JumpStartConfigRanking):
|
|
1681
|
+
Config ranking class represents the ranking of the configs in the model.
|
|
1682
|
+
scope (JumpStartScriptScope):
|
|
1683
|
+
The scope of the current config (inference or training)
|
|
1684
|
+
"""
|
|
1685
|
+
self.configs = configs
|
|
1686
|
+
self.config_rankings = config_rankings
|
|
1687
|
+
self.scope = scope
|
|
1688
|
+
|
|
1689
|
+
def to_json(self) -> Dict[str, Any]:
|
|
1690
|
+
"""Returns json representation of JumpStartMetadataConfigs object."""
|
|
1691
|
+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
|
|
1692
|
+
return json_obj
|
|
1693
|
+
|
|
1694
|
+
def get_top_config_from_ranking(
|
|
1695
|
+
self,
|
|
1696
|
+
ranking_name: str = JumpStartConfigRankingName.DEFAULT,
|
|
1697
|
+
instance_type: Optional[str] = None,
|
|
1698
|
+
) -> Optional[JumpStartMetadataConfig]:
|
|
1699
|
+
"""Gets the best the config based on config ranking.
|
|
1700
|
+
|
|
1701
|
+
Fallback to use the ordering in config names if
|
|
1702
|
+
ranking is not available.
|
|
1703
|
+
Args:
|
|
1704
|
+
ranking_name (str):
|
|
1705
|
+
The ranking name that config priority is based on.
|
|
1706
|
+
instance_type (Optional[str]):
|
|
1707
|
+
The instance type which the config selection is based on.
|
|
1708
|
+
|
|
1709
|
+
Raises:
|
|
1710
|
+
NotImplementedError: If the scope is unrecognized.
|
|
1711
|
+
"""
|
|
1712
|
+
|
|
1713
|
+
if self.scope == JumpStartScriptScope.INFERENCE:
|
|
1714
|
+
instance_type_attribute = "supported_inference_instance_types"
|
|
1715
|
+
elif self.scope == JumpStartScriptScope.TRAINING:
|
|
1716
|
+
instance_type_attribute = "supported_training_instance_types"
|
|
1717
|
+
else:
|
|
1718
|
+
raise NotImplementedError(f"Unknown script scope {self.scope}")
|
|
1719
|
+
|
|
1720
|
+
if self.configs and (
|
|
1721
|
+
not self.config_rankings or not self.config_rankings.get(ranking_name)
|
|
1722
|
+
):
|
|
1723
|
+
ranked_config_names = sorted(list(self.configs.keys()))
|
|
1724
|
+
else:
|
|
1725
|
+
rankings = self.config_rankings.get(ranking_name)
|
|
1726
|
+
ranked_config_names = rankings.rankings
|
|
1727
|
+
for config_name in ranked_config_names:
|
|
1728
|
+
resolved_config = self.configs[config_name].resolved_config
|
|
1729
|
+
if instance_type and instance_type not in getattr(
|
|
1730
|
+
resolved_config, instance_type_attribute
|
|
1731
|
+
):
|
|
1732
|
+
continue
|
|
1733
|
+
return self.configs[config_name]
|
|
1734
|
+
|
|
1735
|
+
return None
|
|
1736
|
+
|
|
1737
|
+
|
|
1738
|
+
class JumpStartModelSpecs(JumpStartMetadataBaseFields):
|
|
1739
|
+
"""Data class JumpStart model specs."""
|
|
1740
|
+
|
|
1741
|
+
slots = [
|
|
1742
|
+
"inference_configs",
|
|
1743
|
+
"inference_config_components",
|
|
1744
|
+
"inference_config_rankings",
|
|
1745
|
+
"training_configs",
|
|
1746
|
+
"training_config_components",
|
|
1747
|
+
"training_config_rankings",
|
|
1748
|
+
]
|
|
1749
|
+
|
|
1750
|
+
__slots__ = JumpStartMetadataBaseFields.__slots__ + slots
|
|
1751
|
+
|
|
1752
|
+
def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False):
|
|
1753
|
+
"""Initializes a JumpStartModelSpecs object from its json representation.
|
|
1754
|
+
|
|
1755
|
+
Args:
|
|
1756
|
+
spec (Dict[str, Any]): Dictionary representation of spec.
|
|
1757
|
+
is_hub_content (Optional[bool]): Whether the model is from a private hub.
|
|
1758
|
+
"""
|
|
1759
|
+
super().__init__(spec, is_hub_content)
|
|
1760
|
+
self.from_json(spec)
|
|
1761
|
+
if self.inference_configs and self.inference_configs.get_top_config_from_ranking():
|
|
1762
|
+
super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config)
|
|
1763
|
+
|
|
1764
|
+
def from_json(self, json_obj: Dict[str, Any]) -> None:
|
|
1765
|
+
"""Sets fields in object based on json of header.
|
|
1766
|
+
|
|
1767
|
+
Args:
|
|
1768
|
+
json_obj (Dict[str, Any]): Dictionary representation of spec.
|
|
1769
|
+
"""
|
|
1770
|
+
super().from_json(json_obj)
|
|
1771
|
+
if self._is_hub_content:
|
|
1772
|
+
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
|
|
1773
|
+
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
|
|
1774
|
+
{
|
|
1775
|
+
component_name: JumpStartConfigComponent(component_name, component)
|
|
1776
|
+
for component_name, component in json_obj["inference_config_components"].items()
|
|
1777
|
+
}
|
|
1778
|
+
if json_obj.get("inference_config_components")
|
|
1779
|
+
else None
|
|
1780
|
+
)
|
|
1781
|
+
self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
|
|
1782
|
+
{
|
|
1783
|
+
alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content)
|
|
1784
|
+
for alias, ranking in json_obj["inference_config_rankings"].items()
|
|
1785
|
+
}
|
|
1786
|
+
if json_obj.get("inference_config_rankings")
|
|
1787
|
+
else None
|
|
1788
|
+
)
|
|
1789
|
+
|
|
1790
|
+
if self._is_hub_content:
|
|
1791
|
+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
|
|
1792
|
+
{
|
|
1793
|
+
alias: JumpStartMetadataConfig(
|
|
1794
|
+
alias,
|
|
1795
|
+
config,
|
|
1796
|
+
json_obj,
|
|
1797
|
+
config.config_components,
|
|
1798
|
+
is_hub_content=self._is_hub_content,
|
|
1799
|
+
)
|
|
1800
|
+
for alias, config in json_obj["inference_configs"]["configs"].items()
|
|
1801
|
+
}
|
|
1802
|
+
if json_obj.get("inference_configs")
|
|
1803
|
+
else None
|
|
1804
|
+
)
|
|
1805
|
+
else:
|
|
1806
|
+
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
|
|
1807
|
+
{
|
|
1808
|
+
alias: JumpStartMetadataConfig(
|
|
1809
|
+
alias,
|
|
1810
|
+
config,
|
|
1811
|
+
json_obj,
|
|
1812
|
+
(
|
|
1813
|
+
{
|
|
1814
|
+
component_name: self.inference_config_components.get(component_name)
|
|
1815
|
+
for component_name in config.get("component_names")
|
|
1816
|
+
}
|
|
1817
|
+
if config and config.get("component_names")
|
|
1818
|
+
else None
|
|
1819
|
+
),
|
|
1820
|
+
)
|
|
1821
|
+
for alias, config in json_obj["inference_configs"].items()
|
|
1822
|
+
}
|
|
1823
|
+
if json_obj.get("inference_configs")
|
|
1824
|
+
else None
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
self.inference_configs: Optional[JumpStartMetadataConfigs] = (
|
|
1828
|
+
JumpStartMetadataConfigs(
|
|
1829
|
+
inference_configs_dict,
|
|
1830
|
+
self.inference_config_rankings,
|
|
1831
|
+
)
|
|
1832
|
+
if json_obj.get("inference_configs")
|
|
1833
|
+
else None
|
|
1834
|
+
)
|
|
1835
|
+
|
|
1836
|
+
if self.training_supported:
|
|
1837
|
+
self.training_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
|
|
1838
|
+
{
|
|
1839
|
+
alias: JumpStartConfigComponent(alias, component)
|
|
1840
|
+
for alias, component in json_obj["training_config_components"].items()
|
|
1841
|
+
}
|
|
1842
|
+
if json_obj.get("training_config_components")
|
|
1843
|
+
else None
|
|
1844
|
+
)
|
|
1845
|
+
self.training_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = (
|
|
1846
|
+
{
|
|
1847
|
+
alias: JumpStartConfigRanking(ranking)
|
|
1848
|
+
for alias, ranking in json_obj["training_config_rankings"].items()
|
|
1849
|
+
}
|
|
1850
|
+
if json_obj.get("training_config_rankings")
|
|
1851
|
+
else None
|
|
1852
|
+
)
|
|
1853
|
+
|
|
1854
|
+
if self._is_hub_content:
|
|
1855
|
+
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
|
|
1856
|
+
{
|
|
1857
|
+
alias: JumpStartMetadataConfig(
|
|
1858
|
+
alias,
|
|
1859
|
+
config,
|
|
1860
|
+
json_obj,
|
|
1861
|
+
config.config_components,
|
|
1862
|
+
is_hub_content=self._is_hub_content,
|
|
1863
|
+
)
|
|
1864
|
+
for alias, config in json_obj["training_configs"]["configs"].items()
|
|
1865
|
+
}
|
|
1866
|
+
if json_obj.get("training_configs")
|
|
1867
|
+
else None
|
|
1868
|
+
)
|
|
1869
|
+
else:
|
|
1870
|
+
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
|
|
1871
|
+
{
|
|
1872
|
+
alias: JumpStartMetadataConfig(
|
|
1873
|
+
alias,
|
|
1874
|
+
config,
|
|
1875
|
+
json_obj,
|
|
1876
|
+
(
|
|
1877
|
+
{
|
|
1878
|
+
component_name: self.training_config_components.get(
|
|
1879
|
+
component_name
|
|
1880
|
+
)
|
|
1881
|
+
for component_name in config.get("component_names")
|
|
1882
|
+
}
|
|
1883
|
+
if config and config.get("component_names")
|
|
1884
|
+
else None
|
|
1885
|
+
),
|
|
1886
|
+
)
|
|
1887
|
+
for alias, config in json_obj["training_configs"].items()
|
|
1888
|
+
}
|
|
1889
|
+
if json_obj.get("training_configs")
|
|
1890
|
+
else None
|
|
1891
|
+
)
|
|
1892
|
+
|
|
1893
|
+
self.training_configs: Optional[JumpStartMetadataConfigs] = (
|
|
1894
|
+
JumpStartMetadataConfigs(
|
|
1895
|
+
training_configs_dict,
|
|
1896
|
+
self.training_config_rankings,
|
|
1897
|
+
JumpStartScriptScope.TRAINING,
|
|
1898
|
+
)
|
|
1899
|
+
if json_obj.get("training_configs")
|
|
1900
|
+
else None
|
|
1901
|
+
)
|
|
1902
|
+
self.model_subscription_link = json_obj.get("model_subscription_link")
|
|
1903
|
+
|
|
1904
|
+
def set_config(
|
|
1905
|
+
self, config_name: str, scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE
|
|
1906
|
+
) -> None:
|
|
1907
|
+
"""Apply the seleted config and resolve to the current model spec.
|
|
1908
|
+
|
|
1909
|
+
Args:
|
|
1910
|
+
config_name (str): Name of the config.
|
|
1911
|
+
scope (JumpStartScriptScope, optional):
|
|
1912
|
+
Scope of the config. Defaults to JumpStartScriptScope.INFERENCE.
|
|
1913
|
+
|
|
1914
|
+
Raises:
|
|
1915
|
+
ValueError: If the scope is not supported, or cannot find config name.
|
|
1916
|
+
"""
|
|
1917
|
+
if scope == JumpStartScriptScope.INFERENCE:
|
|
1918
|
+
metadata_configs = self.inference_configs
|
|
1919
|
+
elif scope == JumpStartScriptScope.TRAINING and self.training_supported:
|
|
1920
|
+
metadata_configs = self.training_configs
|
|
1921
|
+
else:
|
|
1922
|
+
raise ValueError(f"Unknown Jumpstart script scope {scope}.")
|
|
1923
|
+
|
|
1924
|
+
config_object = metadata_configs.configs.get(config_name)
|
|
1925
|
+
if not config_object:
|
|
1926
|
+
error_msg = f"Cannot find Jumpstart config name {config_name}. "
|
|
1927
|
+
config_names = list(metadata_configs.configs.keys())
|
|
1928
|
+
if config_names:
|
|
1929
|
+
error_msg += f"List of config names that is supported by the model: {config_names}"
|
|
1930
|
+
raise ValueError(error_msg)
|
|
1931
|
+
|
|
1932
|
+
super().from_json(config_object.resolved_config)
|
|
1933
|
+
|
|
1934
|
+
def supports_prepacked_inference(self) -> bool:
|
|
1935
|
+
"""Returns True if the model has a prepacked inference artifact."""
|
|
1936
|
+
return getattr(self, "hosting_prepacked_artifact_key", None) is not None
|
|
1937
|
+
|
|
1938
|
+
def use_inference_script_uri(self) -> bool:
|
|
1939
|
+
"""Returns True if the model should use a script uri when deploying inference model."""
|
|
1940
|
+
if self.supports_prepacked_inference():
|
|
1941
|
+
return False
|
|
1942
|
+
return self.hosting_use_script_uri
|
|
1943
|
+
|
|
1944
|
+
def use_training_model_artifact(self) -> bool:
|
|
1945
|
+
"""Returns True if the model should use a model uri when kicking off training job."""
|
|
1946
|
+
# gated model never use training model artifact
|
|
1947
|
+
if self.gated_bucket:
|
|
1948
|
+
return False
|
|
1949
|
+
|
|
1950
|
+
# otherwise, return true is a training model package is not set
|
|
1951
|
+
return len(self.training_model_package_artifact_uris or {}) == 0
|
|
1952
|
+
|
|
1953
|
+
def is_gated_model(self) -> bool:
|
|
1954
|
+
"""Returns True if the model has a EULA key or the model bucket is gated."""
|
|
1955
|
+
return self.gated_bucket or self.hosting_eula_key is not None
|
|
1956
|
+
|
|
1957
|
+
def supports_incremental_training(self) -> bool:
|
|
1958
|
+
"""Returns True if the model supports incremental training."""
|
|
1959
|
+
return self.incremental_training_supported
|
|
1960
|
+
|
|
1961
|
+
def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartModelDataSource]:
|
|
1962
|
+
"""Returns data sources for speculative decoding."""
|
|
1963
|
+
if not self.hosting_additional_data_sources:
|
|
1964
|
+
return []
|
|
1965
|
+
return self.hosting_additional_data_sources.speculative_decoding or []
|
|
1966
|
+
|
|
1967
|
+
def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]:
|
|
1968
|
+
"""Returns a list of the additional S3 data sources for use by the model."""
|
|
1969
|
+
additional_data_sources = []
|
|
1970
|
+
if self.hosting_additional_data_sources:
|
|
1971
|
+
for data_source in self.hosting_additional_data_sources.to_json():
|
|
1972
|
+
data_sources = getattr(self.hosting_additional_data_sources, data_source) or []
|
|
1973
|
+
additional_data_sources.extend(data_sources)
|
|
1974
|
+
return additional_data_sources
|
|
1975
|
+
|
|
1976
|
+
|
|
1977
|
+
class JumpStartVersionedModelId(JumpStartDataHolderType):
|
|
1978
|
+
"""Data class for versioned model IDs."""
|
|
1979
|
+
|
|
1980
|
+
__slots__ = ["model_id", "version"]
|
|
1981
|
+
|
|
1982
|
+
def __init__(
|
|
1983
|
+
self,
|
|
1984
|
+
model_id: str,
|
|
1985
|
+
version: str,
|
|
1986
|
+
) -> None:
|
|
1987
|
+
"""Instantiates JumpStartVersionedModelId object.
|
|
1988
|
+
|
|
1989
|
+
Args:
|
|
1990
|
+
model_id (str): JumpStart model ID.
|
|
1991
|
+
version (str): JumpStart model version.
|
|
1992
|
+
"""
|
|
1993
|
+
self.model_id = model_id
|
|
1994
|
+
self.version = version
|
|
1995
|
+
|
|
1996
|
+
|
|
1997
|
+
class JumpStartCachedContentKey(JumpStartDataHolderType):
|
|
1998
|
+
"""Data class for the cached content keys."""
|
|
1999
|
+
|
|
2000
|
+
__slots__ = ["data_type", "id_info"]
|
|
2001
|
+
|
|
2002
|
+
def __init__(
|
|
2003
|
+
self,
|
|
2004
|
+
data_type: JumpStartContentDataType,
|
|
2005
|
+
id_info: str,
|
|
2006
|
+
) -> None:
|
|
2007
|
+
"""Instantiates JumpStartCachedContentKey object.
|
|
2008
|
+
|
|
2009
|
+
Args:
|
|
2010
|
+
data_type (JumpStartContentDataType): JumpStart content data type.
|
|
2011
|
+
id_info (str): if S3Content, object key in s3. if HubContent, hub content arn.
|
|
2012
|
+
"""
|
|
2013
|
+
self.data_type = data_type
|
|
2014
|
+
self.id_info = id_info
|
|
2015
|
+
|
|
2016
|
+
|
|
2017
|
+
class HubArnExtractedInfo(JumpStartDataHolderType):
|
|
2018
|
+
"""Data class for info extracted from Hub arn."""
|
|
2019
|
+
|
|
2020
|
+
__slots__ = [
|
|
2021
|
+
"partition",
|
|
2022
|
+
"region",
|
|
2023
|
+
"account_id",
|
|
2024
|
+
"hub_name",
|
|
2025
|
+
"hub_content_type",
|
|
2026
|
+
"hub_content_name",
|
|
2027
|
+
"hub_content_version",
|
|
2028
|
+
]
|
|
2029
|
+
|
|
2030
|
+
def __init__(
|
|
2031
|
+
self,
|
|
2032
|
+
partition: str,
|
|
2033
|
+
region: str,
|
|
2034
|
+
account_id: str,
|
|
2035
|
+
hub_name: str,
|
|
2036
|
+
hub_content_type: Optional[str] = None,
|
|
2037
|
+
hub_content_name: Optional[str] = None,
|
|
2038
|
+
hub_content_version: Optional[str] = None,
|
|
2039
|
+
) -> None:
|
|
2040
|
+
"""Instantiates HubArnExtractedInfo object."""
|
|
2041
|
+
|
|
2042
|
+
self.partition = partition
|
|
2043
|
+
self.region = region
|
|
2044
|
+
self.account_id = account_id
|
|
2045
|
+
self.hub_name = hub_name
|
|
2046
|
+
self.hub_content_name = hub_content_name
|
|
2047
|
+
self.hub_content_type = hub_content_type
|
|
2048
|
+
self.hub_content_version = hub_content_version
|
|
2049
|
+
|
|
2050
|
+
@staticmethod
|
|
2051
|
+
def extract_region_from_arn(arn: str) -> Optional[str]:
|
|
2052
|
+
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
|
|
2053
|
+
|
|
2054
|
+
HUB_CONTENT_ARN_REGEX = (
|
|
2055
|
+
r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
|
|
2056
|
+
)
|
|
2057
|
+
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
|
|
2058
|
+
|
|
2059
|
+
match = re.match(HUB_CONTENT_ARN_REGEX, arn)
|
|
2060
|
+
hub_region = None
|
|
2061
|
+
if match:
|
|
2062
|
+
hub_region = match.group(2)
|
|
2063
|
+
return hub_region
|
|
2064
|
+
|
|
2065
|
+
match = re.match(HUB_ARN_REGEX, arn)
|
|
2066
|
+
if match:
|
|
2067
|
+
hub_region = match.group(2)
|
|
2068
|
+
return hub_region
|
|
2069
|
+
|
|
2070
|
+
return hub_region
|
|
2071
|
+
|
|
2072
|
+
|
|
2073
|
+
class JumpStartCachedContentValue(JumpStartDataHolderType):
|
|
2074
|
+
"""Data class for the s3 cached content values."""
|
|
2075
|
+
|
|
2076
|
+
__slots__ = ["formatted_content", "md5_hash"]
|
|
2077
|
+
|
|
2078
|
+
def __init__(
|
|
2079
|
+
self,
|
|
2080
|
+
formatted_content: Union[
|
|
2081
|
+
Dict[JumpStartVersionedModelId, JumpStartModelHeader],
|
|
2082
|
+
JumpStartModelSpecs,
|
|
2083
|
+
],
|
|
2084
|
+
md5_hash: Optional[str] = None,
|
|
2085
|
+
) -> None:
|
|
2086
|
+
"""Instantiates JumpStartCachedContentValue object.
|
|
2087
|
+
|
|
2088
|
+
Args:
|
|
2089
|
+
formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],
|
|
2090
|
+
JumpStartModelSpecs]):
|
|
2091
|
+
Formatted content for model specs and mappings from
|
|
2092
|
+
versioned model IDs to specs.
|
|
2093
|
+
md5_hash (str): md5_hash for stored file content from s3.
|
|
2094
|
+
"""
|
|
2095
|
+
self.formatted_content = formatted_content
|
|
2096
|
+
self.md5_hash = md5_hash
|
|
2097
|
+
|
|
2098
|
+
|
|
2099
|
+
class JumpStartKwargs(JumpStartDataHolderType):
|
|
2100
|
+
"""Data class for JumpStart object kwargs."""
|
|
2101
|
+
|
|
2102
|
+
BASE_SERIALIZATION_EXCLUSION_SET: Set[str] = ["specs"]
|
|
2103
|
+
SERIALIZATION_EXCLUSION_SET: Set[str] = set()
|
|
2104
|
+
|
|
2105
|
+
def to_kwargs_dict(self, exclude_keys: bool = True):
|
|
2106
|
+
"""Serializes object to dictionary to be used for kwargs for method arguments."""
|
|
2107
|
+
kwargs_dict = {}
|
|
2108
|
+
for field in self.__slots__:
|
|
2109
|
+
if (
|
|
2110
|
+
exclude_keys
|
|
2111
|
+
and field
|
|
2112
|
+
not in self.SERIALIZATION_EXCLUSION_SET.union(self.BASE_SERIALIZATION_EXCLUSION_SET)
|
|
2113
|
+
or not exclude_keys
|
|
2114
|
+
):
|
|
2115
|
+
att_value = getattr(self, field, None)
|
|
2116
|
+
if att_value is not None:
|
|
2117
|
+
kwargs_dict[field] = getattr(self, field)
|
|
2118
|
+
return kwargs_dict
|
|
2119
|
+
|
|
2120
|
+
|
|
2121
|
+
class JumpStartModelInitKwargs(JumpStartKwargs):
|
|
2122
|
+
"""Data class for the inputs to `JumpStartModel.__init__` method."""
|
|
2123
|
+
|
|
2124
|
+
__slots__ = [
|
|
2125
|
+
"model_id",
|
|
2126
|
+
"model_version",
|
|
2127
|
+
"hub_arn",
|
|
2128
|
+
"model_type",
|
|
2129
|
+
"instance_type",
|
|
2130
|
+
"tolerate_vulnerable_model",
|
|
2131
|
+
"tolerate_deprecated_model",
|
|
2132
|
+
"region",
|
|
2133
|
+
"image_uri",
|
|
2134
|
+
"model_data",
|
|
2135
|
+
"source_dir",
|
|
2136
|
+
"entry_point",
|
|
2137
|
+
"env",
|
|
2138
|
+
"predictor_cls",
|
|
2139
|
+
"role",
|
|
2140
|
+
"name",
|
|
2141
|
+
"vpc_config",
|
|
2142
|
+
"sagemaker_session",
|
|
2143
|
+
"enable_network_isolation",
|
|
2144
|
+
"model_kms_key",
|
|
2145
|
+
"image_config",
|
|
2146
|
+
"code_location",
|
|
2147
|
+
"container_log_level",
|
|
2148
|
+
"dependencies",
|
|
2149
|
+
"git_config",
|
|
2150
|
+
"model_package_arn",
|
|
2151
|
+
"training_instance_type",
|
|
2152
|
+
"resources",
|
|
2153
|
+
"config_name",
|
|
2154
|
+
"additional_model_data_sources",
|
|
2155
|
+
"hub_content_type",
|
|
2156
|
+
"model_reference_arn",
|
|
2157
|
+
"specs",
|
|
2158
|
+
]
|
|
2159
|
+
|
|
2160
|
+
SERIALIZATION_EXCLUSION_SET = {
|
|
2161
|
+
"instance_type",
|
|
2162
|
+
"model_id",
|
|
2163
|
+
"model_version",
|
|
2164
|
+
"hub_arn",
|
|
2165
|
+
"model_type",
|
|
2166
|
+
"tolerate_vulnerable_model",
|
|
2167
|
+
"tolerate_deprecated_model",
|
|
2168
|
+
"region",
|
|
2169
|
+
"model_package_arn",
|
|
2170
|
+
"training_instance_type",
|
|
2171
|
+
"config_name",
|
|
2172
|
+
"hub_content_type",
|
|
2173
|
+
}
|
|
2174
|
+
|
|
2175
|
+
def __init__(
|
|
2176
|
+
self,
|
|
2177
|
+
model_id: str,
|
|
2178
|
+
model_version: Optional[str] = None,
|
|
2179
|
+
hub_arn: Optional[str] = None,
|
|
2180
|
+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
|
|
2181
|
+
region: Optional[str] = None,
|
|
2182
|
+
instance_type: Optional[str] = None,
|
|
2183
|
+
image_uri: Optional[Union[str, Any]] = None,
|
|
2184
|
+
model_data: Optional[Union[str, Any, dict]] = None,
|
|
2185
|
+
role: Optional[str] = None,
|
|
2186
|
+
predictor_cls: Optional[Callable] = None,
|
|
2187
|
+
env: Optional[Dict[str, Union[str, Any]]] = None,
|
|
2188
|
+
name: Optional[str] = None,
|
|
2189
|
+
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
|
|
2190
|
+
sagemaker_session: Optional[Any] = None,
|
|
2191
|
+
enable_network_isolation: Union[bool, Any] = None,
|
|
2192
|
+
model_kms_key: Optional[str] = None,
|
|
2193
|
+
image_config: Optional[Dict[str, Union[str, Any]]] = None,
|
|
2194
|
+
source_dir: Optional[str] = None,
|
|
2195
|
+
code_location: Optional[str] = None,
|
|
2196
|
+
entry_point: Optional[str] = None,
|
|
2197
|
+
container_log_level: Optional[Union[int, Any]] = None,
|
|
2198
|
+
dependencies: Optional[List[str]] = None,
|
|
2199
|
+
git_config: Optional[Dict[str, str]] = None,
|
|
2200
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
2201
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
2202
|
+
model_package_arn: Optional[str] = None,
|
|
2203
|
+
training_instance_type: Optional[str] = None,
|
|
2204
|
+
resources: Optional[ResourceRequirements] = None,
|
|
2205
|
+
config_name: Optional[str] = None,
|
|
2206
|
+
additional_model_data_sources: Optional[Dict[str, Any]] = None,
|
|
2207
|
+
) -> None:
|
|
2208
|
+
"""Instantiates JumpStartModelInitKwargs object."""
|
|
2209
|
+
|
|
2210
|
+
self.model_id = model_id
|
|
2211
|
+
self.model_version = model_version
|
|
2212
|
+
self.hub_arn = hub_arn
|
|
2213
|
+
self.model_type = model_type
|
|
2214
|
+
self.instance_type = instance_type
|
|
2215
|
+
self.region = region
|
|
2216
|
+
self.image_uri = image_uri
|
|
2217
|
+
self.model_data = deepcopy(model_data)
|
|
2218
|
+
self.source_dir = source_dir
|
|
2219
|
+
self.entry_point = entry_point
|
|
2220
|
+
self.env = deepcopy(env)
|
|
2221
|
+
self.predictor_cls = predictor_cls
|
|
2222
|
+
self.role = role
|
|
2223
|
+
self.name = name
|
|
2224
|
+
self.vpc_config = vpc_config
|
|
2225
|
+
self.sagemaker_session = sagemaker_session
|
|
2226
|
+
self.enable_network_isolation = enable_network_isolation
|
|
2227
|
+
self.model_kms_key = model_kms_key
|
|
2228
|
+
self.image_config = image_config
|
|
2229
|
+
self.code_location = code_location
|
|
2230
|
+
self.container_log_level = container_log_level
|
|
2231
|
+
self.dependencies = dependencies
|
|
2232
|
+
self.git_config = git_config
|
|
2233
|
+
self.tolerate_deprecated_model = tolerate_deprecated_model
|
|
2234
|
+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
|
|
2235
|
+
self.model_package_arn = model_package_arn
|
|
2236
|
+
self.training_instance_type = training_instance_type
|
|
2237
|
+
self.resources = resources
|
|
2238
|
+
self.config_name = config_name
|
|
2239
|
+
self.additional_model_data_sources = additional_model_data_sources
|
|
2240
|
+
|
|
2241
|
+
|
|
2242
|
+
class JumpStartModelDeployKwargs(JumpStartKwargs):
|
|
2243
|
+
"""Data class for the inputs to `JumpStartModel.deploy` method."""
|
|
2244
|
+
|
|
2245
|
+
__slots__ = [
|
|
2246
|
+
"model_id",
|
|
2247
|
+
"model_version",
|
|
2248
|
+
"hub_arn",
|
|
2249
|
+
"model_type",
|
|
2250
|
+
"initial_instance_count",
|
|
2251
|
+
"instance_type",
|
|
2252
|
+
"region",
|
|
2253
|
+
"serializer",
|
|
2254
|
+
"deserializer",
|
|
2255
|
+
"accelerator_type",
|
|
2256
|
+
"endpoint_name",
|
|
2257
|
+
"inference_component_name",
|
|
2258
|
+
"tags",
|
|
2259
|
+
"kms_key",
|
|
2260
|
+
"wait",
|
|
2261
|
+
"data_capture_config",
|
|
2262
|
+
"async_inference_config",
|
|
2263
|
+
"serverless_inference_config",
|
|
2264
|
+
"volume_size",
|
|
2265
|
+
"model_data_download_timeout",
|
|
2266
|
+
"container_startup_health_check_timeout",
|
|
2267
|
+
"inference_recommendation_id",
|
|
2268
|
+
"explainer_config",
|
|
2269
|
+
"tolerate_vulnerable_model",
|
|
2270
|
+
"tolerate_deprecated_model",
|
|
2271
|
+
"sagemaker_session",
|
|
2272
|
+
"training_instance_type",
|
|
2273
|
+
"accept_eula",
|
|
2274
|
+
"model_reference_arn",
|
|
2275
|
+
"endpoint_logging",
|
|
2276
|
+
"resources",
|
|
2277
|
+
"endpoint_type",
|
|
2278
|
+
"config_name",
|
|
2279
|
+
"routing_config",
|
|
2280
|
+
"specs",
|
|
2281
|
+
"model_access_configs",
|
|
2282
|
+
"inference_ami_version",
|
|
2283
|
+
]
|
|
2284
|
+
|
|
2285
|
+
SERIALIZATION_EXCLUSION_SET = {
|
|
2286
|
+
"model_id",
|
|
2287
|
+
"model_version",
|
|
2288
|
+
"model_type",
|
|
2289
|
+
"hub_arn",
|
|
2290
|
+
"region",
|
|
2291
|
+
"tolerate_deprecated_model",
|
|
2292
|
+
"tolerate_vulnerable_model",
|
|
2293
|
+
"sagemaker_session",
|
|
2294
|
+
"training_instance_type",
|
|
2295
|
+
"config_name",
|
|
2296
|
+
"model_access_configs",
|
|
2297
|
+
}
|
|
2298
|
+
|
|
2299
|
+
def __init__(
|
|
2300
|
+
self,
|
|
2301
|
+
model_id: str,
|
|
2302
|
+
model_version: Optional[str] = None,
|
|
2303
|
+
hub_arn: Optional[str] = None,
|
|
2304
|
+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
|
|
2305
|
+
region: Optional[str] = None,
|
|
2306
|
+
initial_instance_count: Optional[int] = None,
|
|
2307
|
+
instance_type: Optional[str] = None,
|
|
2308
|
+
serializer: Optional[Any] = None,
|
|
2309
|
+
deserializer: Optional[Any] = None,
|
|
2310
|
+
accelerator_type: Optional[str] = None,
|
|
2311
|
+
endpoint_name: Optional[str] = None,
|
|
2312
|
+
inference_component_name: Optional[str] = None,
|
|
2313
|
+
tags: Optional[Tags] = None,
|
|
2314
|
+
kms_key: Optional[str] = None,
|
|
2315
|
+
wait: Optional[bool] = None,
|
|
2316
|
+
data_capture_config: Optional[Any] = None,
|
|
2317
|
+
async_inference_config: Optional[Any] = None,
|
|
2318
|
+
serverless_inference_config: Optional[Any] = None,
|
|
2319
|
+
volume_size: Optional[int] = None,
|
|
2320
|
+
model_data_download_timeout: Optional[int] = None,
|
|
2321
|
+
container_startup_health_check_timeout: Optional[int] = None,
|
|
2322
|
+
inference_recommendation_id: Optional[str] = None,
|
|
2323
|
+
explainer_config: Optional[Any] = None,
|
|
2324
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
2325
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
2326
|
+
sagemaker_session: Optional[Session] = None,
|
|
2327
|
+
training_instance_type: Optional[str] = None,
|
|
2328
|
+
accept_eula: Optional[bool] = None,
|
|
2329
|
+
model_reference_arn: Optional[str] = None,
|
|
2330
|
+
endpoint_logging: Optional[bool] = None,
|
|
2331
|
+
resources: Optional[ResourceRequirements] = None,
|
|
2332
|
+
endpoint_type: Optional[EndpointType] = None,
|
|
2333
|
+
config_name: Optional[str] = None,
|
|
2334
|
+
routing_config: Optional[Dict[str, Any]] = None,
|
|
2335
|
+
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
|
|
2336
|
+
inference_ami_version: Optional[str] = None,
|
|
2337
|
+
) -> None:
|
|
2338
|
+
"""Instantiates JumpStartModelDeployKwargs object."""
|
|
2339
|
+
|
|
2340
|
+
self.model_id = model_id
|
|
2341
|
+
self.model_version = model_version
|
|
2342
|
+
self.hub_arn = hub_arn
|
|
2343
|
+
self.model_type = model_type
|
|
2344
|
+
self.initial_instance_count = initial_instance_count
|
|
2345
|
+
self.instance_type = instance_type
|
|
2346
|
+
self.region = region
|
|
2347
|
+
self.serializer = serializer
|
|
2348
|
+
self.deserializer = deserializer
|
|
2349
|
+
self.accelerator_type = accelerator_type
|
|
2350
|
+
self.endpoint_name = endpoint_name
|
|
2351
|
+
self.inference_component_name = inference_component_name
|
|
2352
|
+
self.tags = format_tags(tags)
|
|
2353
|
+
self.kms_key = kms_key
|
|
2354
|
+
self.wait = wait
|
|
2355
|
+
self.data_capture_config = data_capture_config
|
|
2356
|
+
self.async_inference_config = async_inference_config
|
|
2357
|
+
self.serverless_inference_config = serverless_inference_config
|
|
2358
|
+
self.volume_size = volume_size
|
|
2359
|
+
self.model_data_download_timeout = model_data_download_timeout
|
|
2360
|
+
self.container_startup_health_check_timeout = container_startup_health_check_timeout
|
|
2361
|
+
self.inference_recommendation_id = inference_recommendation_id
|
|
2362
|
+
self.explainer_config = explainer_config
|
|
2363
|
+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
|
|
2364
|
+
self.tolerate_deprecated_model = tolerate_deprecated_model
|
|
2365
|
+
self.sagemaker_session = sagemaker_session
|
|
2366
|
+
self.training_instance_type = training_instance_type
|
|
2367
|
+
self.accept_eula = accept_eula
|
|
2368
|
+
self.model_reference_arn = model_reference_arn
|
|
2369
|
+
self.endpoint_logging = endpoint_logging
|
|
2370
|
+
self.resources = resources
|
|
2371
|
+
self.endpoint_type = endpoint_type
|
|
2372
|
+
self.config_name = config_name
|
|
2373
|
+
self.routing_config = routing_config
|
|
2374
|
+
self.model_access_configs = model_access_configs
|
|
2375
|
+
self.inference_ami_version = inference_ami_version
|
|
2376
|
+
|
|
2377
|
+
|
|
2378
|
+
class JumpStartEstimatorInitKwargs(JumpStartKwargs):
|
|
2379
|
+
"""Data class for the inputs to `JumpStartEstimator.__init__` method."""
|
|
2380
|
+
|
|
2381
|
+
__slots__ = [
|
|
2382
|
+
"model_id",
|
|
2383
|
+
"model_version",
|
|
2384
|
+
"hub_arn",
|
|
2385
|
+
"model_type",
|
|
2386
|
+
"instance_type",
|
|
2387
|
+
"instance_count",
|
|
2388
|
+
"region",
|
|
2389
|
+
"image_uri",
|
|
2390
|
+
"model_uri",
|
|
2391
|
+
"source_dir",
|
|
2392
|
+
"entry_point",
|
|
2393
|
+
"hyperparameters",
|
|
2394
|
+
"metric_definitions",
|
|
2395
|
+
"role",
|
|
2396
|
+
"keep_alive_period_in_seconds",
|
|
2397
|
+
"volume_size",
|
|
2398
|
+
"volume_kms_key",
|
|
2399
|
+
"max_run",
|
|
2400
|
+
"input_mode",
|
|
2401
|
+
"output_path",
|
|
2402
|
+
"output_kms_key",
|
|
2403
|
+
"base_job_name",
|
|
2404
|
+
"sagemaker_session",
|
|
2405
|
+
"tags",
|
|
2406
|
+
"subnets",
|
|
2407
|
+
"security_group_ids",
|
|
2408
|
+
"model_channel_name",
|
|
2409
|
+
"encrypt_inter_container_traffic",
|
|
2410
|
+
"use_spot_instances",
|
|
2411
|
+
"max_wait",
|
|
2412
|
+
"checkpoint_s3_uri",
|
|
2413
|
+
"checkpoint_local_path",
|
|
2414
|
+
"enable_network_isolation",
|
|
2415
|
+
"rules",
|
|
2416
|
+
"debugger_hook_config",
|
|
2417
|
+
"tensorboard_output_config",
|
|
2418
|
+
"enable_sagemaker_metrics",
|
|
2419
|
+
"profiler_config",
|
|
2420
|
+
"disable_profiler",
|
|
2421
|
+
"environment",
|
|
2422
|
+
"max_retry_attempts",
|
|
2423
|
+
"git_config",
|
|
2424
|
+
"container_log_level",
|
|
2425
|
+
"code_location",
|
|
2426
|
+
"dependencies",
|
|
2427
|
+
"instance_groups",
|
|
2428
|
+
"training_repository_access_mode",
|
|
2429
|
+
"training_repository_credentials_provider_arn",
|
|
2430
|
+
"tolerate_deprecated_model",
|
|
2431
|
+
"tolerate_vulnerable_model",
|
|
2432
|
+
"container_entry_point",
|
|
2433
|
+
"container_arguments",
|
|
2434
|
+
"disable_output_compression",
|
|
2435
|
+
"enable_infra_check",
|
|
2436
|
+
"enable_remote_debug",
|
|
2437
|
+
"config_name",
|
|
2438
|
+
"enable_session_tag_chaining",
|
|
2439
|
+
"hub_content_type",
|
|
2440
|
+
"model_reference_arn",
|
|
2441
|
+
"specs",
|
|
2442
|
+
"training_plan",
|
|
2443
|
+
]
|
|
2444
|
+
|
|
2445
|
+
SERIALIZATION_EXCLUSION_SET = {
|
|
2446
|
+
"region",
|
|
2447
|
+
"tolerate_deprecated_model",
|
|
2448
|
+
"tolerate_vulnerable_model",
|
|
2449
|
+
"model_id",
|
|
2450
|
+
"model_version",
|
|
2451
|
+
"hub_arn",
|
|
2452
|
+
"model_type",
|
|
2453
|
+
"hub_content_type",
|
|
2454
|
+
"config_name",
|
|
2455
|
+
}
|
|
2456
|
+
|
|
2457
|
+
def __init__(
|
|
2458
|
+
self,
|
|
2459
|
+
model_id: str,
|
|
2460
|
+
model_version: Optional[str] = None,
|
|
2461
|
+
hub_arn: Optional[str] = None,
|
|
2462
|
+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
|
|
2463
|
+
region: Optional[str] = None,
|
|
2464
|
+
image_uri: Optional[Union[str, Any]] = None,
|
|
2465
|
+
role: Optional[str] = None,
|
|
2466
|
+
instance_count: Optional[Union[int, Any]] = None,
|
|
2467
|
+
instance_type: Optional[Union[str, Any]] = None,
|
|
2468
|
+
keep_alive_period_in_seconds: Optional[Union[int, Any]] = None,
|
|
2469
|
+
volume_size: Optional[Union[int, Any]] = None,
|
|
2470
|
+
volume_kms_key: Optional[Union[str, Any]] = None,
|
|
2471
|
+
max_run: Optional[Union[int, Any]] = None,
|
|
2472
|
+
input_mode: Optional[Union[str, Any]] = None,
|
|
2473
|
+
output_path: Optional[Union[str, Any]] = None,
|
|
2474
|
+
output_kms_key: Optional[Union[str, Any]] = None,
|
|
2475
|
+
base_job_name: Optional[str] = None,
|
|
2476
|
+
sagemaker_session: Optional[Any] = None,
|
|
2477
|
+
hyperparameters: Optional[Dict[str, Union[str, Any]]] = None,
|
|
2478
|
+
tags: Optional[Tags] = None,
|
|
2479
|
+
subnets: Optional[List[Union[str, Any]]] = None,
|
|
2480
|
+
security_group_ids: Optional[List[Union[str, Any]]] = None,
|
|
2481
|
+
model_uri: Optional[str] = None,
|
|
2482
|
+
model_channel_name: Optional[Union[str, Any]] = None,
|
|
2483
|
+
metric_definitions: Optional[List[Dict[str, Union[str, Any]]]] = None,
|
|
2484
|
+
encrypt_inter_container_traffic: Union[bool, Any] = None,
|
|
2485
|
+
use_spot_instances: Optional[Union[bool, Any]] = None,
|
|
2486
|
+
max_wait: Optional[Union[int, Any]] = None,
|
|
2487
|
+
checkpoint_s3_uri: Optional[Union[str, Any]] = None,
|
|
2488
|
+
checkpoint_local_path: Optional[Union[str, Any]] = None,
|
|
2489
|
+
enable_network_isolation: Union[bool, Any] = None,
|
|
2490
|
+
rules: Optional[List[Any]] = None,
|
|
2491
|
+
debugger_hook_config: Optional[Union[Any, bool]] = None,
|
|
2492
|
+
tensorboard_output_config: Optional[Any] = None,
|
|
2493
|
+
enable_sagemaker_metrics: Optional[Union[bool, Any]] = None,
|
|
2494
|
+
profiler_config: Optional[Any] = None,
|
|
2495
|
+
disable_profiler: Optional[bool] = None,
|
|
2496
|
+
environment: Optional[Dict[str, Union[str, Any]]] = None,
|
|
2497
|
+
max_retry_attempts: Optional[Union[int, Any]] = None,
|
|
2498
|
+
source_dir: Optional[Union[str, Any]] = None,
|
|
2499
|
+
git_config: Optional[Dict[str, str]] = None,
|
|
2500
|
+
container_log_level: Optional[Union[int, Any]] = None,
|
|
2501
|
+
code_location: Optional[str] = None,
|
|
2502
|
+
entry_point: Optional[Union[str, Any]] = None,
|
|
2503
|
+
dependencies: Optional[List[str]] = None,
|
|
2504
|
+
instance_groups: Optional[List[Any]] = None,
|
|
2505
|
+
training_repository_access_mode: Optional[Union[str, Any]] = None,
|
|
2506
|
+
training_repository_credentials_provider_arn: Optional[Union[str, Any]] = None,
|
|
2507
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
2508
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
2509
|
+
container_entry_point: Optional[List[str]] = None,
|
|
2510
|
+
container_arguments: Optional[List[str]] = None,
|
|
2511
|
+
disable_output_compression: Optional[bool] = None,
|
|
2512
|
+
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
|
|
2513
|
+
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
|
|
2514
|
+
config_name: Optional[str] = None,
|
|
2515
|
+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
|
|
2516
|
+
training_plan: Optional[Union[str, PipelineVariable]] = None,
|
|
2517
|
+
) -> None:
|
|
2518
|
+
"""Instantiates JumpStartEstimatorInitKwargs object."""
|
|
2519
|
+
|
|
2520
|
+
self.model_id = model_id
|
|
2521
|
+
self.model_version = model_version
|
|
2522
|
+
self.hub_arn = hub_arn
|
|
2523
|
+
self.model_type = model_type
|
|
2524
|
+
self.instance_type = instance_type
|
|
2525
|
+
self.instance_count = instance_count
|
|
2526
|
+
self.region = region
|
|
2527
|
+
self.image_uri = image_uri
|
|
2528
|
+
self.model_uri = model_uri
|
|
2529
|
+
self.source_dir = source_dir
|
|
2530
|
+
self.entry_point = entry_point
|
|
2531
|
+
self.hyperparameters = deepcopy(hyperparameters)
|
|
2532
|
+
self.metric_definitions = deepcopy(metric_definitions)
|
|
2533
|
+
self.role = role
|
|
2534
|
+
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
|
|
2535
|
+
self.volume_size = volume_size
|
|
2536
|
+
self.volume_kms_key = volume_kms_key
|
|
2537
|
+
self.max_run = max_run
|
|
2538
|
+
self.input_mode = input_mode
|
|
2539
|
+
self.output_path = output_path
|
|
2540
|
+
self.output_kms_key = output_kms_key
|
|
2541
|
+
self.base_job_name = base_job_name
|
|
2542
|
+
self.sagemaker_session = sagemaker_session
|
|
2543
|
+
self.tags = format_tags(tags)
|
|
2544
|
+
self.subnets = subnets
|
|
2545
|
+
self.security_group_ids = security_group_ids
|
|
2546
|
+
self.model_channel_name = model_channel_name
|
|
2547
|
+
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
|
|
2548
|
+
self.use_spot_instances = use_spot_instances
|
|
2549
|
+
self.max_wait = max_wait
|
|
2550
|
+
self.checkpoint_s3_uri = checkpoint_s3_uri
|
|
2551
|
+
self.checkpoint_local_path = checkpoint_local_path
|
|
2552
|
+
self.enable_network_isolation = enable_network_isolation
|
|
2553
|
+
self.rules = rules
|
|
2554
|
+
self.debugger_hook_config = debugger_hook_config
|
|
2555
|
+
self.tensorboard_output_config = tensorboard_output_config
|
|
2556
|
+
self.enable_sagemaker_metrics = enable_sagemaker_metrics
|
|
2557
|
+
self.profiler_config = profiler_config
|
|
2558
|
+
self.disable_profiler = disable_profiler
|
|
2559
|
+
self.environment = deepcopy(environment)
|
|
2560
|
+
self.max_retry_attempts = max_retry_attempts
|
|
2561
|
+
self.git_config = git_config
|
|
2562
|
+
self.container_log_level = container_log_level
|
|
2563
|
+
self.code_location = code_location
|
|
2564
|
+
self.dependencies = dependencies
|
|
2565
|
+
self.instance_groups = instance_groups
|
|
2566
|
+
self.training_repository_access_mode = training_repository_access_mode
|
|
2567
|
+
self.training_repository_credentials_provider_arn = (
|
|
2568
|
+
training_repository_credentials_provider_arn
|
|
2569
|
+
)
|
|
2570
|
+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
|
|
2571
|
+
self.tolerate_deprecated_model = tolerate_deprecated_model
|
|
2572
|
+
self.container_entry_point = container_entry_point
|
|
2573
|
+
self.container_arguments = container_arguments
|
|
2574
|
+
self.disable_output_compression = disable_output_compression
|
|
2575
|
+
self.enable_infra_check = enable_infra_check
|
|
2576
|
+
self.enable_remote_debug = enable_remote_debug
|
|
2577
|
+
self.config_name = config_name
|
|
2578
|
+
self.enable_session_tag_chaining = enable_session_tag_chaining
|
|
2579
|
+
self.training_plan = training_plan
|
|
2580
|
+
|
|
2581
|
+
|
|
2582
|
+
class JumpStartEstimatorFitKwargs(JumpStartKwargs):
|
|
2583
|
+
"""Data class for the inputs to `JumpStartEstimator.fit` method."""
|
|
2584
|
+
|
|
2585
|
+
__slots__ = [
|
|
2586
|
+
"model_id",
|
|
2587
|
+
"model_version",
|
|
2588
|
+
"hub_arn",
|
|
2589
|
+
"model_type",
|
|
2590
|
+
"region",
|
|
2591
|
+
"inputs",
|
|
2592
|
+
"wait",
|
|
2593
|
+
"logs",
|
|
2594
|
+
"job_name",
|
|
2595
|
+
"experiment_config",
|
|
2596
|
+
"tolerate_deprecated_model",
|
|
2597
|
+
"tolerate_vulnerable_model",
|
|
2598
|
+
"sagemaker_session",
|
|
2599
|
+
"config_name",
|
|
2600
|
+
"specs",
|
|
2601
|
+
]
|
|
2602
|
+
|
|
2603
|
+
SERIALIZATION_EXCLUSION_SET = {
|
|
2604
|
+
"model_id",
|
|
2605
|
+
"model_version",
|
|
2606
|
+
"hub_arn",
|
|
2607
|
+
"model_type",
|
|
2608
|
+
"region",
|
|
2609
|
+
"tolerate_deprecated_model",
|
|
2610
|
+
"tolerate_vulnerable_model",
|
|
2611
|
+
"sagemaker_session",
|
|
2612
|
+
"config_name",
|
|
2613
|
+
}
|
|
2614
|
+
|
|
2615
|
+
def __init__(
|
|
2616
|
+
self,
|
|
2617
|
+
model_id: str,
|
|
2618
|
+
model_version: Optional[str] = None,
|
|
2619
|
+
hub_arn: Optional[str] = None,
|
|
2620
|
+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
|
|
2621
|
+
region: Optional[str] = None,
|
|
2622
|
+
inputs: Optional[Union[str, Dict, Any, Any]] = None,
|
|
2623
|
+
wait: Optional[bool] = None,
|
|
2624
|
+
logs: Optional[str] = None,
|
|
2625
|
+
job_name: Optional[str] = None,
|
|
2626
|
+
experiment_config: Optional[Dict[str, str]] = None,
|
|
2627
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
2628
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
2629
|
+
sagemaker_session: Optional[Session] = None,
|
|
2630
|
+
config_name: Optional[str] = None,
|
|
2631
|
+
) -> None:
|
|
2632
|
+
"""Instantiates JumpStartEstimatorInitKwargs object."""
|
|
2633
|
+
|
|
2634
|
+
self.model_id = model_id
|
|
2635
|
+
self.model_version = model_version
|
|
2636
|
+
self.hub_arn = hub_arn
|
|
2637
|
+
self.model_type = model_type
|
|
2638
|
+
self.region = region
|
|
2639
|
+
self.inputs = inputs
|
|
2640
|
+
self.wait = wait
|
|
2641
|
+
self.logs = logs
|
|
2642
|
+
self.job_name = job_name
|
|
2643
|
+
self.experiment_config = experiment_config
|
|
2644
|
+
self.tolerate_deprecated_model = tolerate_deprecated_model
|
|
2645
|
+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
|
|
2646
|
+
self.sagemaker_session = sagemaker_session
|
|
2647
|
+
self.config_name = config_name
|
|
2648
|
+
|
|
2649
|
+
|
|
2650
|
+
class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
|
|
2651
|
+
"""Data class for the inputs to `JumpStartEstimator.deploy` method."""
|
|
2652
|
+
|
|
2653
|
+
__slots__ = [
|
|
2654
|
+
"model_id",
|
|
2655
|
+
"model_version",
|
|
2656
|
+
"hub_arn",
|
|
2657
|
+
"instance_type",
|
|
2658
|
+
"initial_instance_count",
|
|
2659
|
+
"region",
|
|
2660
|
+
"image_uri",
|
|
2661
|
+
"source_dir",
|
|
2662
|
+
"entry_point",
|
|
2663
|
+
"env",
|
|
2664
|
+
"predictor_cls",
|
|
2665
|
+
"serializer",
|
|
2666
|
+
"deserializer",
|
|
2667
|
+
"accelerator_type",
|
|
2668
|
+
"endpoint_name",
|
|
2669
|
+
"tags",
|
|
2670
|
+
"kms_key",
|
|
2671
|
+
"wait",
|
|
2672
|
+
"data_capture_config",
|
|
2673
|
+
"async_inference_config",
|
|
2674
|
+
"serverless_inference_config",
|
|
2675
|
+
"volume_size",
|
|
2676
|
+
"model_data_download_timeout",
|
|
2677
|
+
"container_startup_health_check_timeout",
|
|
2678
|
+
"inference_recommendation_id",
|
|
2679
|
+
"explainer_config",
|
|
2680
|
+
"role",
|
|
2681
|
+
"vpc_config",
|
|
2682
|
+
"sagemaker_session",
|
|
2683
|
+
"enable_network_isolation",
|
|
2684
|
+
"model_kms_key",
|
|
2685
|
+
"image_config",
|
|
2686
|
+
"code_location",
|
|
2687
|
+
"container_log_level",
|
|
2688
|
+
"dependencies",
|
|
2689
|
+
"git_config",
|
|
2690
|
+
"tolerate_deprecated_model",
|
|
2691
|
+
"tolerate_vulnerable_model",
|
|
2692
|
+
"model_name",
|
|
2693
|
+
"use_compiled_model",
|
|
2694
|
+
"config_name",
|
|
2695
|
+
"specs",
|
|
2696
|
+
]
|
|
2697
|
+
|
|
2698
|
+
SERIALIZATION_EXCLUSION_SET = {
|
|
2699
|
+
"tolerate_vulnerable_model",
|
|
2700
|
+
"tolerate_deprecated_model",
|
|
2701
|
+
"region",
|
|
2702
|
+
"model_id",
|
|
2703
|
+
"model_version",
|
|
2704
|
+
"hub_arn",
|
|
2705
|
+
"sagemaker_session",
|
|
2706
|
+
"config_name",
|
|
2707
|
+
}
|
|
2708
|
+
|
|
2709
|
+
def __init__(
|
|
2710
|
+
self,
|
|
2711
|
+
model_id: str,
|
|
2712
|
+
model_version: Optional[str] = None,
|
|
2713
|
+
hub_arn: Optional[str] = None,
|
|
2714
|
+
region: Optional[str] = None,
|
|
2715
|
+
initial_instance_count: Optional[int] = None,
|
|
2716
|
+
instance_type: Optional[str] = None,
|
|
2717
|
+
serializer: Optional[Any] = None,
|
|
2718
|
+
deserializer: Optional[Any] = None,
|
|
2719
|
+
accelerator_type: Optional[str] = None,
|
|
2720
|
+
endpoint_name: Optional[str] = None,
|
|
2721
|
+
tags: Optional[Tags] = None,
|
|
2722
|
+
kms_key: Optional[str] = None,
|
|
2723
|
+
wait: Optional[bool] = None,
|
|
2724
|
+
data_capture_config: Optional[Any] = None,
|
|
2725
|
+
async_inference_config: Optional[Any] = None,
|
|
2726
|
+
serverless_inference_config: Optional[Any] = None,
|
|
2727
|
+
volume_size: Optional[int] = None,
|
|
2728
|
+
model_data_download_timeout: Optional[int] = None,
|
|
2729
|
+
container_startup_health_check_timeout: Optional[int] = None,
|
|
2730
|
+
inference_recommendation_id: Optional[str] = None,
|
|
2731
|
+
explainer_config: Optional[Any] = None,
|
|
2732
|
+
image_uri: Optional[Union[str, Any]] = None,
|
|
2733
|
+
role: Optional[str] = None,
|
|
2734
|
+
predictor_cls: Optional[Callable] = None,
|
|
2735
|
+
env: Optional[Dict[str, Union[str, Any]]] = None,
|
|
2736
|
+
model_name: Optional[str] = None,
|
|
2737
|
+
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
|
|
2738
|
+
sagemaker_session: Optional[Any] = None,
|
|
2739
|
+
enable_network_isolation: Union[bool, Any] = None,
|
|
2740
|
+
model_kms_key: Optional[str] = None,
|
|
2741
|
+
image_config: Optional[Dict[str, Union[str, Any]]] = None,
|
|
2742
|
+
source_dir: Optional[str] = None,
|
|
2743
|
+
code_location: Optional[str] = None,
|
|
2744
|
+
entry_point: Optional[str] = None,
|
|
2745
|
+
container_log_level: Optional[Union[int, Any]] = None,
|
|
2746
|
+
dependencies: Optional[List[str]] = None,
|
|
2747
|
+
git_config: Optional[Dict[str, str]] = None,
|
|
2748
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
2749
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
2750
|
+
use_compiled_model: bool = False,
|
|
2751
|
+
config_name: Optional[str] = None,
|
|
2752
|
+
) -> None:
|
|
2753
|
+
"""Instantiates JumpStartEstimatorInitKwargs object."""
|
|
2754
|
+
|
|
2755
|
+
self.model_id = model_id
|
|
2756
|
+
self.model_version = model_version
|
|
2757
|
+
self.hub_arn = hub_arn
|
|
2758
|
+
self.instance_type = instance_type
|
|
2759
|
+
self.initial_instance_count = initial_instance_count
|
|
2760
|
+
self.region = region
|
|
2761
|
+
self.image_uri = image_uri
|
|
2762
|
+
self.source_dir = source_dir
|
|
2763
|
+
self.entry_point = entry_point
|
|
2764
|
+
self.env = deepcopy(env)
|
|
2765
|
+
self.predictor_cls = predictor_cls
|
|
2766
|
+
self.serializer = serializer
|
|
2767
|
+
self.deserializer = deserializer
|
|
2768
|
+
self.accelerator_type = accelerator_type
|
|
2769
|
+
self.endpoint_name = endpoint_name
|
|
2770
|
+
self.tags = format_tags(tags)
|
|
2771
|
+
self.kms_key = kms_key
|
|
2772
|
+
self.wait = wait
|
|
2773
|
+
self.data_capture_config = data_capture_config
|
|
2774
|
+
self.async_inference_config = async_inference_config
|
|
2775
|
+
self.serverless_inference_config = serverless_inference_config
|
|
2776
|
+
self.volume_size = volume_size
|
|
2777
|
+
self.model_data_download_timeout = model_data_download_timeout
|
|
2778
|
+
self.container_startup_health_check_timeout = container_startup_health_check_timeout
|
|
2779
|
+
self.inference_recommendation_id = inference_recommendation_id
|
|
2780
|
+
self.explainer_config = explainer_config
|
|
2781
|
+
self.role = role
|
|
2782
|
+
self.model_name = model_name
|
|
2783
|
+
self.vpc_config = vpc_config
|
|
2784
|
+
self.sagemaker_session = sagemaker_session
|
|
2785
|
+
self.enable_network_isolation = enable_network_isolation
|
|
2786
|
+
self.model_kms_key = model_kms_key
|
|
2787
|
+
self.image_config = image_config
|
|
2788
|
+
self.code_location = code_location
|
|
2789
|
+
self.container_log_level = container_log_level
|
|
2790
|
+
self.dependencies = dependencies
|
|
2791
|
+
self.git_config = git_config
|
|
2792
|
+
self.tolerate_deprecated_model = tolerate_deprecated_model
|
|
2793
|
+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
|
|
2794
|
+
self.use_compiled_model = use_compiled_model
|
|
2795
|
+
self.config_name = config_name
|
|
2796
|
+
|
|
2797
|
+
|
|
2798
|
+
class JumpStartModelRegisterKwargs(JumpStartKwargs):
|
|
2799
|
+
"""Data class for the inputs to `JumpStartEstimator.deploy` method."""
|
|
2800
|
+
|
|
2801
|
+
__slots__ = [
|
|
2802
|
+
"tolerate_vulnerable_model",
|
|
2803
|
+
"tolerate_deprecated_model",
|
|
2804
|
+
"region",
|
|
2805
|
+
"model_id",
|
|
2806
|
+
"model_type",
|
|
2807
|
+
"model_version",
|
|
2808
|
+
"hub_arn",
|
|
2809
|
+
"sagemaker_session",
|
|
2810
|
+
"content_types",
|
|
2811
|
+
"response_types",
|
|
2812
|
+
"inference_instances",
|
|
2813
|
+
"transform_instances",
|
|
2814
|
+
"model_package_group_name",
|
|
2815
|
+
"image_uri",
|
|
2816
|
+
"model_metrics",
|
|
2817
|
+
"metadata_properties",
|
|
2818
|
+
"approval_status",
|
|
2819
|
+
"description",
|
|
2820
|
+
"drift_check_baselines",
|
|
2821
|
+
"customer_metadata_properties",
|
|
2822
|
+
"validation_specification",
|
|
2823
|
+
"domain",
|
|
2824
|
+
"task",
|
|
2825
|
+
"sample_payload_url",
|
|
2826
|
+
"framework",
|
|
2827
|
+
"framework_version",
|
|
2828
|
+
"nearest_model_name",
|
|
2829
|
+
"data_input_configuration",
|
|
2830
|
+
"skip_model_validation",
|
|
2831
|
+
"source_uri",
|
|
2832
|
+
"model_life_cycle",
|
|
2833
|
+
"config_name",
|
|
2834
|
+
"model_card",
|
|
2835
|
+
"accept_eula",
|
|
2836
|
+
"specs",
|
|
2837
|
+
]
|
|
2838
|
+
|
|
2839
|
+
SERIALIZATION_EXCLUSION_SET = {
|
|
2840
|
+
"tolerate_vulnerable_model",
|
|
2841
|
+
"tolerate_deprecated_model",
|
|
2842
|
+
"region",
|
|
2843
|
+
"model_id",
|
|
2844
|
+
"model_version",
|
|
2845
|
+
"hub_arn",
|
|
2846
|
+
"sagemaker_session",
|
|
2847
|
+
"config_name",
|
|
2848
|
+
}
|
|
2849
|
+
|
|
2850
|
+
def __init__(
|
|
2851
|
+
self,
|
|
2852
|
+
model_id: str,
|
|
2853
|
+
model_version: Optional[str] = None,
|
|
2854
|
+
hub_arn: Optional[str] = None,
|
|
2855
|
+
region: Optional[str] = None,
|
|
2856
|
+
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
|
|
2857
|
+
tolerate_deprecated_model: Optional[bool] = None,
|
|
2858
|
+
tolerate_vulnerable_model: Optional[bool] = None,
|
|
2859
|
+
sagemaker_session: Optional[Any] = None,
|
|
2860
|
+
content_types: List[str] = None,
|
|
2861
|
+
response_types: List[str] = None,
|
|
2862
|
+
inference_instances: Optional[List[str]] = None,
|
|
2863
|
+
transform_instances: Optional[List[str]] = None,
|
|
2864
|
+
model_package_group_name: Optional[str] = None,
|
|
2865
|
+
image_uri: Optional[str] = None,
|
|
2866
|
+
model_metrics: Optional[ModelMetrics] = None,
|
|
2867
|
+
metadata_properties: Optional[MetadataProperties] = None,
|
|
2868
|
+
approval_status: Optional[str] = None,
|
|
2869
|
+
description: Optional[str] = None,
|
|
2870
|
+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
|
|
2871
|
+
customer_metadata_properties: Optional[Dict[str, str]] = None,
|
|
2872
|
+
validation_specification: Optional[str] = None,
|
|
2873
|
+
domain: Optional[str] = None,
|
|
2874
|
+
task: Optional[str] = None,
|
|
2875
|
+
sample_payload_url: Optional[str] = None,
|
|
2876
|
+
framework: Optional[str] = None,
|
|
2877
|
+
framework_version: Optional[str] = None,
|
|
2878
|
+
nearest_model_name: Optional[str] = None,
|
|
2879
|
+
data_input_configuration: Optional[str] = None,
|
|
2880
|
+
skip_model_validation: Optional[str] = None,
|
|
2881
|
+
source_uri: Optional[str] = None,
|
|
2882
|
+
model_life_cycle: Optional[ModelLifeCycle] = None,
|
|
2883
|
+
config_name: Optional[str] = None,
|
|
2884
|
+
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
|
|
2885
|
+
accept_eula: Optional[bool] = None,
|
|
2886
|
+
) -> None:
|
|
2887
|
+
"""Instantiates JumpStartModelRegisterKwargs object."""
|
|
2888
|
+
|
|
2889
|
+
self.model_id = model_id
|
|
2890
|
+
self.model_version = model_version
|
|
2891
|
+
self.hub_arn = hub_arn
|
|
2892
|
+
self.model_type = model_type
|
|
2893
|
+
self.region = region
|
|
2894
|
+
self.image_uri = image_uri
|
|
2895
|
+
self.sagemaker_session = sagemaker_session
|
|
2896
|
+
self.tolerate_deprecated_model = tolerate_deprecated_model
|
|
2897
|
+
self.tolerate_vulnerable_model = tolerate_vulnerable_model
|
|
2898
|
+
self.content_types = content_types
|
|
2899
|
+
self.response_types = response_types
|
|
2900
|
+
self.inference_instances = inference_instances
|
|
2901
|
+
self.transform_instances = transform_instances
|
|
2902
|
+
self.model_package_group_name = model_package_group_name
|
|
2903
|
+
self.image_uri = image_uri
|
|
2904
|
+
self.model_metrics = model_metrics
|
|
2905
|
+
self.metadata_properties = metadata_properties
|
|
2906
|
+
self.approval_status = approval_status
|
|
2907
|
+
self.description = description
|
|
2908
|
+
self.drift_check_baselines = drift_check_baselines
|
|
2909
|
+
self.customer_metadata_properties = customer_metadata_properties
|
|
2910
|
+
self.validation_specification = validation_specification
|
|
2911
|
+
self.domain = domain
|
|
2912
|
+
self.task = task
|
|
2913
|
+
self.sample_payload_url = sample_payload_url
|
|
2914
|
+
self.framework = framework
|
|
2915
|
+
self.framework_version = framework_version
|
|
2916
|
+
self.nearest_model_name = nearest_model_name
|
|
2917
|
+
self.data_input_configuration = data_input_configuration
|
|
2918
|
+
self.skip_model_validation = skip_model_validation
|
|
2919
|
+
self.source_uri = source_uri
|
|
2920
|
+
self.config_name = config_name
|
|
2921
|
+
self.model_card = model_card
|
|
2922
|
+
self.accept_eula = accept_eula
|
|
2923
|
+
|
|
2924
|
+
|
|
2925
|
+
class BaseDeploymentConfigDataHolder(JumpStartDataHolderType):
|
|
2926
|
+
"""Base class for Deployment Config Data."""
|
|
2927
|
+
|
|
2928
|
+
def _convert_to_pascal_case(self, attr_name: str) -> str:
|
|
2929
|
+
"""Converts a snake_case attribute name into a camelCased string.
|
|
2930
|
+
|
|
2931
|
+
Args:
|
|
2932
|
+
attr_name (str): The snake_case attribute name.
|
|
2933
|
+
Returns:
|
|
2934
|
+
str: The PascalCased attribute name.
|
|
2935
|
+
"""
|
|
2936
|
+
return attr_name.replace("_", " ").title().replace(" ", "")
|
|
2937
|
+
|
|
2938
|
+
def to_json(self) -> Dict[str, Any]:
|
|
2939
|
+
"""Represents ``This`` object as JSON."""
|
|
2940
|
+
json_obj = {}
|
|
2941
|
+
for att in self.__slots__:
|
|
2942
|
+
if hasattr(self, att):
|
|
2943
|
+
cur_val = getattr(self, att)
|
|
2944
|
+
att = self._convert_to_pascal_case(att)
|
|
2945
|
+
json_obj[att] = self._val_to_json(cur_val)
|
|
2946
|
+
return json_obj
|
|
2947
|
+
|
|
2948
|
+
def _val_to_json(self, val: Any) -> Any:
|
|
2949
|
+
"""Converts the given value to JSON.
|
|
2950
|
+
|
|
2951
|
+
Args:
|
|
2952
|
+
val (Any): The value to convert.
|
|
2953
|
+
Returns:
|
|
2954
|
+
Any: The converted json value.
|
|
2955
|
+
"""
|
|
2956
|
+
if issubclass(type(val), JumpStartDataHolderType):
|
|
2957
|
+
if isinstance(val, JumpStartBenchmarkStat):
|
|
2958
|
+
val.name = val.name.replace("_", " ").title()
|
|
2959
|
+
return val.to_json()
|
|
2960
|
+
if isinstance(val, list):
|
|
2961
|
+
list_obj = []
|
|
2962
|
+
for obj in val:
|
|
2963
|
+
list_obj.append(self._val_to_json(obj))
|
|
2964
|
+
return list_obj
|
|
2965
|
+
if isinstance(val, dict):
|
|
2966
|
+
dict_obj = {}
|
|
2967
|
+
for k, v in val.items():
|
|
2968
|
+
if isinstance(v, JumpStartDataHolderType):
|
|
2969
|
+
dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v)
|
|
2970
|
+
else:
|
|
2971
|
+
dict_obj[k] = self._val_to_json(v)
|
|
2972
|
+
return dict_obj
|
|
2973
|
+
return val
|
|
2974
|
+
|
|
2975
|
+
|
|
2976
|
+
class DeploymentArgs(BaseDeploymentConfigDataHolder):
|
|
2977
|
+
"""Dataclass representing a Deployment Args."""
|
|
2978
|
+
|
|
2979
|
+
__slots__ = [
|
|
2980
|
+
"image_uri",
|
|
2981
|
+
"model_data",
|
|
2982
|
+
"model_package_arn",
|
|
2983
|
+
"environment",
|
|
2984
|
+
"instance_type",
|
|
2985
|
+
"compute_resource_requirements",
|
|
2986
|
+
"model_data_download_timeout",
|
|
2987
|
+
"container_startup_health_check_timeout",
|
|
2988
|
+
"additional_data_sources",
|
|
2989
|
+
]
|
|
2990
|
+
|
|
2991
|
+
def __init__(
|
|
2992
|
+
self,
|
|
2993
|
+
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
|
|
2994
|
+
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
|
|
2995
|
+
resolved_config: Optional[Dict[str, Any]] = None,
|
|
2996
|
+
):
|
|
2997
|
+
"""Instantiates DeploymentArgs object."""
|
|
2998
|
+
if init_kwargs is not None:
|
|
2999
|
+
self.image_uri = init_kwargs.image_uri
|
|
3000
|
+
self.model_data = init_kwargs.model_data
|
|
3001
|
+
self.model_package_arn = init_kwargs.model_package_arn
|
|
3002
|
+
self.instance_type = init_kwargs.instance_type
|
|
3003
|
+
self.environment = init_kwargs.env
|
|
3004
|
+
if init_kwargs.resources is not None:
|
|
3005
|
+
self.compute_resource_requirements = (
|
|
3006
|
+
init_kwargs.resources.get_compute_resource_requirements()
|
|
3007
|
+
)
|
|
3008
|
+
if deploy_kwargs is not None:
|
|
3009
|
+
self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout
|
|
3010
|
+
self.container_startup_health_check_timeout = (
|
|
3011
|
+
deploy_kwargs.container_startup_health_check_timeout
|
|
3012
|
+
)
|
|
3013
|
+
if resolved_config is not None:
|
|
3014
|
+
self.default_instance_type = resolved_config.get("default_inference_instance_type")
|
|
3015
|
+
self.supported_instance_types = resolved_config.get(
|
|
3016
|
+
"supported_inference_instance_types"
|
|
3017
|
+
)
|
|
3018
|
+
self.additional_data_sources = resolved_config.get("hosting_additional_data_sources")
|
|
3019
|
+
|
|
3020
|
+
|
|
3021
|
+
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
|
|
3022
|
+
"""Dataclass representing a Deployment Config Metadata"""
|
|
3023
|
+
|
|
3024
|
+
__slots__ = [
|
|
3025
|
+
"deployment_config_name",
|
|
3026
|
+
"deployment_args",
|
|
3027
|
+
"acceleration_configs",
|
|
3028
|
+
"benchmark_metrics",
|
|
3029
|
+
]
|
|
3030
|
+
|
|
3031
|
+
def __init__(
|
|
3032
|
+
self,
|
|
3033
|
+
config_name: Optional[str] = None,
|
|
3034
|
+
metadata_config: Optional[JumpStartMetadataConfig] = None,
|
|
3035
|
+
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
|
|
3036
|
+
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
|
|
3037
|
+
):
|
|
3038
|
+
"""Instantiates DeploymentConfigMetadata object."""
|
|
3039
|
+
self.deployment_config_name = config_name
|
|
3040
|
+
self.deployment_args = DeploymentArgs(
|
|
3041
|
+
init_kwargs, deploy_kwargs, metadata_config.resolved_config
|
|
3042
|
+
)
|
|
3043
|
+
self.benchmark_metrics = metadata_config.benchmark_metrics
|
|
3044
|
+
self.acceleration_configs = metadata_config.acceleration_configs
|