sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module defines the JumpStartModelsCache class."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
import datetime
|
|
16
|
+
from difflib import get_close_matches
|
|
17
|
+
import os
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
19
|
+
import json
|
|
20
|
+
import boto3
|
|
21
|
+
import botocore
|
|
22
|
+
from packaging.version import Version
|
|
23
|
+
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
|
24
|
+
from sagemaker.core.jumpstart.constants import (
|
|
25
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
26
|
+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
|
|
27
|
+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
|
|
28
|
+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
|
|
29
|
+
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
|
|
30
|
+
JUMPSTART_LOGGER,
|
|
31
|
+
MODEL_ID_LIST_WEB_URL,
|
|
32
|
+
MODEL_TYPE_TO_MANIFEST_MAP,
|
|
33
|
+
MODEL_TYPE_TO_SPECS_MAP,
|
|
34
|
+
)
|
|
35
|
+
from sagemaker.core.jumpstart.exceptions import (
|
|
36
|
+
get_wildcard_model_version_msg,
|
|
37
|
+
get_wildcard_proprietary_model_version_msg,
|
|
38
|
+
)
|
|
39
|
+
from sagemaker.core.jumpstart.parameters import (
|
|
40
|
+
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
|
|
41
|
+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
|
|
42
|
+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
|
|
43
|
+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
|
|
44
|
+
)
|
|
45
|
+
from sagemaker.core.jumpstart.types import (
|
|
46
|
+
JumpStartCachedContentKey,
|
|
47
|
+
JumpStartCachedContentValue,
|
|
48
|
+
JumpStartModelHeader,
|
|
49
|
+
JumpStartModelSpecs,
|
|
50
|
+
JumpStartS3FileType,
|
|
51
|
+
JumpStartVersionedModelId,
|
|
52
|
+
HubContentType,
|
|
53
|
+
)
|
|
54
|
+
from sagemaker.core.jumpstart.hub import utils as hub_utils
|
|
55
|
+
from sagemaker.core.jumpstart.hub.interfaces import DescribeHubContentResponse
|
|
56
|
+
from sagemaker.core.jumpstart.hub.parsers import (
|
|
57
|
+
make_model_specs_from_describe_hub_content_response,
|
|
58
|
+
)
|
|
59
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
60
|
+
from sagemaker.core.jumpstart import utils
|
|
61
|
+
from sagemaker.core.utilities.cache import LRUCache
|
|
62
|
+
from sagemaker.core.helper.session_helper import Session
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class JumpStartModelsCache:
|
|
66
|
+
"""Class that implements a cache for JumpStart models manifests and specs.
|
|
67
|
+
|
|
68
|
+
The manifest and specs associated with JumpStart models provide the information necessary
|
|
69
|
+
for launching JumpStart models from the SageMaker SDK.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
region: Optional[str] = None,
|
|
75
|
+
max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
|
|
76
|
+
s3_cache_expiration_horizon: datetime.timedelta = (
|
|
77
|
+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON
|
|
78
|
+
),
|
|
79
|
+
max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
|
|
80
|
+
semantic_version_cache_expiration_horizon: datetime.timedelta = (
|
|
81
|
+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON
|
|
82
|
+
),
|
|
83
|
+
manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
|
|
84
|
+
proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
|
|
85
|
+
s3_bucket_name: Optional[str] = None,
|
|
86
|
+
s3_client_config: Optional[botocore.config.Config] = None,
|
|
87
|
+
s3_client: Optional[boto3.client] = None,
|
|
88
|
+
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Initialize a ``JumpStartModelsCache`` instance.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
region (str): AWS region to associate with cache. Default: region associated
|
|
94
|
+
with boto3 session.
|
|
95
|
+
max_s3_cache_items (int): Maximum number of items to store in s3 cache.
|
|
96
|
+
Default: 20.
|
|
97
|
+
s3_cache_expiration_horizon (datetime.timedelta): Maximum time to hold
|
|
98
|
+
items in s3 cache before invalidation. Default: 6 hours.
|
|
99
|
+
max_semantic_version_cache_items (int): Maximum number of items to store in
|
|
100
|
+
semantic version cache. Default: 20.
|
|
101
|
+
semantic_version_cache_expiration_horizon (datetime.timedelta):
|
|
102
|
+
Maximum time to hold items in semantic version cache before invalidation.
|
|
103
|
+
Default: 6 hours.
|
|
104
|
+
manifest_file_s3_key (str): The key in S3 corresponding to the sdk metadata manifest.
|
|
105
|
+
s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
|
|
106
|
+
Default: JumpStart-hosted content bucket for region.
|
|
107
|
+
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
|
|
108
|
+
Default: None (no config).
|
|
109
|
+
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
|
|
110
|
+
sagemaker_session: sagemaker session object to use.
|
|
111
|
+
Default: session object from default region us-west-2.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
self._region = region or utils.get_region_fallback(
|
|
115
|
+
s3_bucket_name=s3_bucket_name, s3_client=s3_client
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
|
|
119
|
+
max_cache_items=max_s3_cache_items,
|
|
120
|
+
expiration_horizon=s3_cache_expiration_horizon,
|
|
121
|
+
retrieval_function=self._retrieval_function,
|
|
122
|
+
)
|
|
123
|
+
self._open_weight_model_id_manifest_key_cache = LRUCache[
|
|
124
|
+
JumpStartVersionedModelId, JumpStartVersionedModelId
|
|
125
|
+
](
|
|
126
|
+
max_cache_items=max_semantic_version_cache_items,
|
|
127
|
+
expiration_horizon=semantic_version_cache_expiration_horizon,
|
|
128
|
+
retrieval_function=self._get_open_weight_manifest_key_from_model_id,
|
|
129
|
+
)
|
|
130
|
+
self._proprietary_model_id_manifest_key_cache = LRUCache[
|
|
131
|
+
JumpStartVersionedModelId, JumpStartVersionedModelId
|
|
132
|
+
](
|
|
133
|
+
max_cache_items=max_semantic_version_cache_items,
|
|
134
|
+
expiration_horizon=semantic_version_cache_expiration_horizon,
|
|
135
|
+
retrieval_function=self._get_proprietary_manifest_key_from_model_id,
|
|
136
|
+
)
|
|
137
|
+
self._manifest_file_s3_key = manifest_file_s3_key
|
|
138
|
+
self._proprietary_manifest_s3_key = proprietary_manifest_s3_key
|
|
139
|
+
self._manifest_file_s3_map = {
|
|
140
|
+
JumpStartModelType.OPEN_WEIGHTS: self._manifest_file_s3_key,
|
|
141
|
+
JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key,
|
|
142
|
+
}
|
|
143
|
+
self.s3_bucket_name = (
|
|
144
|
+
utils.get_jumpstart_content_bucket(self._region)
|
|
145
|
+
if s3_bucket_name is None
|
|
146
|
+
else s3_bucket_name
|
|
147
|
+
)
|
|
148
|
+
self._s3_client = s3_client or (
|
|
149
|
+
boto3.client("s3", region_name=self._region, config=s3_client_config)
|
|
150
|
+
if s3_client_config
|
|
151
|
+
else boto3.client("s3", region_name=self._region)
|
|
152
|
+
)
|
|
153
|
+
# Fallback in case a caller overrides sagemaker_session to None
|
|
154
|
+
self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
155
|
+
|
|
156
|
+
def set_region(self, region: str) -> None:
|
|
157
|
+
"""Set region for cache. Clears cache after new region is set."""
|
|
158
|
+
if region != self._region:
|
|
159
|
+
self._region = region
|
|
160
|
+
self.clear()
|
|
161
|
+
|
|
162
|
+
def get_region(self) -> str:
|
|
163
|
+
"""Return region for cache."""
|
|
164
|
+
return self._region
|
|
165
|
+
|
|
166
|
+
def set_manifest_file_s3_key(
|
|
167
|
+
self,
|
|
168
|
+
key: str,
|
|
169
|
+
file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
|
|
170
|
+
) -> None:
|
|
171
|
+
"""Set manifest file s3 key, clear cache after new key is set.
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
ValueError: if the file type is not recognized
|
|
175
|
+
"""
|
|
176
|
+
file_mapping = {
|
|
177
|
+
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: self._manifest_file_s3_key,
|
|
178
|
+
JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key,
|
|
179
|
+
}
|
|
180
|
+
property_name = file_mapping.get(file_type)
|
|
181
|
+
if not property_name:
|
|
182
|
+
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
|
|
183
|
+
if key != property_name:
|
|
184
|
+
setattr(self, property_name, key)
|
|
185
|
+
self.clear()
|
|
186
|
+
|
|
187
|
+
def get_manifest_file_s3_key(
|
|
188
|
+
self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST
|
|
189
|
+
) -> str:
|
|
190
|
+
"""Return manifest file s3 key for cache."""
|
|
191
|
+
if file_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
|
|
192
|
+
return self._manifest_file_s3_key
|
|
193
|
+
if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST:
|
|
194
|
+
return self._proprietary_manifest_s3_key
|
|
195
|
+
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
|
|
196
|
+
|
|
197
|
+
def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
|
|
198
|
+
"""Set s3 bucket used for cache."""
|
|
199
|
+
if s3_bucket_name != self.s3_bucket_name:
|
|
200
|
+
self.s3_bucket_name = s3_bucket_name
|
|
201
|
+
self.clear()
|
|
202
|
+
|
|
203
|
+
def get_bucket(self) -> str:
|
|
204
|
+
"""Return bucket used for cache."""
|
|
205
|
+
return self.s3_bucket_name
|
|
206
|
+
|
|
207
|
+
def _file_type_error_msg(self, file_type: str, manifest_only: bool = False) -> str:
|
|
208
|
+
"""Return error message for bad model type."""
|
|
209
|
+
if manifest_only:
|
|
210
|
+
return (
|
|
211
|
+
f"Bad value when getting manifest '{file_type}': "
|
|
212
|
+
f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST} "
|
|
213
|
+
f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}."
|
|
214
|
+
)
|
|
215
|
+
return (
|
|
216
|
+
f"Bad value when getting manifest '{file_type}': "
|
|
217
|
+
f"must be in '{' '.join([e.name for e in JumpStartS3FileType])}'."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def _model_id_retrieval_function(
|
|
221
|
+
self,
|
|
222
|
+
key: JumpStartVersionedModelId,
|
|
223
|
+
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
|
|
224
|
+
model_type: JumpStartModelType,
|
|
225
|
+
) -> JumpStartVersionedModelId:
|
|
226
|
+
"""Return model ID and version in manifest that matches semantic version/id.
|
|
227
|
+
|
|
228
|
+
Uses ``packaging.version`` to perform version comparison. The highest model version
|
|
229
|
+
matching the semantic version is used, which is compatible with the SageMaker
|
|
230
|
+
version.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
key (JumpStartVersionedModelId): Key for which to fetch versioned model ID.
|
|
234
|
+
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
|
|
235
|
+
old cached model ID/version.
|
|
236
|
+
model_type (JumpStartModelType): JumpStart model type to indicate whether it is
|
|
237
|
+
open weights model or proprietary (Marketplace) model.
|
|
238
|
+
|
|
239
|
+
Raises:
|
|
240
|
+
KeyError: If the semantic version is not found in the manifest, or is found but
|
|
241
|
+
the SageMaker version needs to be upgraded in order for the model to be used.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
model_id, version = key.model_id, key.version
|
|
245
|
+
sm_version = utils.get_sagemaker_version()
|
|
246
|
+
manifest = self._content_cache.get(
|
|
247
|
+
JumpStartCachedContentKey(
|
|
248
|
+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
|
|
249
|
+
)
|
|
250
|
+
)[0].formatted_content
|
|
251
|
+
|
|
252
|
+
versions_compatible_with_sagemaker = [
|
|
253
|
+
header.version
|
|
254
|
+
for header in manifest.values() # type: ignore
|
|
255
|
+
if header.model_id == model_id and Version(header.min_version) <= Version(sm_version)
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
sm_compatible_model_version = self._select_version(
|
|
259
|
+
model_id, version, versions_compatible_with_sagemaker, model_type
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if sm_compatible_model_version is not None:
|
|
263
|
+
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
|
|
264
|
+
|
|
265
|
+
versions_incompatible_with_sagemaker = [
|
|
266
|
+
header.version
|
|
267
|
+
for header in manifest.values() # type: ignore
|
|
268
|
+
if header.model_id == model_id
|
|
269
|
+
]
|
|
270
|
+
sm_incompatible_model_version = self._select_version(
|
|
271
|
+
model_id, version, versions_incompatible_with_sagemaker, model_type
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if sm_incompatible_model_version is not None:
|
|
275
|
+
model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
|
|
276
|
+
sm_version_to_use_list = [
|
|
277
|
+
header.min_version
|
|
278
|
+
for header in manifest.values() # type: ignore
|
|
279
|
+
if header.model_id == model_id
|
|
280
|
+
and header.version == model_version_to_use_incompatible_with_sagemaker
|
|
281
|
+
]
|
|
282
|
+
if len(sm_version_to_use_list) != 1:
|
|
283
|
+
# ``manifest`` dict should already enforce this
|
|
284
|
+
raise RuntimeError("Found more than one incompatible SageMaker version to use.")
|
|
285
|
+
sm_version_to_use = sm_version_to_use_list[0]
|
|
286
|
+
|
|
287
|
+
error_msg = (
|
|
288
|
+
f"Unable to find model manifest for '{model_id}' with version '{version}' "
|
|
289
|
+
f"compatible with your SageMaker version ('{sm_version}'). "
|
|
290
|
+
f"Consider upgrading your SageMaker library to at least version "
|
|
291
|
+
f"'{sm_version_to_use}' so you can use version "
|
|
292
|
+
f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'."
|
|
293
|
+
)
|
|
294
|
+
raise KeyError(error_msg)
|
|
295
|
+
|
|
296
|
+
error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
|
|
297
|
+
error_msg += "Specify a different model ID or try a different AWS Region. "
|
|
298
|
+
error_msg += f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "
|
|
299
|
+
|
|
300
|
+
other_model_id_version = None
|
|
301
|
+
if model_type == JumpStartModelType.OPEN_WEIGHTS:
|
|
302
|
+
other_model_id_version = self._select_version(
|
|
303
|
+
model_id, "*", versions_incompatible_with_sagemaker, model_type
|
|
304
|
+
) # all versions here are incompatible with sagemaker
|
|
305
|
+
elif model_type == JumpStartModelType.PROPRIETARY:
|
|
306
|
+
all_possible_model_id_version = [
|
|
307
|
+
header.version
|
|
308
|
+
for header in manifest.values() # type: ignore
|
|
309
|
+
if header.model_id == model_id
|
|
310
|
+
]
|
|
311
|
+
other_model_id_version = (
|
|
312
|
+
None if not all_possible_model_id_version else all_possible_model_id_version[0]
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if other_model_id_version is not None:
|
|
316
|
+
error_msg += (
|
|
317
|
+
f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'."
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
|
|
321
|
+
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
|
|
322
|
+
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
|
|
323
|
+
|
|
324
|
+
raise KeyError(error_msg)
|
|
325
|
+
|
|
326
|
+
def _get_open_weight_manifest_key_from_model_id(
|
|
327
|
+
self,
|
|
328
|
+
key: JumpStartVersionedModelId,
|
|
329
|
+
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
|
|
330
|
+
) -> JumpStartVersionedModelId:
|
|
331
|
+
"""For open weights models, retrieve model manifest key for open weight model.
|
|
332
|
+
|
|
333
|
+
Filters models list by supported versions.
|
|
334
|
+
"""
|
|
335
|
+
return self._model_id_retrieval_function(
|
|
336
|
+
key, value, model_type=JumpStartModelType.OPEN_WEIGHTS
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
def _get_proprietary_manifest_key_from_model_id(
|
|
340
|
+
self,
|
|
341
|
+
key: JumpStartVersionedModelId,
|
|
342
|
+
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
|
|
343
|
+
) -> JumpStartVersionedModelId:
|
|
344
|
+
"""For proprietary models, retrieve model manifest key for proprietary model.
|
|
345
|
+
|
|
346
|
+
Filters models list by supported versions.
|
|
347
|
+
"""
|
|
348
|
+
return self._model_id_retrieval_function(
|
|
349
|
+
key, value, model_type=JumpStartModelType.PROPRIETARY
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]:
|
|
353
|
+
"""Returns json file from s3, along with its etag."""
|
|
354
|
+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key)
|
|
355
|
+
return json.loads(response["Body"].read().decode("utf-8")), response["ETag"]
|
|
356
|
+
|
|
357
|
+
def _is_local_metadata_mode(self) -> bool:
|
|
358
|
+
"""Returns True if the cache should use local metadata mode, based off env variables."""
|
|
359
|
+
return (
|
|
360
|
+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
|
|
361
|
+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
|
|
362
|
+
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
|
|
363
|
+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def _get_json_file(
|
|
367
|
+
self, key: str, filetype: JumpStartS3FileType
|
|
368
|
+
) -> Tuple[Union[dict, list], Optional[str]]:
|
|
369
|
+
"""Returns json file either from s3 or local file system.
|
|
370
|
+
|
|
371
|
+
Returns etag along with json object for s3, or just the json
|
|
372
|
+
object and None when reading from the local file system.
|
|
373
|
+
"""
|
|
374
|
+
if self._is_local_metadata_mode():
|
|
375
|
+
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
|
|
376
|
+
else:
|
|
377
|
+
file_content, etag = self._get_json_file_and_etag_from_s3(key)
|
|
378
|
+
return file_content, etag
|
|
379
|
+
|
|
380
|
+
def _get_json_md5_hash(self, key: str):
|
|
381
|
+
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.
|
|
382
|
+
|
|
383
|
+
Raises:
|
|
384
|
+
ValueError: if the cache should use local metadata mode.
|
|
385
|
+
"""
|
|
386
|
+
if self._is_local_metadata_mode():
|
|
387
|
+
raise ValueError("Cannot get md5 hash of local file.")
|
|
388
|
+
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
|
|
389
|
+
|
|
390
|
+
def _get_json_file_from_local_override(
|
|
391
|
+
self, key: str, filetype: JumpStartS3FileType
|
|
392
|
+
) -> Union[dict, list]:
|
|
393
|
+
"""Reads json file from local filesystem and returns data."""
|
|
394
|
+
if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
|
|
395
|
+
metadata_local_root = os.environ[
|
|
396
|
+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
|
|
397
|
+
]
|
|
398
|
+
elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS:
|
|
399
|
+
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
|
|
400
|
+
else:
|
|
401
|
+
raise ValueError(f"Unsupported file type for local override: {filetype}")
|
|
402
|
+
file_path = os.path.join(metadata_local_root, key)
|
|
403
|
+
with open(file_path, "r") as f:
|
|
404
|
+
data = json.load(f)
|
|
405
|
+
return data
|
|
406
|
+
|
|
407
|
+
def _retrieval_function(
|
|
408
|
+
self,
|
|
409
|
+
key: JumpStartCachedContentKey,
|
|
410
|
+
value: Optional[JumpStartCachedContentValue],
|
|
411
|
+
) -> JumpStartCachedContentValue:
|
|
412
|
+
"""Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``.
|
|
413
|
+
|
|
414
|
+
If a manifest file is being fetched, we only download the object if the md5 hash in
|
|
415
|
+
``head_object`` does not match the current md5 hash for the stored value. This prevents
|
|
416
|
+
unnecessarily downloading the full manifest when it hasn't changed.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
key (JumpStartCachedContentKey): key for which to fetch s3 content.
|
|
420
|
+
value (Optional[JumpStartVersionedModelId]): Current value of old cached
|
|
421
|
+
s3 content. This is used for the manifest file, so that it is only
|
|
422
|
+
downloaded when its content changes.
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
data_type, id_info = key.data_type, key.id_info
|
|
426
|
+
|
|
427
|
+
if data_type in {
|
|
428
|
+
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
|
|
429
|
+
JumpStartS3FileType.PROPRIETARY_MANIFEST,
|
|
430
|
+
}:
|
|
431
|
+
if value is not None and not self._is_local_metadata_mode():
|
|
432
|
+
etag = self._get_json_md5_hash(id_info)
|
|
433
|
+
if etag == value.md5_hash:
|
|
434
|
+
return value
|
|
435
|
+
formatted_body, etag = self._get_json_file(id_info, data_type)
|
|
436
|
+
return JumpStartCachedContentValue(
|
|
437
|
+
formatted_content=utils.get_formatted_manifest(formatted_body),
|
|
438
|
+
md5_hash=etag,
|
|
439
|
+
)
|
|
440
|
+
if data_type in {
|
|
441
|
+
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
|
|
442
|
+
JumpStartS3FileType.PROPRIETARY_SPECS,
|
|
443
|
+
}:
|
|
444
|
+
formatted_body, _ = self._get_json_file(id_info, data_type)
|
|
445
|
+
model_specs = JumpStartModelSpecs(formatted_body)
|
|
446
|
+
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
|
|
447
|
+
return JumpStartCachedContentValue(formatted_content=model_specs)
|
|
448
|
+
|
|
449
|
+
if data_type == HubContentType.NOTEBOOK:
|
|
450
|
+
hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn(
|
|
451
|
+
id_info
|
|
452
|
+
)
|
|
453
|
+
response: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
|
|
454
|
+
hub_name=hub_name,
|
|
455
|
+
hub_content_name=notebook_name,
|
|
456
|
+
hub_content_version=notebook_version,
|
|
457
|
+
hub_content_type=data_type,
|
|
458
|
+
)
|
|
459
|
+
hub_notebook_description = DescribeHubContentResponse(response)
|
|
460
|
+
return JumpStartCachedContentValue(formatted_content=hub_notebook_description)
|
|
461
|
+
|
|
462
|
+
if data_type in {
|
|
463
|
+
HubContentType.MODEL,
|
|
464
|
+
HubContentType.MODEL_REFERENCE,
|
|
465
|
+
}:
|
|
466
|
+
|
|
467
|
+
hub_resource_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info)
|
|
468
|
+
hub_arn = hub_utils.construct_hub_arn_from_name(
|
|
469
|
+
hub_name=hub_resource_arn_extracted_info.hub_name,
|
|
470
|
+
region=hub_resource_arn_extracted_info.region,
|
|
471
|
+
account_id=hub_resource_arn_extracted_info.account_id,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
model_version: str = hub_utils.get_hub_model_version(
|
|
475
|
+
hub_model_name=hub_resource_arn_extracted_info.hub_content_name,
|
|
476
|
+
hub_model_type=data_type.value,
|
|
477
|
+
hub_name=hub_arn,
|
|
478
|
+
sagemaker_session=self._sagemaker_session,
|
|
479
|
+
hub_model_version=hub_resource_arn_extracted_info.hub_content_version,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
|
|
483
|
+
hub_name=hub_arn,
|
|
484
|
+
hub_content_name=hub_resource_arn_extracted_info.hub_content_name,
|
|
485
|
+
hub_content_version=model_version,
|
|
486
|
+
hub_content_type=data_type.value,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
model_specs = make_model_specs_from_describe_hub_content_response(
|
|
490
|
+
DescribeHubContentResponse(hub_model_description),
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
return JumpStartCachedContentValue(formatted_content=model_specs)
|
|
494
|
+
|
|
495
|
+
raise ValueError(self._file_type_error_msg(data_type))
|
|
496
|
+
|
|
497
|
+
def get_manifest(
|
|
498
|
+
self,
|
|
499
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
500
|
+
) -> List[JumpStartModelHeader]:
|
|
501
|
+
"""Return entire JumpStart models manifest."""
|
|
502
|
+
manifest_dict = self._content_cache.get(
|
|
503
|
+
JumpStartCachedContentKey(
|
|
504
|
+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
|
|
505
|
+
)
|
|
506
|
+
)[0].formatted_content
|
|
507
|
+
manifest = list(manifest_dict.values()) # type: ignore
|
|
508
|
+
return manifest
|
|
509
|
+
|
|
510
|
+
def get_header(
|
|
511
|
+
self,
|
|
512
|
+
model_id: str,
|
|
513
|
+
semantic_version_str: str,
|
|
514
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
515
|
+
) -> JumpStartModelHeader:
|
|
516
|
+
"""Return header for a given JumpStart model ID and semantic version.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
model_id (str): model ID for which to get a header.
|
|
520
|
+
semantic_version_str (str): The semantic version for which to get a
|
|
521
|
+
header.
|
|
522
|
+
"""
|
|
523
|
+
|
|
524
|
+
return self._get_header_impl(
|
|
525
|
+
model_id, semantic_version_str=semantic_version_str, model_type=model_type
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def _select_version(
|
|
529
|
+
self,
|
|
530
|
+
model_id: str,
|
|
531
|
+
version_str: str,
|
|
532
|
+
available_versions: List[str],
|
|
533
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
534
|
+
) -> Optional[str]:
|
|
535
|
+
"""Perform semantic version search on available versions.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
version_str (str): the semantic version for which to filter
|
|
539
|
+
available versions.
|
|
540
|
+
available_versions (List[Version]): list of available versions.
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
if version_str == "*":
|
|
544
|
+
return utils.get_latest_version(available_versions)
|
|
545
|
+
|
|
546
|
+
if model_type == JumpStartModelType.PROPRIETARY:
|
|
547
|
+
if "*" in version_str:
|
|
548
|
+
raise KeyError(
|
|
549
|
+
get_wildcard_proprietary_model_version_msg(
|
|
550
|
+
model_id, version_str, available_versions
|
|
551
|
+
)
|
|
552
|
+
)
|
|
553
|
+
return version_str if version_str in available_versions else None
|
|
554
|
+
|
|
555
|
+
try:
|
|
556
|
+
spec = SpecifierSet(f"=={version_str}")
|
|
557
|
+
except InvalidSpecifier:
|
|
558
|
+
raise KeyError(f"Bad semantic version: {version_str}")
|
|
559
|
+
available_versions_filtered = list(spec.filter(available_versions))
|
|
560
|
+
return str(max(available_versions_filtered)) if available_versions_filtered != [] else None
|
|
561
|
+
|
|
562
|
+
def _get_header_impl(
|
|
563
|
+
self,
|
|
564
|
+
model_id: str,
|
|
565
|
+
semantic_version_str: str,
|
|
566
|
+
attempt: int = 0,
|
|
567
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
568
|
+
) -> JumpStartModelHeader:
|
|
569
|
+
"""Lower-level function to return header.
|
|
570
|
+
|
|
571
|
+
Allows a single retry if the cache is old.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
model_id (str): model ID for which to get a header.
|
|
575
|
+
semantic_version_str (str): The semantic version for which to get a
|
|
576
|
+
header.
|
|
577
|
+
attempt (int): attempt number at retrieving a header.
|
|
578
|
+
"""
|
|
579
|
+
if model_type == JumpStartModelType.OPEN_WEIGHTS:
|
|
580
|
+
versioned_model_id = self._open_weight_model_id_manifest_key_cache.get(
|
|
581
|
+
JumpStartVersionedModelId(model_id, semantic_version_str)
|
|
582
|
+
)[0]
|
|
583
|
+
elif model_type == JumpStartModelType.PROPRIETARY:
|
|
584
|
+
versioned_model_id = self._proprietary_model_id_manifest_key_cache.get(
|
|
585
|
+
JumpStartVersionedModelId(model_id, semantic_version_str)
|
|
586
|
+
)[0]
|
|
587
|
+
|
|
588
|
+
manifest = self._content_cache.get(
|
|
589
|
+
JumpStartCachedContentKey(
|
|
590
|
+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
|
|
591
|
+
)
|
|
592
|
+
)[0].formatted_content
|
|
593
|
+
|
|
594
|
+
try:
|
|
595
|
+
header = manifest[versioned_model_id] # type: ignore
|
|
596
|
+
return header
|
|
597
|
+
except KeyError:
|
|
598
|
+
if attempt > 0:
|
|
599
|
+
raise
|
|
600
|
+
self.clear()
|
|
601
|
+
return self._get_header_impl(model_id, semantic_version_str, attempt + 1, model_type)
|
|
602
|
+
|
|
603
|
+
def get_specs(
|
|
604
|
+
self,
|
|
605
|
+
model_id: str,
|
|
606
|
+
version_str: str,
|
|
607
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
608
|
+
) -> JumpStartModelSpecs:
|
|
609
|
+
"""Return specs for a given JumpStart model ID and semantic version.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
model_id (str): model ID for which to get specs.
|
|
613
|
+
semantic_version_str (str): The semantic version for which to get
|
|
614
|
+
specs.
|
|
615
|
+
model_type (JumpStartModelType): The type of the model of interest.
|
|
616
|
+
"""
|
|
617
|
+
header = self.get_header(model_id, version_str, model_type)
|
|
618
|
+
spec_key = header.spec_key
|
|
619
|
+
specs, cache_hit = self._content_cache.get(
|
|
620
|
+
JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
if not cache_hit and "*" in version_str:
|
|
624
|
+
JUMPSTART_LOGGER.warning(
|
|
625
|
+
get_wildcard_model_version_msg(header.model_id, version_str, header.version)
|
|
626
|
+
)
|
|
627
|
+
return specs.formatted_content
|
|
628
|
+
|
|
629
|
+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
|
|
630
|
+
"""Return JumpStart-compatible specs for a given Hub model
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
hub_model_arn (str): Arn for the Hub model to get specs for
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
details, _ = self._content_cache.get(
|
|
637
|
+
JumpStartCachedContentKey(
|
|
638
|
+
HubContentType.MODEL,
|
|
639
|
+
hub_model_arn,
|
|
640
|
+
)
|
|
641
|
+
)
|
|
642
|
+
return details.formatted_content
|
|
643
|
+
|
|
644
|
+
def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs:
|
|
645
|
+
"""Return JumpStart-compatible specs for a given Hub model reference
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
hub_model_arn (str): Arn for the Hub model to get specs for
|
|
649
|
+
"""
|
|
650
|
+
|
|
651
|
+
details, _ = self._content_cache.get(
|
|
652
|
+
JumpStartCachedContentKey(
|
|
653
|
+
HubContentType.MODEL_REFERENCE,
|
|
654
|
+
hub_model_reference_arn,
|
|
655
|
+
)
|
|
656
|
+
)
|
|
657
|
+
return details.formatted_content
|
|
658
|
+
|
|
659
|
+
def clear(self) -> None:
|
|
660
|
+
"""Clears the model ID/version and s3 cache."""
|
|
661
|
+
self._content_cache.clear()
|
|
662
|
+
self._open_weight_model_id_manifest_key_cache.clear()
|
|
663
|
+
self._proprietary_model_id_manifest_key_cache.clear()
|