sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module stores notebook utils related to SageMaker JumpStart."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
import copy
|
|
16
|
+
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
18
|
+
|
|
19
|
+
from functools import cmp_to_key
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
|
|
23
|
+
from packaging.version import Version
|
|
24
|
+
from sagemaker.core.jumpstart import accessors
|
|
25
|
+
from sagemaker.core.jumpstart.constants import (
|
|
26
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
27
|
+
PROPRIETARY_MODEL_SPEC_PREFIX,
|
|
28
|
+
)
|
|
29
|
+
from sagemaker.core.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
|
|
30
|
+
from sagemaker.core.jumpstart.filters import (
|
|
31
|
+
SPECIAL_SUPPORTED_FILTER_KEYS,
|
|
32
|
+
ProprietaryModelFilterIdentifiers,
|
|
33
|
+
BooleanValues,
|
|
34
|
+
Identity,
|
|
35
|
+
SpecialSupportedFilterKeys,
|
|
36
|
+
)
|
|
37
|
+
from sagemaker.core.jumpstart.filters import (
|
|
38
|
+
Constant,
|
|
39
|
+
ModelFilter,
|
|
40
|
+
Operator,
|
|
41
|
+
evaluate_filter_expression,
|
|
42
|
+
)
|
|
43
|
+
from sagemaker.core.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
|
|
44
|
+
from sagemaker.core.jumpstart.utils import (
|
|
45
|
+
get_jumpstart_content_bucket,
|
|
46
|
+
get_region_fallback,
|
|
47
|
+
get_sagemaker_version,
|
|
48
|
+
verify_model_region_and_return_specs,
|
|
49
|
+
validate_model_id_and_get_type,
|
|
50
|
+
)
|
|
51
|
+
from sagemaker.core.helper.session_helper import Session
|
|
52
|
+
|
|
53
|
+
MAX_SEARCH_WORKERS = min(32, (os.cpu_count() or 4) * 2)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
|
|
57
|
+
model_version_1: Optional[Tuple[str, str]] = None,
|
|
58
|
+
model_version_2: Optional[Tuple[str, str]] = None,
|
|
59
|
+
) -> int:
|
|
60
|
+
"""Performs comparison of sdk specs paths, in order to sort them.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model_version_1 (Tuple[str, str]): The first model ID and version tuple to compare.
|
|
64
|
+
model_version_2 (Tuple[str, str]): The second model ID and version tuple to compare.
|
|
65
|
+
"""
|
|
66
|
+
if model_version_1 is None or model_version_2 is None:
|
|
67
|
+
if model_version_2 is not None:
|
|
68
|
+
return -1
|
|
69
|
+
if model_version_1 is not None:
|
|
70
|
+
return 1
|
|
71
|
+
return 0
|
|
72
|
+
|
|
73
|
+
model_id_1, version_1 = model_version_1
|
|
74
|
+
|
|
75
|
+
model_id_2, version_2 = model_version_2
|
|
76
|
+
|
|
77
|
+
if model_id_1 < model_id_2:
|
|
78
|
+
return -1
|
|
79
|
+
|
|
80
|
+
if model_id_2 < model_id_1:
|
|
81
|
+
return 1
|
|
82
|
+
|
|
83
|
+
if Version(version_1) < Version(version_2):
|
|
84
|
+
return 1
|
|
85
|
+
|
|
86
|
+
if Version(version_2) < Version(version_1):
|
|
87
|
+
return -1
|
|
88
|
+
|
|
89
|
+
return 0
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _model_filter_in_operator_generator(filter_operator: Operator) -> Generator:
|
|
93
|
+
"""Generator for model filters in an operator."""
|
|
94
|
+
for operator in filter_operator:
|
|
95
|
+
if isinstance(operator.unresolved_value, ModelFilter):
|
|
96
|
+
yield operator
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _put_resolved_booleans_into_filter(
|
|
100
|
+
filter_operator: Operator, model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues]
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Iterate over the operators in the filter, assign resolved value if found in second arg.
|
|
103
|
+
|
|
104
|
+
If not found, assigns ``UNKNOWN``.
|
|
105
|
+
"""
|
|
106
|
+
for operator in _model_filter_in_operator_generator(filter_operator):
|
|
107
|
+
model_filter = operator.unresolved_value
|
|
108
|
+
operator.resolved_value = model_filters_to_resolved_values.get(
|
|
109
|
+
model_filter, BooleanValues.UNKNOWN
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _populate_model_filters_to_resolved_values(
|
|
114
|
+
manifest_specs_cached_values: Dict[str, Any],
|
|
115
|
+
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues],
|
|
116
|
+
model_filters: Operator,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Iterate over the model filters, if the filter key has a cached value, evaluate the filter.
|
|
119
|
+
|
|
120
|
+
The resolved filter values are placed in ``model_filters_to_resolved_values``.
|
|
121
|
+
"""
|
|
122
|
+
for model_filter in model_filters:
|
|
123
|
+
if model_filter.key in manifest_specs_cached_values:
|
|
124
|
+
cached_model_value = manifest_specs_cached_values[model_filter.key]
|
|
125
|
+
evaluated_expression: BooleanValues = evaluate_filter_expression(
|
|
126
|
+
model_filter, cached_model_value
|
|
127
|
+
)
|
|
128
|
+
model_filters_to_resolved_values[model_filter] = evaluated_expression
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
|
|
132
|
+
"""Parse the model ID, return a tuple framework, task, rest-of-id.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
model_id (str): The model ID for which to extract the framework/task/model.
|
|
136
|
+
"""
|
|
137
|
+
_id_parts = model_id.split("-")
|
|
138
|
+
|
|
139
|
+
if len(_id_parts) < 3:
|
|
140
|
+
return "", "", ""
|
|
141
|
+
|
|
142
|
+
framework = _id_parts[0]
|
|
143
|
+
task = _id_parts[1]
|
|
144
|
+
name = "-".join(_id_parts[2:])
|
|
145
|
+
|
|
146
|
+
return framework, task, name
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def extract_model_type_filter_representation(spec_key: str) -> str:
|
|
150
|
+
"""Parses model spec key, determine if the model is proprietary or open weight.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
spek_key (str): The model spec key for which to extract the model type.
|
|
154
|
+
"""
|
|
155
|
+
model_spec_prefix = spec_key.split("/")[0]
|
|
156
|
+
|
|
157
|
+
if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX:
|
|
158
|
+
return JumpStartModelType.PROPRIETARY.value
|
|
159
|
+
|
|
160
|
+
return JumpStartModelType.OPEN_WEIGHTS.value
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def list_jumpstart_tasks( # pylint: disable=redefined-builtin
|
|
164
|
+
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
|
|
165
|
+
region: Optional[str] = None,
|
|
166
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
167
|
+
) -> List[str]:
|
|
168
|
+
"""List tasks for JumpStart, and optionally apply filters to result.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
filter (Union[Operator, str]): Optional. The filter to apply to list tasks. This can be
|
|
172
|
+
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
|
|
173
|
+
or simply a string filter which will get serialized into an Identity filter.
|
|
174
|
+
(e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed.
|
|
175
|
+
(Default: Constant(BooleanValues.TRUE)).
|
|
176
|
+
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
|
|
177
|
+
models. (Default: None).
|
|
178
|
+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
|
|
179
|
+
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
region = region or get_region_fallback(
|
|
183
|
+
sagemaker_session=sagemaker_session,
|
|
184
|
+
)
|
|
185
|
+
tasks: Set[str] = set()
|
|
186
|
+
for model_id, _ in _generate_jumpstart_model_versions(
|
|
187
|
+
filter=filter,
|
|
188
|
+
region=region,
|
|
189
|
+
sagemaker_session=sagemaker_session,
|
|
190
|
+
model_type=JumpStartModelType.OPEN_WEIGHTS,
|
|
191
|
+
):
|
|
192
|
+
_, task, _ = extract_framework_task_model(model_id)
|
|
193
|
+
tasks.add(task)
|
|
194
|
+
return sorted(list(tasks))
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
|
|
198
|
+
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
|
|
199
|
+
region: Optional[str] = None,
|
|
200
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
201
|
+
) -> List[str]:
|
|
202
|
+
"""List frameworks for JumpStart, and optionally apply filters to result.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
filter (Union[Operator, str]): Optional. The filter to apply to list frameworks. This can be
|
|
206
|
+
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
|
|
207
|
+
or simply a string filter which will get serialized into an Identity filter.
|
|
208
|
+
(eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed.
|
|
209
|
+
(Default: Constant(BooleanValues.TRUE)).
|
|
210
|
+
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
|
|
211
|
+
models. (Default: None).
|
|
212
|
+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
|
|
213
|
+
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
region = region or get_region_fallback(
|
|
217
|
+
sagemaker_session=sagemaker_session,
|
|
218
|
+
)
|
|
219
|
+
frameworks: Set[str] = set()
|
|
220
|
+
for model_id, _ in _generate_jumpstart_model_versions(
|
|
221
|
+
filter=filter,
|
|
222
|
+
region=region,
|
|
223
|
+
sagemaker_session=sagemaker_session,
|
|
224
|
+
model_type=JumpStartModelType.OPEN_WEIGHTS,
|
|
225
|
+
):
|
|
226
|
+
framework, _, _ = extract_framework_task_model(model_id)
|
|
227
|
+
frameworks.add(framework)
|
|
228
|
+
return sorted(list(frameworks))
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def list_jumpstart_scripts( # pylint: disable=redefined-builtin
|
|
232
|
+
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
|
|
233
|
+
region: Optional[str] = None,
|
|
234
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
235
|
+
) -> List[str]:
|
|
236
|
+
"""List scripts for JumpStart, and optionally apply filters to result.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
filter (Union[Operator, str]): Optional. The filter to apply to list scripts. This can be
|
|
240
|
+
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
|
|
241
|
+
or simply a string filter which will get serialized into an Identity filter.
|
|
242
|
+
(e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed.
|
|
243
|
+
(Default: Constant(BooleanValues.TRUE)).
|
|
244
|
+
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
|
|
245
|
+
models. (Default: None).
|
|
246
|
+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
|
|
247
|
+
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
248
|
+
"""
|
|
249
|
+
region = region or get_region_fallback(
|
|
250
|
+
sagemaker_session=sagemaker_session,
|
|
251
|
+
)
|
|
252
|
+
if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or (
|
|
253
|
+
isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower()
|
|
254
|
+
):
|
|
255
|
+
return sorted([e.value for e in JumpStartScriptScope])
|
|
256
|
+
|
|
257
|
+
scripts: Set[str] = set()
|
|
258
|
+
for model_id, version in _generate_jumpstart_model_versions(
|
|
259
|
+
filter=filter,
|
|
260
|
+
region=region,
|
|
261
|
+
sagemaker_session=sagemaker_session,
|
|
262
|
+
model_type=JumpStartModelType.OPEN_WEIGHTS,
|
|
263
|
+
):
|
|
264
|
+
scripts.add(JumpStartScriptScope.INFERENCE)
|
|
265
|
+
model_specs = verify_model_region_and_return_specs(
|
|
266
|
+
region=region,
|
|
267
|
+
model_id=model_id,
|
|
268
|
+
version=version,
|
|
269
|
+
sagemaker_session=sagemaker_session,
|
|
270
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
271
|
+
)
|
|
272
|
+
if model_specs.training_supported:
|
|
273
|
+
scripts.add(JumpStartScriptScope.TRAINING)
|
|
274
|
+
|
|
275
|
+
if scripts == {e.value for e in JumpStartScriptScope}:
|
|
276
|
+
break
|
|
277
|
+
return sorted(list(scripts))
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _is_valid_version(version: str) -> bool:
|
|
281
|
+
"""Checks if the version is convertable to Version class."""
|
|
282
|
+
try:
|
|
283
|
+
Version(version)
|
|
284
|
+
return True
|
|
285
|
+
except Exception: # pylint: disable=broad-except
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def list_jumpstart_models( # pylint: disable=redefined-builtin
|
|
290
|
+
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
|
|
291
|
+
region: Optional[str] = None,
|
|
292
|
+
list_incomplete_models: bool = False,
|
|
293
|
+
list_old_models: bool = False,
|
|
294
|
+
list_versions: bool = False,
|
|
295
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
296
|
+
) -> List[Union[Tuple[str], Tuple[str, str]]]:
|
|
297
|
+
"""List models for JumpStart, and optionally apply filters to result.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be
|
|
301
|
+
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
|
|
302
|
+
or simply a string filter which will get serialized into an Identity filter.
|
|
303
|
+
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed.
|
|
304
|
+
(Default: Constant(BooleanValues.TRUE)).
|
|
305
|
+
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
|
|
306
|
+
models. (Default: None).
|
|
307
|
+
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
|
|
308
|
+
requested by the filter, and the filter cannot be resolved to a include/not include,
|
|
309
|
+
whether the model should be included. By default, these models are omitted from results.
|
|
310
|
+
(Default: False).
|
|
311
|
+
list_old_models (bool): Optional. If there are older versions of a model, whether the older
|
|
312
|
+
versions should be included in the returned result. (Default: False).
|
|
313
|
+
list_versions (bool): Optional. True if versions for models should be returned in addition
|
|
314
|
+
to the id of the model. (Default: False).
|
|
315
|
+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
|
|
316
|
+
to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
region = region or get_region_fallback(
|
|
320
|
+
sagemaker_session=sagemaker_session,
|
|
321
|
+
)
|
|
322
|
+
model_id_version_dict: Dict[str, List[str]] = dict()
|
|
323
|
+
for model_id, version in _generate_jumpstart_model_versions(
|
|
324
|
+
filter=filter,
|
|
325
|
+
region=region,
|
|
326
|
+
list_incomplete_models=list_incomplete_models,
|
|
327
|
+
sagemaker_session=sagemaker_session,
|
|
328
|
+
):
|
|
329
|
+
if model_id not in model_id_version_dict:
|
|
330
|
+
model_id_version_dict[model_id] = list()
|
|
331
|
+
model_version = Version(version) if _is_valid_version(version) else version
|
|
332
|
+
model_id_version_dict[model_id].append(model_version)
|
|
333
|
+
|
|
334
|
+
if not list_versions:
|
|
335
|
+
return sorted(list(model_id_version_dict.keys()))
|
|
336
|
+
|
|
337
|
+
if not list_old_models:
|
|
338
|
+
for model_id, versions in model_id_version_dict.items():
|
|
339
|
+
try:
|
|
340
|
+
model_id_version_dict.update({model_id: set([max(versions)])})
|
|
341
|
+
except TypeError:
|
|
342
|
+
versions = [str(v) for v in versions]
|
|
343
|
+
model_id_version_dict.update({model_id: set([max(versions)])})
|
|
344
|
+
|
|
345
|
+
model_id_version_set: Set[Tuple[str, str]] = set()
|
|
346
|
+
for model_id in model_id_version_dict:
|
|
347
|
+
for version in model_id_version_dict[model_id]:
|
|
348
|
+
model_id_version_set.add((model_id, str(version)))
|
|
349
|
+
|
|
350
|
+
return sorted(list(model_id_version_set), key=cmp_to_key(_compare_model_version_tuples))
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
|
|
354
|
+
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
|
|
355
|
+
region: Optional[str] = None,
|
|
356
|
+
list_incomplete_models: bool = False,
|
|
357
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
358
|
+
model_type: Optional[JumpStartModelType] = None,
|
|
359
|
+
) -> Generator:
|
|
360
|
+
"""Generate models for JumpStart, and optionally apply filters to result.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
filter (Union[Operator, str]): Optional. The filter to apply to generate models. This can be
|
|
364
|
+
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
|
|
365
|
+
or simply a string filter which will get serialized into an Identity filter.
|
|
366
|
+
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated.
|
|
367
|
+
(Default: Constant(BooleanValues.TRUE)).
|
|
368
|
+
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
|
|
369
|
+
models. (Default: None).
|
|
370
|
+
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
|
|
371
|
+
requested by the filter, and the filter cannot be resolved to a include/not include,
|
|
372
|
+
whether the model should be included. By default, these models are omitted from
|
|
373
|
+
results. (Default: False).
|
|
374
|
+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
|
|
375
|
+
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
region = region or get_region_fallback(
|
|
379
|
+
sagemaker_session=sagemaker_session,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
|
|
383
|
+
region=region,
|
|
384
|
+
s3_client=sagemaker_session.s3_client,
|
|
385
|
+
model_type=JumpStartModelType.PROPRIETARY,
|
|
386
|
+
)
|
|
387
|
+
open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
|
|
388
|
+
region=region,
|
|
389
|
+
s3_client=sagemaker_session.s3_client,
|
|
390
|
+
model_type=JumpStartModelType.OPEN_WEIGHTS,
|
|
391
|
+
)
|
|
392
|
+
models_manifest_list = (
|
|
393
|
+
open_weight_manifest_list
|
|
394
|
+
if model_type == JumpStartModelType.OPEN_WEIGHTS
|
|
395
|
+
else (
|
|
396
|
+
prop_models_manifest_list
|
|
397
|
+
if model_type == JumpStartModelType.PROPRIETARY
|
|
398
|
+
else open_weight_manifest_list + prop_models_manifest_list
|
|
399
|
+
)
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
if isinstance(filter, str):
|
|
403
|
+
filter = Identity(filter)
|
|
404
|
+
|
|
405
|
+
manifest_keys = set(
|
|
406
|
+
open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
all_keys: Set[str] = set()
|
|
410
|
+
|
|
411
|
+
model_filters: Set[ModelFilter] = set()
|
|
412
|
+
|
|
413
|
+
for operator in _model_filter_in_operator_generator(filter):
|
|
414
|
+
model_filter = operator.unresolved_value
|
|
415
|
+
key = model_filter.key
|
|
416
|
+
all_keys.add(key)
|
|
417
|
+
if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in {
|
|
418
|
+
identifier.value for identifier in ProprietaryModelFilterIdentifiers
|
|
419
|
+
}:
|
|
420
|
+
model_filter.set_value(JumpStartModelType.PROPRIETARY.value)
|
|
421
|
+
model_filters.add(model_filter)
|
|
422
|
+
|
|
423
|
+
for key in all_keys:
|
|
424
|
+
if "." in key:
|
|
425
|
+
raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').")
|
|
426
|
+
|
|
427
|
+
metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
|
|
428
|
+
|
|
429
|
+
required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
|
|
430
|
+
possible_spec_keys = metadata_filter_keys - manifest_keys
|
|
431
|
+
|
|
432
|
+
is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
|
|
433
|
+
is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
|
|
434
|
+
is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys
|
|
435
|
+
|
|
436
|
+
def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]:
|
|
437
|
+
|
|
438
|
+
copied_filter = copy.deepcopy(filter)
|
|
439
|
+
|
|
440
|
+
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
|
|
441
|
+
|
|
442
|
+
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
|
|
443
|
+
|
|
444
|
+
for val in required_manifest_keys:
|
|
445
|
+
manifest_specs_cached_values[val] = getattr(model_manifest, val)
|
|
446
|
+
|
|
447
|
+
if is_task_filter:
|
|
448
|
+
manifest_specs_cached_values[SpecialSupportedFilterKeys.TASK] = (
|
|
449
|
+
extract_framework_task_model(model_manifest.model_id)[1]
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
if is_framework_filter:
|
|
453
|
+
manifest_specs_cached_values[SpecialSupportedFilterKeys.FRAMEWORK] = (
|
|
454
|
+
extract_framework_task_model(model_manifest.model_id)[0]
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if is_model_type_filter:
|
|
458
|
+
manifest_specs_cached_values[SpecialSupportedFilterKeys.MODEL_TYPE] = (
|
|
459
|
+
extract_model_type_filter_representation(model_manifest.spec_key)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if Version(model_manifest.min_version) > Version(get_sagemaker_version()):
|
|
463
|
+
return None
|
|
464
|
+
|
|
465
|
+
_populate_model_filters_to_resolved_values(
|
|
466
|
+
manifest_specs_cached_values,
|
|
467
|
+
model_filters_to_resolved_values,
|
|
468
|
+
model_filters,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
_put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
|
|
472
|
+
|
|
473
|
+
copied_filter.eval()
|
|
474
|
+
|
|
475
|
+
if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
|
|
476
|
+
if copied_filter.resolved_value == BooleanValues.TRUE:
|
|
477
|
+
return (model_manifest.model_id, model_manifest.version)
|
|
478
|
+
return None
|
|
479
|
+
|
|
480
|
+
if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
|
|
481
|
+
raise RuntimeError(
|
|
482
|
+
"Filter expression in unevaluated state after using "
|
|
483
|
+
"values from model manifest. Model ID and version that "
|
|
484
|
+
f"is failing: {(model_manifest.model_id, model_manifest.version)}."
|
|
485
|
+
)
|
|
486
|
+
copied_filter_2 = copy.deepcopy(filter)
|
|
487
|
+
|
|
488
|
+
# spec is downloaded to thread's memory. since each thread
|
|
489
|
+
# accesses a unique s3 spec, there is no need to use the JS caching utils.
|
|
490
|
+
# spec only stays in memory for lifecycle of thread.
|
|
491
|
+
model_specs = JumpStartModelSpecs(
|
|
492
|
+
json.loads(
|
|
493
|
+
sagemaker_session.read_s3_file(
|
|
494
|
+
get_jumpstart_content_bucket(region), model_manifest.spec_key
|
|
495
|
+
)
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
for val in possible_spec_keys:
|
|
500
|
+
if hasattr(model_specs, val):
|
|
501
|
+
manifest_specs_cached_values[val] = getattr(model_specs, val)
|
|
502
|
+
|
|
503
|
+
_populate_model_filters_to_resolved_values(
|
|
504
|
+
manifest_specs_cached_values,
|
|
505
|
+
model_filters_to_resolved_values,
|
|
506
|
+
model_filters,
|
|
507
|
+
)
|
|
508
|
+
_put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
|
|
509
|
+
|
|
510
|
+
copied_filter_2.eval()
|
|
511
|
+
|
|
512
|
+
if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
|
|
513
|
+
if copied_filter_2.resolved_value == BooleanValues.TRUE or (
|
|
514
|
+
BooleanValues.UNKNOWN and list_incomplete_models
|
|
515
|
+
):
|
|
516
|
+
return (model_manifest.model_id, model_manifest.version)
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
raise RuntimeError(
|
|
520
|
+
"Filter expression in unevaluated state after using values from model specs. "
|
|
521
|
+
"Model ID and version that is failing: "
|
|
522
|
+
f"{(model_manifest.model_id, model_manifest.version)}."
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor:
|
|
526
|
+
futures = []
|
|
527
|
+
for header in models_manifest_list:
|
|
528
|
+
futures.append(executor.submit(evaluate_model, header))
|
|
529
|
+
|
|
530
|
+
for future in as_completed(futures):
|
|
531
|
+
error = future.exception()
|
|
532
|
+
if error:
|
|
533
|
+
raise error
|
|
534
|
+
result = future.result()
|
|
535
|
+
if result:
|
|
536
|
+
yield result
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def get_model_url(
|
|
540
|
+
model_id: str,
|
|
541
|
+
model_version: str,
|
|
542
|
+
region: Optional[str] = None,
|
|
543
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
544
|
+
config_name: Optional[str] = None,
|
|
545
|
+
) -> str:
|
|
546
|
+
"""Retrieve web url describing pretrained model.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
model_id (str): The model ID for which to retrieve the url.
|
|
550
|
+
model_version (str): The model version for which to retrieve the url.
|
|
551
|
+
region (str): Optional. The region from which to retrieve metadata.
|
|
552
|
+
(Default: None)
|
|
553
|
+
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
|
|
554
|
+
to retrieve the model url.
|
|
555
|
+
"""
|
|
556
|
+
model_type = validate_model_id_and_get_type(
|
|
557
|
+
model_id=model_id,
|
|
558
|
+
model_version=model_version,
|
|
559
|
+
region=region,
|
|
560
|
+
sagemaker_session=sagemaker_session,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
region = region or get_region_fallback(
|
|
564
|
+
sagemaker_session=sagemaker_session,
|
|
565
|
+
)
|
|
566
|
+
model_specs = verify_model_region_and_return_specs(
|
|
567
|
+
region=region,
|
|
568
|
+
model_id=model_id,
|
|
569
|
+
version=model_version,
|
|
570
|
+
sagemaker_session=sagemaker_session,
|
|
571
|
+
scope=JumpStartScriptScope.INFERENCE,
|
|
572
|
+
model_type=model_type,
|
|
573
|
+
config_name=config_name,
|
|
574
|
+
)
|
|
575
|
+
return model_specs.url
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module stores parameters related to SageMaker JumpStart."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
import datetime
|
|
16
|
+
|
|
17
|
+
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20
|
|
18
|
+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
|
|
19
|
+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
|
|
20
|
+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
|