sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains utilites for JumpStart model metadata."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, ConfigDict
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseConfig(BaseModel):
|
|
21
|
+
"""BaseConfig"""
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(
|
|
24
|
+
validate_assignment=True,
|
|
25
|
+
extra="forbid",
|
|
26
|
+
frozen=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class JumpStartConfig(BaseConfig):
|
|
31
|
+
"""Configuration Class for JumpStart.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
model_id (str): The model ID of the JumpStart model.
|
|
35
|
+
model_version (Optional[str]): The version of the JumpStart model.
|
|
36
|
+
Defaults to None.
|
|
37
|
+
hub_name (Optional[str]): The name of the JumpStart hub. Defaults to None.
|
|
38
|
+
accept_eula (Optional[bool]): Whether to accept the EULA. Defaults to None.
|
|
39
|
+
training_config_name (Optional[str]): The name of the training configuration.
|
|
40
|
+
Defaults to None.
|
|
41
|
+
inference_config_name (Optional[str]): The name of the inference configuration.
|
|
42
|
+
Defaults to None.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
model_id: str
|
|
46
|
+
model_version: Optional[str] = None
|
|
47
|
+
hub_name: Optional[str] = None
|
|
48
|
+
accept_eula: Optional[bool] = False
|
|
49
|
+
training_config_name: Optional[str] = None
|
|
50
|
+
inference_config_name: Optional[str] = None
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains constants for JumpStart."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from typing import Dict, Set, Type
|
|
19
|
+
import json
|
|
20
|
+
import boto3
|
|
21
|
+
from sagemaker.core.deserializers import BaseDeserializer, JSONDeserializer
|
|
22
|
+
from sagemaker.core.jumpstart.enums import (
|
|
23
|
+
JumpStartScriptScope,
|
|
24
|
+
SerializerType,
|
|
25
|
+
DeserializerType,
|
|
26
|
+
MIMEType,
|
|
27
|
+
JumpStartModelType,
|
|
28
|
+
)
|
|
29
|
+
from sagemaker.core.jumpstart.types import JumpStartLaunchedRegionInfo, JumpStartS3FileType
|
|
30
|
+
from sagemaker.core.serializers import (
|
|
31
|
+
BaseSerializer,
|
|
32
|
+
CSVSerializer,
|
|
33
|
+
DataSerializer,
|
|
34
|
+
IdentitySerializer,
|
|
35
|
+
JSONSerializer,
|
|
36
|
+
)
|
|
37
|
+
from sagemaker.core.helper.session_helper import Session
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
SAGEMAKER_PUBLIC_HUB = "SageMakerPublicHub"
|
|
41
|
+
DEFAULT_TRAINING_ENTRY_POINT = "transfer_learning.py"
|
|
42
|
+
|
|
43
|
+
JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart")
|
|
44
|
+
|
|
45
|
+
# disable logging if env var is set
|
|
46
|
+
JUMPSTART_LOGGER.addHandler(
|
|
47
|
+
type(
|
|
48
|
+
"",
|
|
49
|
+
(logging.StreamHandler,),
|
|
50
|
+
{
|
|
51
|
+
"emit": lambda self, *args, **kwargs: (
|
|
52
|
+
logging.StreamHandler.emit(self, *args, **kwargs)
|
|
53
|
+
if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING)
|
|
54
|
+
else None
|
|
55
|
+
)
|
|
56
|
+
},
|
|
57
|
+
)()
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
_CURRENT_FILE_DIRECTORY_PATH = os.path.dirname(os.path.realpath(__file__))
|
|
62
|
+
REGION_CONFIG_JSON_FILENAME = "region_config.json"
|
|
63
|
+
REGION_CONFIG_JSON_FILEPATH = os.path.join(
|
|
64
|
+
_CURRENT_FILE_DIRECTORY_PATH, REGION_CONFIG_JSON_FILENAME
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _load_region_config(filepath: str) -> Set[JumpStartLaunchedRegionInfo]:
|
|
69
|
+
"""Load the JumpStart region config from a JSON file."""
|
|
70
|
+
debug_msg = f"Loading JumpStart region config from '{filepath}'."
|
|
71
|
+
JUMPSTART_LOGGER.debug(debug_msg)
|
|
72
|
+
try:
|
|
73
|
+
with open(filepath) as f:
|
|
74
|
+
config = json.load(f)
|
|
75
|
+
|
|
76
|
+
return {
|
|
77
|
+
JumpStartLaunchedRegionInfo(
|
|
78
|
+
region_name=region,
|
|
79
|
+
content_bucket=data["content_bucket"],
|
|
80
|
+
gated_content_bucket=data.get("gated_content_bucket"),
|
|
81
|
+
neo_content_bucket=data.get("neo_content_bucket"),
|
|
82
|
+
)
|
|
83
|
+
for region, data in config.items()
|
|
84
|
+
}
|
|
85
|
+
except Exception: # pylint: disable=W0703
|
|
86
|
+
JUMPSTART_LOGGER.error("Unable to load JumpStart region config.", exc_info=True)
|
|
87
|
+
return set()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING"
|
|
91
|
+
ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY = "DISABLE_JUMPSTART_TELEMETRY"
|
|
92
|
+
|
|
93
|
+
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = _load_region_config(
|
|
94
|
+
REGION_CONFIG_JSON_FILEPATH
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = {
|
|
98
|
+
region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS
|
|
99
|
+
}
|
|
100
|
+
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
|
|
101
|
+
|
|
102
|
+
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
|
|
103
|
+
JUMPSTART_GATED_BUCKET_NAME_SET = {
|
|
104
|
+
region.gated_content_bucket
|
|
105
|
+
for region in JUMPSTART_LAUNCHED_REGIONS
|
|
106
|
+
if region.gated_content_bucket is not None
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET = JUMPSTART_BUCKET_NAME_SET.union(
|
|
110
|
+
JUMPSTART_GATED_BUCKET_NAME_SET
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
|
|
114
|
+
NEO_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
|
|
115
|
+
|
|
116
|
+
JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub"
|
|
117
|
+
|
|
118
|
+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
|
|
119
|
+
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
|
|
120
|
+
|
|
121
|
+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
|
|
122
|
+
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
|
|
123
|
+
|
|
124
|
+
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
|
|
125
|
+
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
|
|
126
|
+
|
|
127
|
+
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
|
|
128
|
+
|
|
129
|
+
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
|
|
130
|
+
ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE"
|
|
131
|
+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
|
|
132
|
+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
|
|
133
|
+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (
|
|
134
|
+
"AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE"
|
|
135
|
+
)
|
|
136
|
+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE"
|
|
137
|
+
ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE = "AWS_NEO_CONTENT_BUCKET_OVERRIDE"
|
|
138
|
+
|
|
139
|
+
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
|
|
140
|
+
|
|
141
|
+
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri"
|
|
142
|
+
|
|
143
|
+
PROPRIETARY_MODEL_SPEC_PREFIX = "proprietary-models"
|
|
144
|
+
PROPRIETARY_MODEL_FILTER_NAME = "marketplace"
|
|
145
|
+
|
|
146
|
+
CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = {
|
|
147
|
+
MIMEType.X_IMAGE: SerializerType.RAW_BYTES,
|
|
148
|
+
MIMEType.LIST_TEXT: SerializerType.JSON,
|
|
149
|
+
MIMEType.X_TEXT: SerializerType.TEXT,
|
|
150
|
+
MIMEType.JSON: SerializerType.JSON,
|
|
151
|
+
MIMEType.CSV: SerializerType.CSV,
|
|
152
|
+
MIMEType.WAV: SerializerType.RAW_BYTES,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP: Dict[MIMEType, DeserializerType] = {
|
|
157
|
+
MIMEType.JSON: DeserializerType.JSON,
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
SERIALIZER_TYPE_TO_CLASS_MAP: Dict[SerializerType, Type[BaseSerializer]] = {
|
|
161
|
+
SerializerType.RAW_BYTES: DataSerializer,
|
|
162
|
+
SerializerType.JSON: JSONSerializer,
|
|
163
|
+
SerializerType.TEXT: IdentitySerializer,
|
|
164
|
+
SerializerType.CSV: CSVSerializer,
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
DESERIALIZER_TYPE_TO_CLASS_MAP: Dict[DeserializerType, Type[BaseDeserializer]] = {
|
|
168
|
+
DeserializerType.JSON: JSONDeserializer,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = {
|
|
172
|
+
JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
|
|
173
|
+
JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = {
|
|
177
|
+
JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_SPECS,
|
|
178
|
+
JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html"
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session(
|
|
185
|
+
boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)
|
|
186
|
+
)
|
|
187
|
+
except Exception as e: # pylint: disable=W0703
|
|
188
|
+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION = None
|
|
189
|
+
JUMPSTART_LOGGER.warning(
|
|
190
|
+
"Unable to create default JumpStart SageMaker Session due to the following error: %s.",
|
|
191
|
+
str(e),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
EXTRA_MODEL_ID_TAGS = ["sm-jumpstart-id", "sagemaker-studio:jumpstart-model-id"]
|
|
195
|
+
EXTRA_MODEL_VERSION_TAGS = [
|
|
196
|
+
"sm-jumpstart-model-version",
|
|
197
|
+
"sagemaker-studio:jumpstart-model-version",
|
|
198
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""JumpStart deserializers module - provides retrieve_default function for backward compatibility."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from sagemaker.core.deserializers import BaseDeserializer
|
|
19
|
+
from sagemaker.core.jumpstart import artifacts, utils as jumpstart_utils
|
|
20
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
21
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
22
|
+
from sagemaker.core.helper.session_helper import Session
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def retrieve_default(
|
|
26
|
+
region: Optional[str] = None,
|
|
27
|
+
model_id: Optional[str] = None,
|
|
28
|
+
model_version: Optional[str] = None,
|
|
29
|
+
hub_arn: Optional[str] = None,
|
|
30
|
+
tolerate_vulnerable_model: bool = False,
|
|
31
|
+
tolerate_deprecated_model: bool = False,
|
|
32
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
33
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
34
|
+
config_name: Optional[str] = None,
|
|
35
|
+
) -> BaseDeserializer:
|
|
36
|
+
"""Retrieves the default deserializer for the model matching the given arguments.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
region (str): The AWS Region for which to retrieve the default deserializer.
|
|
40
|
+
Defaults to ``None``.
|
|
41
|
+
model_id (str): The model ID of the model for which to
|
|
42
|
+
retrieve the default deserializer. (Default: None).
|
|
43
|
+
model_version (str): The version of the model for which to retrieve the
|
|
44
|
+
default deserializer. (Default: None).
|
|
45
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
46
|
+
model details from. (Default: None).
|
|
47
|
+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
|
|
48
|
+
specifications should be tolerated (exception not raised). If False, raises an
|
|
49
|
+
exception if the script used by this version of the model has dependencies with known
|
|
50
|
+
security vulnerabilities. (Default: False).
|
|
51
|
+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
|
|
52
|
+
(exception not raised). False if these models should raise an exception.
|
|
53
|
+
(Default: False).
|
|
54
|
+
model_type (JumpStartModelType): The model type. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
55
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
56
|
+
object, used for SageMaker interactions. If not
|
|
57
|
+
specified, one is created using the default AWS configuration
|
|
58
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
59
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
60
|
+
Returns:
|
|
61
|
+
BaseDeserializer: The default deserializer to use for the model.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
65
|
+
"""
|
|
66
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving deserializers."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return artifacts._retrieve_default_deserializer(
|
|
72
|
+
model_id=model_id,
|
|
73
|
+
model_version=model_version,
|
|
74
|
+
hub_arn=hub_arn,
|
|
75
|
+
region=region,
|
|
76
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
77
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
78
|
+
sagemaker_session=sagemaker_session,
|
|
79
|
+
model_type=model_type,
|
|
80
|
+
config_name=config_name,
|
|
81
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains utilites for JumpStart model metadata."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
from typing import Optional, Tuple
|
|
18
|
+
from functools import lru_cache
|
|
19
|
+
from botocore.exceptions import ClientError
|
|
20
|
+
|
|
21
|
+
from sagemaker.core.helper.session_helper import Session
|
|
22
|
+
from sagemaker.core.utils.utils import logger
|
|
23
|
+
from sagemaker.core.resources import HubContent
|
|
24
|
+
from sagemaker.core.jumpstart.configs import JumpStartConfig
|
|
25
|
+
from sagemaker.core.jumpstart.models import HubContentDocument
|
|
26
|
+
from sagemaker.core.jumpstart.constants import SAGEMAKER_PUBLIC_HUB
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@lru_cache(maxsize=128)
|
|
30
|
+
def get_hub_content_and_document(
|
|
31
|
+
jumpstart_config: JumpStartConfig,
|
|
32
|
+
sagemaker_session: Optional[Session] = None,
|
|
33
|
+
) -> Tuple[HubContent, HubContentDocument]:
|
|
34
|
+
"""Get model metadata for JumpStart.
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
jumpstart_config (JumpStartConfig): JumpStart configuration.
|
|
39
|
+
sagemaker_session (Session, optional): SageMaker session.
|
|
40
|
+
Defaults to None.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
HubContentDocument: Model metadata.
|
|
44
|
+
"""
|
|
45
|
+
if sagemaker_session is None:
|
|
46
|
+
sagemaker_session = Session()
|
|
47
|
+
logger.debug("No sagemaker session provided. Using default session.")
|
|
48
|
+
|
|
49
|
+
hub_name = jumpstart_config.hub_name if jumpstart_config.hub_name else SAGEMAKER_PUBLIC_HUB
|
|
50
|
+
hub_content_type = "Model" if hub_name == SAGEMAKER_PUBLIC_HUB else "ModelReference"
|
|
51
|
+
|
|
52
|
+
region = sagemaker_session.boto_region_name
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
hub_content = HubContent.get(
|
|
56
|
+
hub_name=hub_name,
|
|
57
|
+
hub_content_name=jumpstart_config.model_id,
|
|
58
|
+
hub_content_version=jumpstart_config.model_version,
|
|
59
|
+
hub_content_type=hub_content_type,
|
|
60
|
+
session=sagemaker_session.boto_session,
|
|
61
|
+
region=region,
|
|
62
|
+
)
|
|
63
|
+
except ClientError as e:
|
|
64
|
+
if e.response["Error"]["Code"] == "ResourceNotFound":
|
|
65
|
+
logger.error(
|
|
66
|
+
f"Hub content {jumpstart_config.model_id} not found in {hub_name}.\n"
|
|
67
|
+
"Please check that the Model ID is availble in the specified hub."
|
|
68
|
+
)
|
|
69
|
+
raise e
|
|
70
|
+
|
|
71
|
+
logger.info(
|
|
72
|
+
f"hub_content_name: {hub_content.hub_content_name}, "
|
|
73
|
+
f"hub_content_version: {hub_content.hub_content_version}"
|
|
74
|
+
)
|
|
75
|
+
document_json = json.loads(hub_content.hub_content_document)
|
|
76
|
+
return (hub_content, HubContentDocument(**document_json))
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module stores enums related to SageMaker JumpStart."""
|
|
14
|
+
|
|
15
|
+
from __future__ import absolute_import
|
|
16
|
+
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import List
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelFramework(str, Enum):
|
|
22
|
+
"""Enum class for JumpStart model framework.
|
|
23
|
+
|
|
24
|
+
The ML framework as referenced in the prefix of the model ID.
|
|
25
|
+
This value does not necessarily correspond to the container name.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
PYTORCH = "pytorch"
|
|
29
|
+
TENSORFLOW = "tensorflow"
|
|
30
|
+
MXNET = "mxnet"
|
|
31
|
+
HUGGINGFACE = "huggingface"
|
|
32
|
+
LIGHTGBM = "lightgbm"
|
|
33
|
+
CATBOOST = "catboost"
|
|
34
|
+
XGBOOST = "xgboost"
|
|
35
|
+
SKLEARN = "sklearn"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class JumpStartModelType(str, Enum):
|
|
39
|
+
"""Enum class for JumpStart model type.
|
|
40
|
+
|
|
41
|
+
OPEN_WEIGHTS: Publicly available models have open weights
|
|
42
|
+
and are onboarded and maintained by JumpStart.
|
|
43
|
+
PROPRIETARY: Proprietary models from third-party providers do not have open weights.
|
|
44
|
+
You must subscribe to proprietary models in AWS Marketplace before use.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
OPEN_WEIGHTS = "open_weights"
|
|
48
|
+
PROPRIETARY = "proprietary"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class VariableScope(str, Enum):
|
|
52
|
+
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.
|
|
53
|
+
|
|
54
|
+
Used for hosting environment variables and training hyperparameters.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
CONTAINER = "container"
|
|
58
|
+
ALGORITHM = "algorithm"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class JumpStartScriptScope(str, Enum):
|
|
62
|
+
"""Enum class for JumpStart script scopes."""
|
|
63
|
+
|
|
64
|
+
INFERENCE = "inference"
|
|
65
|
+
TRAINING = "training"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class HyperparameterValidationMode(str, Enum):
|
|
69
|
+
"""Possible modes for validating hyperparameters."""
|
|
70
|
+
|
|
71
|
+
VALIDATE_PROVIDED = "validate_provided"
|
|
72
|
+
VALIDATE_ALGORITHM = "validate_algorithm"
|
|
73
|
+
VALIDATE_ALL = "validate_all"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class VariableTypes(str, Enum):
|
|
77
|
+
"""Possible types for hyperparameters and environment variables."""
|
|
78
|
+
|
|
79
|
+
TEXT = "text"
|
|
80
|
+
INT = "int"
|
|
81
|
+
FLOAT = "float"
|
|
82
|
+
BOOL = "bool"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class HubContentCapability(str, Enum):
|
|
86
|
+
"""Enum class for HubContent capabilities."""
|
|
87
|
+
|
|
88
|
+
BEDROCK_CONSOLE = "BEDROCK_CONSOLE"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class JumpStartTag(str, Enum):
|
|
92
|
+
"""Enum class for tag keys to apply to JumpStart models."""
|
|
93
|
+
|
|
94
|
+
INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
|
|
95
|
+
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"
|
|
96
|
+
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
|
|
97
|
+
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"
|
|
98
|
+
|
|
99
|
+
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
|
|
100
|
+
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
|
|
101
|
+
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
|
|
102
|
+
|
|
103
|
+
INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
|
|
104
|
+
TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"
|
|
105
|
+
|
|
106
|
+
HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"
|
|
107
|
+
|
|
108
|
+
BEDROCK = "sagemaker-sdk:bedrock"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class SerializerType(str, Enum):
|
|
112
|
+
"""Enum class for serializers associated with JumpStart models."""
|
|
113
|
+
|
|
114
|
+
TEXT = "text"
|
|
115
|
+
JSON = "json"
|
|
116
|
+
CSV = "csv"
|
|
117
|
+
RAW_BYTES = "raw_bytes"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class DeserializerType(str, Enum):
|
|
121
|
+
"""Enum class for deserializers associated with JumpStart models."""
|
|
122
|
+
|
|
123
|
+
JSON = "json"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class MIMEType(str, Enum):
|
|
127
|
+
"""Enum class for MIME types associated with JumpStart models."""
|
|
128
|
+
|
|
129
|
+
X_IMAGE = "application/x-image"
|
|
130
|
+
LIST_TEXT = "application/list-text"
|
|
131
|
+
X_TEXT = "application/x-text"
|
|
132
|
+
JSON = "application/json"
|
|
133
|
+
CSV = "text/csv"
|
|
134
|
+
WAV = "audio/wav"
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType":
|
|
138
|
+
"""Removes suffix from type and instantiates enum."""
|
|
139
|
+
base_type, _, _ = mime_type_with_suffix.partition(";")
|
|
140
|
+
return MIMEType(base_type)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class NamingConventionType(str, Enum):
|
|
144
|
+
"""Enum class for naming conventions."""
|
|
145
|
+
|
|
146
|
+
SNAKE_CASE = "snake_case"
|
|
147
|
+
UPPER_CAMEL_CASE = "upper_camel_case"
|
|
148
|
+
DEFAULT = UPPER_CAMEL_CASE
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class ModelSpecKwargType(str, Enum):
|
|
152
|
+
"""Enum class for types of kwargs for model hub content document and model specs."""
|
|
153
|
+
|
|
154
|
+
FIT = "fit_kwargs"
|
|
155
|
+
MODEL = "model_kwargs"
|
|
156
|
+
ESTIMATOR = "estimator_kwargs"
|
|
157
|
+
DEPLOY = "deploy_kwargs"
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def arg_keys(cls) -> List[str]:
|
|
161
|
+
"""Returns a list of kwargs keys that each type can have"""
|
|
162
|
+
return [member.value for member in cls]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class JumpStartConfigRankingName(str, Enum):
|
|
166
|
+
"""Enum class for ranking of JumpStart config."""
|
|
167
|
+
|
|
168
|
+
DEFAULT = "overall"
|