sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/__init__.py +0 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/__init__.py +0 -4
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.47.dist-info/RECORD +0 -35
- sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
# pylint: skip-file
|
|
14
|
+
"""This module contains utilities related to SageMaker JumpStart Hub."""
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
import re
|
|
18
|
+
from typing import Any, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def camel_to_snake(camel_case_string: str) -> str:
|
|
22
|
+
"""Converts camelCase to snake_case_string using a regex.
|
|
23
|
+
|
|
24
|
+
This regex cannot handle whitespace ("camelString TwoWords")
|
|
25
|
+
"""
|
|
26
|
+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def snake_to_upper_camel(snake_case_string: str) -> str:
|
|
30
|
+
"""Converts snake_case_string to UpperCamelCaseString."""
|
|
31
|
+
upper_camel_case_string = "".join(word.title() for word in snake_case_string.split("_"))
|
|
32
|
+
return upper_camel_case_string
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def walk_and_apply_json(
|
|
36
|
+
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
|
|
37
|
+
) -> Dict[Any, Any]:
|
|
38
|
+
"""Recursively walks a json object and applies a given function to the keys.
|
|
39
|
+
|
|
40
|
+
stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
|
|
41
|
+
Any children of these keys will not have the application function applied to them.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def _walk_and_apply_json(json_obj, new):
|
|
45
|
+
if isinstance(json_obj, dict) and isinstance(new, dict):
|
|
46
|
+
for key, value in json_obj.items():
|
|
47
|
+
new_key = apply(key)
|
|
48
|
+
if (stop_keys and new_key not in stop_keys) or stop_keys is None:
|
|
49
|
+
if isinstance(value, dict):
|
|
50
|
+
new[new_key] = {}
|
|
51
|
+
_walk_and_apply_json(value, new=new[new_key])
|
|
52
|
+
elif isinstance(value, list):
|
|
53
|
+
new[new_key] = []
|
|
54
|
+
for item in value:
|
|
55
|
+
_walk_and_apply_json(item, new=new[new_key])
|
|
56
|
+
else:
|
|
57
|
+
new[new_key] = value
|
|
58
|
+
else:
|
|
59
|
+
new[new_key] = value
|
|
60
|
+
elif isinstance(json_obj, dict) and isinstance(new, list):
|
|
61
|
+
new.append(_walk_and_apply_json(json_obj, new={}))
|
|
62
|
+
elif isinstance(json_obj, list) and isinstance(new, dict):
|
|
63
|
+
new.update(json_obj)
|
|
64
|
+
elif isinstance(json_obj, list) and isinstance(new, list):
|
|
65
|
+
new.append(json_obj)
|
|
66
|
+
elif isinstance(json_obj, str) and isinstance(new, list):
|
|
67
|
+
new.append(json_obj)
|
|
68
|
+
return new
|
|
69
|
+
|
|
70
|
+
return _walk_and_apply_json(json_obj, new={})
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
# pylint: skip-file
|
|
14
|
+
"""This module stores Hub converter utilities for JumpStart."""
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
from typing import Any, Dict, List
|
|
18
|
+
from sagemaker.core.jumpstart.enums import ModelSpecKwargType, NamingConventionType
|
|
19
|
+
from sagemaker.core.s3 import parse_s3_url
|
|
20
|
+
from sagemaker.core.jumpstart.types import (
|
|
21
|
+
JumpStartModelSpecs,
|
|
22
|
+
HubContentType,
|
|
23
|
+
JumpStartDataHolderType,
|
|
24
|
+
)
|
|
25
|
+
from sagemaker.core.jumpstart.hub.interfaces import (
|
|
26
|
+
DescribeHubContentResponse,
|
|
27
|
+
HubModelDocument,
|
|
28
|
+
)
|
|
29
|
+
from sagemaker.core.jumpstart.hub.parser_utils import (
|
|
30
|
+
camel_to_snake,
|
|
31
|
+
snake_to_upper_camel,
|
|
32
|
+
walk_and_apply_json,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]:
|
|
37
|
+
"""Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys"""
|
|
38
|
+
for key, value in dictionary.items():
|
|
39
|
+
if issubclass(type(value), JumpStartDataHolderType):
|
|
40
|
+
dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel)
|
|
41
|
+
elif isinstance(value, list):
|
|
42
|
+
new_value = []
|
|
43
|
+
for value_in_list in value:
|
|
44
|
+
new_value_in_list = value_in_list
|
|
45
|
+
if issubclass(type(value_in_list), JumpStartDataHolderType):
|
|
46
|
+
new_value_in_list = walk_and_apply_json(
|
|
47
|
+
value_in_list.to_json(), snake_to_upper_camel
|
|
48
|
+
)
|
|
49
|
+
new_value.append(new_value_in_list)
|
|
50
|
+
dictionary[key] = new_value
|
|
51
|
+
elif isinstance(value, dict):
|
|
52
|
+
for key_in_dict, value_in_dict in value.items():
|
|
53
|
+
if issubclass(type(value_in_dict), JumpStartDataHolderType):
|
|
54
|
+
value[key_in_dict] = walk_and_apply_json(
|
|
55
|
+
value_in_dict.to_json(), snake_to_upper_camel
|
|
56
|
+
)
|
|
57
|
+
return dictionary
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_model_spec_arg_keys(
|
|
61
|
+
arg_type: ModelSpecKwargType,
|
|
62
|
+
naming_convention: NamingConventionType = NamingConventionType.DEFAULT,
|
|
63
|
+
) -> List[str]:
|
|
64
|
+
"""Returns a list of arg keys for a specific model spec arg type.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
arg_type (ModelSpecKwargType): Type of the model spec's kwarg.
|
|
68
|
+
naming_convention (NamingConventionType): Type of naming convention to return.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValueError: If the naming convention is not valid.
|
|
72
|
+
"""
|
|
73
|
+
arg_keys: List[str] = []
|
|
74
|
+
if arg_type == ModelSpecKwargType.DEPLOY:
|
|
75
|
+
arg_keys = [
|
|
76
|
+
"ModelDataDownloadTimeout",
|
|
77
|
+
"ContainerStartupHealthCheckTimeout",
|
|
78
|
+
"InferenceAmiVersion",
|
|
79
|
+
]
|
|
80
|
+
elif arg_type == ModelSpecKwargType.ESTIMATOR:
|
|
81
|
+
arg_keys = [
|
|
82
|
+
"EncryptInterContainerTraffic",
|
|
83
|
+
"MaxRuntimeInSeconds",
|
|
84
|
+
"DisableOutputCompression",
|
|
85
|
+
"ModelDir",
|
|
86
|
+
]
|
|
87
|
+
elif arg_type == ModelSpecKwargType.MODEL:
|
|
88
|
+
arg_keys = []
|
|
89
|
+
elif arg_type == ModelSpecKwargType.FIT:
|
|
90
|
+
arg_keys = []
|
|
91
|
+
|
|
92
|
+
if naming_convention == NamingConventionType.SNAKE_CASE:
|
|
93
|
+
arg_keys = [camel_to_snake(key) for key in arg_keys]
|
|
94
|
+
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
|
|
95
|
+
return arg_keys
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError("Please provide a valid naming convention.")
|
|
98
|
+
return arg_keys
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_model_spec_kwargs_from_hub_model_document(
|
|
102
|
+
arg_type: ModelSpecKwargType,
|
|
103
|
+
hub_content_document: Dict[str, Any],
|
|
104
|
+
naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE,
|
|
105
|
+
) -> Dict[str, Any]:
|
|
106
|
+
"""Returns a map of arg type to arg keys for a given hub content document.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
arg_type (ModelSpecKwargType): Type of the model spec's kwarg.
|
|
110
|
+
hub_content_document: A dictionary representation of hub content document.
|
|
111
|
+
naming_convention (NamingConventionType): Type of naming convention to return.
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
kwargs = dict()
|
|
115
|
+
keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention)
|
|
116
|
+
for k in keys:
|
|
117
|
+
kwarg_value = hub_content_document.get(k)
|
|
118
|
+
if kwarg_value is not None:
|
|
119
|
+
kwargs[k] = kwarg_value
|
|
120
|
+
return kwargs
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def make_model_specs_from_describe_hub_content_response(
|
|
124
|
+
response: DescribeHubContentResponse,
|
|
125
|
+
) -> JumpStartModelSpecs:
|
|
126
|
+
"""Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
response (Dict[str, any]): parsed DescribeHubContentResponse returned
|
|
130
|
+
from SageMaker:DescribeHubContent
|
|
131
|
+
"""
|
|
132
|
+
if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}:
|
|
133
|
+
raise AttributeError(
|
|
134
|
+
"Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE."
|
|
135
|
+
)
|
|
136
|
+
region = response.get_hub_region()
|
|
137
|
+
specs = {}
|
|
138
|
+
model_id = response.hub_content_name
|
|
139
|
+
specs["model_id"] = model_id
|
|
140
|
+
specs["version"] = response.hub_content_version
|
|
141
|
+
hub_model_document: HubModelDocument = response.hub_content_document
|
|
142
|
+
specs["url"] = hub_model_document.url
|
|
143
|
+
specs["min_sdk_version"] = hub_model_document.min_sdk_version
|
|
144
|
+
specs["model_types"] = hub_model_document.model_types
|
|
145
|
+
specs["capabilities"] = hub_model_document.capabilities
|
|
146
|
+
specs["training_supported"] = bool(hub_model_document.training_supported)
|
|
147
|
+
specs["incremental_training_supported"] = bool(
|
|
148
|
+
hub_model_document.incremental_training_supported
|
|
149
|
+
)
|
|
150
|
+
specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri
|
|
151
|
+
specs["inference_configs"] = hub_model_document.inference_configs
|
|
152
|
+
specs["inference_config_components"] = hub_model_document.inference_config_components
|
|
153
|
+
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
|
|
154
|
+
|
|
155
|
+
if hub_model_document.hosting_artifact_uri:
|
|
156
|
+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
|
|
157
|
+
hub_model_document.hosting_artifact_uri
|
|
158
|
+
)
|
|
159
|
+
specs["hosting_artifact_key"] = hosting_artifact_key
|
|
160
|
+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
|
|
161
|
+
|
|
162
|
+
if hub_model_document.hosting_script_uri:
|
|
163
|
+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
|
|
164
|
+
hub_model_document.hosting_script_uri
|
|
165
|
+
)
|
|
166
|
+
specs["hosting_script_key"] = hosting_script_key
|
|
167
|
+
|
|
168
|
+
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
|
|
169
|
+
specs["inference_vulnerable"] = False
|
|
170
|
+
specs["inference_dependencies"] = hub_model_document.inference_dependencies
|
|
171
|
+
specs["inference_vulnerabilities"] = []
|
|
172
|
+
specs["training_vulnerable"] = False
|
|
173
|
+
specs["training_vulnerabilities"] = []
|
|
174
|
+
specs["deprecated"] = False
|
|
175
|
+
specs["deprecated_message"] = None
|
|
176
|
+
specs["deprecate_warn_message"] = None
|
|
177
|
+
specs["usage_info_message"] = None
|
|
178
|
+
specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type
|
|
179
|
+
specs["supported_inference_instance_types"] = (
|
|
180
|
+
hub_model_document.supported_inference_instance_types
|
|
181
|
+
)
|
|
182
|
+
specs["dynamic_container_deployment_supported"] = (
|
|
183
|
+
hub_model_document.dynamic_container_deployment_supported
|
|
184
|
+
)
|
|
185
|
+
specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements
|
|
186
|
+
|
|
187
|
+
specs["hosting_prepacked_artifact_key"] = None
|
|
188
|
+
if hub_model_document.hosting_prepacked_artifact_uri is not None:
|
|
189
|
+
(
|
|
190
|
+
hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable
|
|
191
|
+
hosting_prepacked_artifact_key,
|
|
192
|
+
) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri)
|
|
193
|
+
specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key
|
|
194
|
+
|
|
195
|
+
hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json()
|
|
196
|
+
|
|
197
|
+
specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
|
|
198
|
+
ModelSpecKwargType.FIT, hub_content_document_dict
|
|
199
|
+
)
|
|
200
|
+
specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
|
|
201
|
+
ModelSpecKwargType.MODEL, hub_content_document_dict
|
|
202
|
+
)
|
|
203
|
+
specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
|
|
204
|
+
ModelSpecKwargType.DEPLOY, hub_content_document_dict
|
|
205
|
+
)
|
|
206
|
+
specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document(
|
|
207
|
+
ModelSpecKwargType.ESTIMATOR, hub_content_document_dict
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications
|
|
211
|
+
default_payloads: Dict[str, Any] = {}
|
|
212
|
+
if hub_model_document.default_payloads is not None:
|
|
213
|
+
for alias, payload in hub_model_document.default_payloads.items():
|
|
214
|
+
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
|
|
215
|
+
specs["default_payloads"] = default_payloads
|
|
216
|
+
specs["gated_bucket"] = hub_model_document.gated_bucket
|
|
217
|
+
specs["inference_volume_size"] = hub_model_document.inference_volume_size
|
|
218
|
+
specs["inference_enable_network_isolation"] = (
|
|
219
|
+
hub_model_document.inference_enable_network_isolation
|
|
220
|
+
)
|
|
221
|
+
specs["resource_name_base"] = hub_model_document.resource_name_base
|
|
222
|
+
|
|
223
|
+
specs["hosting_eula_key"] = None
|
|
224
|
+
if hub_model_document.hosting_eula_uri is not None:
|
|
225
|
+
hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable
|
|
226
|
+
hub_model_document.hosting_eula_uri
|
|
227
|
+
)
|
|
228
|
+
specs["hosting_eula_key"] = hosting_eula_key
|
|
229
|
+
|
|
230
|
+
if hub_model_document.hosting_model_package_arn:
|
|
231
|
+
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}
|
|
232
|
+
|
|
233
|
+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
|
|
234
|
+
|
|
235
|
+
specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
|
|
236
|
+
|
|
237
|
+
specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants
|
|
238
|
+
|
|
239
|
+
if specs["training_supported"]:
|
|
240
|
+
specs["training_ecr_uri"] = hub_model_document.training_ecr_uri
|
|
241
|
+
(
|
|
242
|
+
training_artifact_bucket, # pylint: disable=unused-variable
|
|
243
|
+
training_artifact_key,
|
|
244
|
+
) = parse_s3_url(hub_model_document.training_artifact_uri)
|
|
245
|
+
specs["training_artifact_key"] = training_artifact_key
|
|
246
|
+
(
|
|
247
|
+
training_script_bucket, # pylint: disable=unused-variable
|
|
248
|
+
training_script_key,
|
|
249
|
+
) = parse_s3_url(hub_model_document.training_script_uri)
|
|
250
|
+
specs["training_script_key"] = training_script_key
|
|
251
|
+
|
|
252
|
+
specs["training_configs"] = hub_model_document.training_configs
|
|
253
|
+
specs["training_config_components"] = hub_model_document.training_config_components
|
|
254
|
+
specs["training_config_rankings"] = hub_model_document.training_config_rankings
|
|
255
|
+
|
|
256
|
+
specs["training_dependencies"] = hub_model_document.training_dependencies
|
|
257
|
+
specs["default_training_instance_type"] = hub_model_document.default_training_instance_type
|
|
258
|
+
specs["supported_training_instance_types"] = (
|
|
259
|
+
hub_model_document.supported_training_instance_types
|
|
260
|
+
)
|
|
261
|
+
specs["metrics"] = hub_model_document.training_metrics
|
|
262
|
+
specs["training_prepacked_script_key"] = None
|
|
263
|
+
if hub_model_document.training_prepacked_script_uri is not None:
|
|
264
|
+
(
|
|
265
|
+
training_prepacked_script_bucket, # pylint: disable=unused-variable
|
|
266
|
+
training_prepacked_script_key,
|
|
267
|
+
) = parse_s3_url(hub_model_document.training_prepacked_script_uri)
|
|
268
|
+
specs["training_prepacked_script_key"] = training_prepacked_script_key
|
|
269
|
+
|
|
270
|
+
specs["hyperparameters"] = hub_model_document.hyperparameters
|
|
271
|
+
specs["training_volume_size"] = hub_model_document.training_volume_size
|
|
272
|
+
specs["training_enable_network_isolation"] = (
|
|
273
|
+
hub_model_document.training_enable_network_isolation
|
|
274
|
+
)
|
|
275
|
+
if hub_model_document.training_model_package_artifact_uri:
|
|
276
|
+
specs["training_model_package_artifact_uris"] = {
|
|
277
|
+
region: hub_model_document.training_model_package_artifact_uri
|
|
278
|
+
}
|
|
279
|
+
specs["training_instance_type_variants"] = (
|
|
280
|
+
hub_model_document.training_instance_type_variants
|
|
281
|
+
)
|
|
282
|
+
if hub_model_document.default_training_dataset_uri:
|
|
283
|
+
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
|
|
284
|
+
hub_model_document.default_training_dataset_uri
|
|
285
|
+
)
|
|
286
|
+
specs["default_training_dataset_key"] = default_training_dataset_key
|
|
287
|
+
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
|
|
288
|
+
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module stores types related to SageMaker JumpStart Hub."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from typing import Dict
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class S3ObjectLocation:
|
|
21
|
+
"""Helper class for S3 object references."""
|
|
22
|
+
|
|
23
|
+
bucket: str
|
|
24
|
+
key: str
|
|
25
|
+
|
|
26
|
+
def format_for_s3_copy(self) -> Dict[str, str]:
|
|
27
|
+
"""Returns a dict formatted for S3 copy calls"""
|
|
28
|
+
return {
|
|
29
|
+
"Bucket": self.bucket,
|
|
30
|
+
"Key": self.key,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
def get_uri(self) -> str:
|
|
34
|
+
"""Returns the s3 URI"""
|
|
35
|
+
return f"s3://{self.bucket}/{self.key}"
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
# pylint: skip-file
|
|
14
|
+
"""This module contains utilities related to SageMaker JumpStart Hub."""
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
import re
|
|
17
|
+
from typing import Optional, List, Any
|
|
18
|
+
from sagemaker.core.helper.session_helper import Session
|
|
19
|
+
from sagemaker.core.common_utils import aws_partition
|
|
20
|
+
from sagemaker.core.jumpstart.types import HubContentType, HubArnExtractedInfo
|
|
21
|
+
from sagemaker.core.jumpstart import constants
|
|
22
|
+
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
|
23
|
+
from packaging import version
|
|
24
|
+
|
|
25
|
+
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _convert_str_to_optional(string: str) -> Optional[str]:
|
|
29
|
+
if string == "None":
|
|
30
|
+
string = None
|
|
31
|
+
return string
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_info_from_hub_resource_arn(
|
|
35
|
+
arn: str,
|
|
36
|
+
) -> HubArnExtractedInfo:
|
|
37
|
+
"""Extracts descriptive information from a Hub or HubContent Arn."""
|
|
38
|
+
|
|
39
|
+
match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
|
|
40
|
+
if match:
|
|
41
|
+
partition = match.group(1)
|
|
42
|
+
hub_region = match.group(2)
|
|
43
|
+
account_id = match.group(3)
|
|
44
|
+
hub_name = match.group(4)
|
|
45
|
+
hub_content_type = match.group(5)
|
|
46
|
+
hub_content_name = match.group(6)
|
|
47
|
+
hub_content_version = _convert_str_to_optional(match.group(7))
|
|
48
|
+
|
|
49
|
+
return HubArnExtractedInfo(
|
|
50
|
+
partition=partition,
|
|
51
|
+
region=hub_region,
|
|
52
|
+
account_id=account_id,
|
|
53
|
+
hub_name=hub_name,
|
|
54
|
+
hub_content_type=hub_content_type,
|
|
55
|
+
hub_content_name=hub_content_name,
|
|
56
|
+
hub_content_version=hub_content_version,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
match = re.match(constants.HUB_ARN_REGEX, arn)
|
|
60
|
+
if match:
|
|
61
|
+
partition = match.group(1)
|
|
62
|
+
hub_region = match.group(2)
|
|
63
|
+
account_id = match.group(3)
|
|
64
|
+
hub_name = match.group(4)
|
|
65
|
+
return HubArnExtractedInfo(
|
|
66
|
+
partition=partition,
|
|
67
|
+
region=hub_region,
|
|
68
|
+
account_id=account_id,
|
|
69
|
+
hub_name=hub_name,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def construct_hub_arn_from_name(
|
|
74
|
+
hub_name: str,
|
|
75
|
+
region: Optional[str] = None,
|
|
76
|
+
session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
77
|
+
account_id: Optional[str] = None,
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Constructs a Hub arn from the Hub name using default Session values."""
|
|
80
|
+
if session is None:
|
|
81
|
+
# session is overridden to none by some callers
|
|
82
|
+
session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
83
|
+
|
|
84
|
+
account_id = account_id or session.account_id()
|
|
85
|
+
region = region or session.boto_region_name
|
|
86
|
+
partition = aws_partition(region)
|
|
87
|
+
|
|
88
|
+
return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str:
|
|
92
|
+
"""Constructs a HubContent model arn from the Hub name, model name, and model version."""
|
|
93
|
+
|
|
94
|
+
info = get_info_from_hub_resource_arn(hub_arn)
|
|
95
|
+
arn = (
|
|
96
|
+
f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/"
|
|
97
|
+
f"{info.hub_name}/{HubContentType.MODEL.value}/{model_name}/{version}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return arn
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def construct_hub_model_reference_arn_from_inputs(
|
|
104
|
+
hub_arn: str, model_name: str, version: str
|
|
105
|
+
) -> str:
|
|
106
|
+
"""Constructs a HubContent model arn from the Hub name, model name, and model version."""
|
|
107
|
+
|
|
108
|
+
info = get_info_from_hub_resource_arn(hub_arn)
|
|
109
|
+
arn = (
|
|
110
|
+
f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/"
|
|
111
|
+
f"{info.hub_name}/{HubContentType.MODEL_REFERENCE.value}/{model_name}/{version}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return arn
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def generate_hub_arn_for_init_kwargs(
|
|
118
|
+
hub_name: str, region: Optional[str] = None, session: Optional[Session] = None
|
|
119
|
+
):
|
|
120
|
+
"""Generates the Hub Arn for JumpStart class args from a HubName or Arn.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
hub_name (str): HubName or HubArn from JumpStart class args
|
|
124
|
+
region (str): Region from JumpStart class args
|
|
125
|
+
session (Session): Custom SageMaker Session from JumpStart class args
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
hub_arn = None
|
|
129
|
+
if hub_name:
|
|
130
|
+
if hub_name == constants.JUMPSTART_MODEL_HUB_NAME:
|
|
131
|
+
return None
|
|
132
|
+
match = re.match(constants.HUB_ARN_REGEX, hub_name)
|
|
133
|
+
if match:
|
|
134
|
+
hub_arn = hub_name
|
|
135
|
+
else:
|
|
136
|
+
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
|
|
137
|
+
return hub_arn
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def is_gated_bucket(bucket_name: str) -> bool:
|
|
141
|
+
"""Returns true if the bucket name is the JumpStart gated bucket."""
|
|
142
|
+
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_hub_model_version(
|
|
146
|
+
hub_name: str,
|
|
147
|
+
hub_model_name: str,
|
|
148
|
+
hub_model_type: str,
|
|
149
|
+
hub_model_version: Optional[str] = None,
|
|
150
|
+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
151
|
+
) -> str:
|
|
152
|
+
"""Returns available Jumpstart hub model version.
|
|
153
|
+
|
|
154
|
+
It will attempt both a semantic HubContent version search and Marketplace version search.
|
|
155
|
+
If the Marketplace version is also semantic, this function will default to HubContent version.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
ClientError: If the specified model is not found in the hub.
|
|
159
|
+
KeyError: If the specified model version is not found.
|
|
160
|
+
"""
|
|
161
|
+
if sagemaker_session is None:
|
|
162
|
+
# sagemaker_session is overridden to none by some callers
|
|
163
|
+
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
hub_content_summaries = _list_hub_content_versions_helper(
|
|
167
|
+
hub_name=hub_name,
|
|
168
|
+
hub_content_name=hub_model_name,
|
|
169
|
+
hub_content_type=hub_model_type,
|
|
170
|
+
sagemaker_session=sagemaker_session,
|
|
171
|
+
)
|
|
172
|
+
except Exception as ex:
|
|
173
|
+
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
return _get_hub_model_version_for_open_weight_version(
|
|
177
|
+
hub_content_summaries, hub_model_version
|
|
178
|
+
)
|
|
179
|
+
except KeyError:
|
|
180
|
+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
|
|
181
|
+
hub_content_summaries, hub_model_version
|
|
182
|
+
)
|
|
183
|
+
if marketplace_hub_content_version:
|
|
184
|
+
return marketplace_hub_content_version
|
|
185
|
+
raise
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _list_hub_content_versions_helper(
|
|
189
|
+
hub_name, hub_content_name, hub_content_type, sagemaker_session
|
|
190
|
+
):
|
|
191
|
+
all_hub_content_summaries = []
|
|
192
|
+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
|
|
193
|
+
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
|
|
194
|
+
)
|
|
195
|
+
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
|
|
196
|
+
while "NextToken" in list_hub_content_versions_response:
|
|
197
|
+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
|
|
198
|
+
hub_name=hub_name,
|
|
199
|
+
hub_content_name=hub_content_name,
|
|
200
|
+
hub_content_type=hub_content_type,
|
|
201
|
+
next_token=list_hub_content_versions_response["NextToken"],
|
|
202
|
+
)
|
|
203
|
+
all_hub_content_summaries.extend(
|
|
204
|
+
list_hub_content_versions_response.get("HubContentSummaries")
|
|
205
|
+
)
|
|
206
|
+
return all_hub_content_summaries
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _get_hub_model_version_for_open_weight_version(
|
|
210
|
+
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
|
|
211
|
+
) -> str:
|
|
212
|
+
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
|
|
213
|
+
|
|
214
|
+
if hub_model_version == "*" or hub_model_version is None:
|
|
215
|
+
return str(max(version.parse(v) for v in available_model_versions))
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
spec = SpecifierSet(f"=={hub_model_version}")
|
|
219
|
+
except InvalidSpecifier:
|
|
220
|
+
raise KeyError(f"Bad semantic version: {hub_model_version}")
|
|
221
|
+
available_versions_filtered = list(spec.filter(available_model_versions))
|
|
222
|
+
if not available_versions_filtered:
|
|
223
|
+
raise KeyError("Model version not available in the Hub")
|
|
224
|
+
hub_model_version = str(max(available_versions_filtered))
|
|
225
|
+
|
|
226
|
+
return hub_model_version
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _get_hub_model_version_for_marketplace_version(
|
|
230
|
+
hub_content_summaries: List[Any], marketplace_version: str
|
|
231
|
+
) -> Optional[str]:
|
|
232
|
+
"""Returns the HubContent version associated with the Marketplace version.
|
|
233
|
+
|
|
234
|
+
This function will check within the HubContentSearchKeywords for the proprietary version.
|
|
235
|
+
"""
|
|
236
|
+
for model in hub_content_summaries:
|
|
237
|
+
model_search_keywords = model.get("HubContentSearchKeywords", [])
|
|
238
|
+
if _hub_search_keywords_contains_marketplace_version(
|
|
239
|
+
model_search_keywords, marketplace_version
|
|
240
|
+
):
|
|
241
|
+
return model.get("HubContentVersion")
|
|
242
|
+
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _hub_search_keywords_contains_marketplace_version(
|
|
247
|
+
model_search_keywords: List[str], marketplace_version: str
|
|
248
|
+
) -> bool:
|
|
249
|
+
proprietary_version_keyword = next(
|
|
250
|
+
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if not proprietary_version_keyword:
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
|
|
257
|
+
if proprietary_version == marketplace_version:
|
|
258
|
+
return True
|
|
259
|
+
|
|
260
|
+
return False
|