sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/__init__.py +2 -0
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2399 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +247 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1220 -0
- sagemaker/core/git_utils.py +415 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2977 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
- sagemaker/core/image_uri_config/huggingface.json +2287 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +494 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +972 -0
- sagemaker/core/image_uris.py +816 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +197 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +501 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +171 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +423 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +246 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +239 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1599 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1310 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
- sagemaker/core/remote_function/core/serialization.py +410 -0
- sagemaker/core/remote_function/core/stored_function.py +223 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +102 -0
- sagemaker/core/remote_function/invoke_function.py +167 -0
- sagemaker/core/remote_function/job.py +2121 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +82 -0
- sagemaker/core/telemetry/telemetry_logging.py +285 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +345 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +514 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.3.1.dist-info/RECORD +351 -0
- sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
from sagemaker.core.common_utils import (
|
|
2
|
+
format_tags,
|
|
3
|
+
resolve_value_from_config,
|
|
4
|
+
update_list_of_dicts_with_values_from_config,
|
|
5
|
+
_create_resource,
|
|
6
|
+
can_model_package_source_uri_autopopulate,
|
|
7
|
+
)
|
|
8
|
+
from sagemaker.core.config import (
|
|
9
|
+
MODEL_PACKAGE_VALIDATION_ROLE_PATH,
|
|
10
|
+
VALIDATION_ROLE,
|
|
11
|
+
VALIDATION_PROFILES,
|
|
12
|
+
MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH,
|
|
13
|
+
MODEL_PACKAGE_VALIDATION_PROFILES_PATH,
|
|
14
|
+
)
|
|
15
|
+
from botocore.exceptions import ClientError
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
logger = LOGGER = logging.getLogger("sagemaker")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_model_package_args(
|
|
22
|
+
content_types=None,
|
|
23
|
+
response_types=None,
|
|
24
|
+
inference_instances=None,
|
|
25
|
+
transform_instances=None,
|
|
26
|
+
model_package_name=None,
|
|
27
|
+
model_package_group_name=None,
|
|
28
|
+
model_data=None,
|
|
29
|
+
image_uri=None,
|
|
30
|
+
model_metrics=None,
|
|
31
|
+
metadata_properties=None,
|
|
32
|
+
marketplace_cert=False,
|
|
33
|
+
approval_status=None,
|
|
34
|
+
description=None,
|
|
35
|
+
tags=None,
|
|
36
|
+
container_def_list=None,
|
|
37
|
+
drift_check_baselines=None,
|
|
38
|
+
customer_metadata_properties=None,
|
|
39
|
+
validation_specification=None,
|
|
40
|
+
domain=None,
|
|
41
|
+
sample_payload_url=None,
|
|
42
|
+
task=None,
|
|
43
|
+
skip_model_validation=None,
|
|
44
|
+
source_uri=None,
|
|
45
|
+
model_card=None,
|
|
46
|
+
model_life_cycle=None,
|
|
47
|
+
):
|
|
48
|
+
if container_def_list is not None:
|
|
49
|
+
containers = container_def_list
|
|
50
|
+
else:
|
|
51
|
+
container = {
|
|
52
|
+
"Image": image_uri,
|
|
53
|
+
}
|
|
54
|
+
if model_data is not None:
|
|
55
|
+
container["ModelDataUrl"] = model_data
|
|
56
|
+
|
|
57
|
+
containers = [container]
|
|
58
|
+
|
|
59
|
+
model_package_args = {
|
|
60
|
+
"containers": containers,
|
|
61
|
+
"inference_instances": inference_instances,
|
|
62
|
+
"transform_instances": transform_instances,
|
|
63
|
+
"marketplace_cert": marketplace_cert,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if content_types is not None:
|
|
67
|
+
model_package_args["content_types"] = content_types
|
|
68
|
+
if response_types is not None:
|
|
69
|
+
model_package_args["response_types"] = response_types
|
|
70
|
+
if model_package_name is not None:
|
|
71
|
+
model_package_args["model_package_name"] = model_package_name
|
|
72
|
+
if model_package_group_name is not None:
|
|
73
|
+
model_package_args["model_package_group_name"] = model_package_group_name
|
|
74
|
+
if model_metrics is not None:
|
|
75
|
+
model_package_args["model_metrics"] = model_metrics._to_request_dict()
|
|
76
|
+
if drift_check_baselines is not None:
|
|
77
|
+
model_package_args["drift_check_baselines"] = drift_check_baselines._to_request_dict()
|
|
78
|
+
if metadata_properties is not None:
|
|
79
|
+
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
|
|
80
|
+
if approval_status is not None:
|
|
81
|
+
model_package_args["approval_status"] = approval_status
|
|
82
|
+
if description is not None:
|
|
83
|
+
model_package_args["description"] = description
|
|
84
|
+
if tags is not None:
|
|
85
|
+
model_package_args["tags"] = format_tags(tags)
|
|
86
|
+
if customer_metadata_properties is not None:
|
|
87
|
+
model_package_args["customer_metadata_properties"] = customer_metadata_properties
|
|
88
|
+
if validation_specification is not None:
|
|
89
|
+
model_package_args["validation_specification"] = validation_specification
|
|
90
|
+
if domain is not None:
|
|
91
|
+
model_package_args["domain"] = domain
|
|
92
|
+
if sample_payload_url is not None:
|
|
93
|
+
model_package_args["sample_payload_url"] = sample_payload_url
|
|
94
|
+
if task is not None:
|
|
95
|
+
model_package_args["task"] = task
|
|
96
|
+
if skip_model_validation is not None:
|
|
97
|
+
model_package_args["skip_model_validation"] = skip_model_validation
|
|
98
|
+
if source_uri is not None:
|
|
99
|
+
model_package_args["source_uri"] = source_uri
|
|
100
|
+
if model_life_cycle is not None:
|
|
101
|
+
model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict()
|
|
102
|
+
if model_card is not None:
|
|
103
|
+
original_req = model_card._create_request_args()
|
|
104
|
+
if original_req.get("ModelCardName") is not None:
|
|
105
|
+
del original_req["ModelCardName"]
|
|
106
|
+
if original_req.get("Content") is not None:
|
|
107
|
+
original_req["ModelCardContent"] = original_req["Content"]
|
|
108
|
+
del original_req["Content"]
|
|
109
|
+
model_package_args["model_card"] = original_req
|
|
110
|
+
return model_package_args
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_create_model_package_request(
|
|
114
|
+
model_package_name=None,
|
|
115
|
+
model_package_group_name=None,
|
|
116
|
+
containers=None,
|
|
117
|
+
content_types=None,
|
|
118
|
+
response_types=None,
|
|
119
|
+
inference_instances=None,
|
|
120
|
+
transform_instances=None,
|
|
121
|
+
model_metrics=None,
|
|
122
|
+
metadata_properties=None,
|
|
123
|
+
marketplace_cert=False,
|
|
124
|
+
approval_status="PendingManualApproval",
|
|
125
|
+
description=None,
|
|
126
|
+
tags=None,
|
|
127
|
+
drift_check_baselines=None,
|
|
128
|
+
customer_metadata_properties=None,
|
|
129
|
+
validation_specification=None,
|
|
130
|
+
domain=None,
|
|
131
|
+
sample_payload_url=None,
|
|
132
|
+
task=None,
|
|
133
|
+
skip_model_validation="None",
|
|
134
|
+
source_uri=None,
|
|
135
|
+
model_card=None,
|
|
136
|
+
model_life_cycle=None,
|
|
137
|
+
):
|
|
138
|
+
if all([model_package_name, model_package_group_name]):
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"model_package_name and model_package_group_name cannot be present at the " "same time."
|
|
141
|
+
)
|
|
142
|
+
if all([model_package_name, source_uri]):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
"Un-versioned SageMaker Model Package currently cannot be " "created with source_uri."
|
|
145
|
+
)
|
|
146
|
+
if (containers is not None) and all(
|
|
147
|
+
[
|
|
148
|
+
model_package_name,
|
|
149
|
+
any(
|
|
150
|
+
[
|
|
151
|
+
(("ModelDataSource" in c) and (c["ModelDataSource"] is not None))
|
|
152
|
+
for c in containers
|
|
153
|
+
]
|
|
154
|
+
),
|
|
155
|
+
]
|
|
156
|
+
):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Un-versioned SageMaker Model Package currently cannot be "
|
|
159
|
+
"created with ModelDataSource."
|
|
160
|
+
)
|
|
161
|
+
request_dict = {}
|
|
162
|
+
if model_package_name is not None:
|
|
163
|
+
request_dict["ModelPackageName"] = model_package_name
|
|
164
|
+
if model_package_group_name is not None:
|
|
165
|
+
request_dict["ModelPackageGroupName"] = model_package_group_name
|
|
166
|
+
if description is not None:
|
|
167
|
+
request_dict["ModelPackageDescription"] = description
|
|
168
|
+
if tags is not None:
|
|
169
|
+
request_dict["Tags"] = format_tags(tags)
|
|
170
|
+
if model_metrics:
|
|
171
|
+
request_dict["ModelMetrics"] = model_metrics
|
|
172
|
+
if drift_check_baselines:
|
|
173
|
+
request_dict["DriftCheckBaselines"] = drift_check_baselines
|
|
174
|
+
if metadata_properties:
|
|
175
|
+
request_dict["MetadataProperties"] = metadata_properties
|
|
176
|
+
if customer_metadata_properties is not None:
|
|
177
|
+
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
|
|
178
|
+
if validation_specification:
|
|
179
|
+
request_dict["ValidationSpecification"] = validation_specification
|
|
180
|
+
if domain is not None:
|
|
181
|
+
request_dict["Domain"] = domain
|
|
182
|
+
if sample_payload_url is not None:
|
|
183
|
+
request_dict["SamplePayloadUrl"] = sample_payload_url
|
|
184
|
+
if task is not None:
|
|
185
|
+
request_dict["Task"] = task
|
|
186
|
+
if source_uri is not None:
|
|
187
|
+
request_dict["SourceUri"] = source_uri
|
|
188
|
+
if containers is not None:
|
|
189
|
+
inference_specification = {
|
|
190
|
+
"Containers": containers,
|
|
191
|
+
}
|
|
192
|
+
if content_types is not None:
|
|
193
|
+
inference_specification.update(
|
|
194
|
+
{
|
|
195
|
+
"SupportedContentTypes": content_types,
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
if response_types is not None:
|
|
199
|
+
inference_specification.update(
|
|
200
|
+
{
|
|
201
|
+
"SupportedResponseMIMETypes": response_types,
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
if model_package_group_name is not None:
|
|
205
|
+
if inference_instances is not None:
|
|
206
|
+
inference_specification.update(
|
|
207
|
+
{
|
|
208
|
+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
if transform_instances is not None:
|
|
212
|
+
inference_specification.update(
|
|
213
|
+
{
|
|
214
|
+
"SupportedTransformInstanceTypes": transform_instances,
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
if not all([inference_instances, transform_instances]):
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"inference_instances and transform_instances "
|
|
221
|
+
"must be provided if model_package_group_name is not present."
|
|
222
|
+
)
|
|
223
|
+
inference_specification.update(
|
|
224
|
+
{
|
|
225
|
+
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
|
|
226
|
+
"SupportedTransformInstanceTypes": transform_instances,
|
|
227
|
+
}
|
|
228
|
+
)
|
|
229
|
+
request_dict["InferenceSpecification"] = inference_specification
|
|
230
|
+
request_dict["CertifyForMarketplace"] = marketplace_cert
|
|
231
|
+
request_dict["ModelApprovalStatus"] = approval_status
|
|
232
|
+
request_dict["SkipModelValidation"] = skip_model_validation
|
|
233
|
+
if model_card is not None:
|
|
234
|
+
request_dict["ModelCard"] = model_card
|
|
235
|
+
if model_life_cycle is not None:
|
|
236
|
+
request_dict["ModelLifeCycle"] = model_life_cycle
|
|
237
|
+
return request_dict
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def create_model_package_from_containers(
|
|
241
|
+
sagemaker_session,
|
|
242
|
+
containers=None,
|
|
243
|
+
content_types=None,
|
|
244
|
+
response_types=None,
|
|
245
|
+
inference_instances=None,
|
|
246
|
+
transform_instances=None,
|
|
247
|
+
model_package_name=None,
|
|
248
|
+
model_package_group_name=None,
|
|
249
|
+
model_metrics=None,
|
|
250
|
+
metadata_properties=None,
|
|
251
|
+
marketplace_cert=False,
|
|
252
|
+
approval_status="PendingManualApproval",
|
|
253
|
+
description=None,
|
|
254
|
+
drift_check_baselines=None,
|
|
255
|
+
customer_metadata_properties=None,
|
|
256
|
+
validation_specification=None,
|
|
257
|
+
domain=None,
|
|
258
|
+
sample_payload_url=None,
|
|
259
|
+
task=None,
|
|
260
|
+
skip_model_validation="None",
|
|
261
|
+
source_uri=None,
|
|
262
|
+
model_card=None,
|
|
263
|
+
model_life_cycle=None,
|
|
264
|
+
):
|
|
265
|
+
"""Get request dictionary for CreateModelPackage API.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
containers (list): A list of inference containers that can be used for inference
|
|
269
|
+
specifications of Model Package (default: None).
|
|
270
|
+
content_types (list): The supported MIME types for the input data (default: None).
|
|
271
|
+
response_types (list): The supported MIME types for the output data (default: None).
|
|
272
|
+
inference_instances (list): A list of the instance types that are used to
|
|
273
|
+
generate inferences in real-time (default: None).
|
|
274
|
+
transform_instances (list): A list of the instance types on which a transformation
|
|
275
|
+
job can be run or on which an endpoint can be deployed (default: None).
|
|
276
|
+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
|
|
277
|
+
using `model_package_name` makes the Model Package un-versioned (default: None).
|
|
278
|
+
model_package_group_name (str): Model Package Group name, exclusive to
|
|
279
|
+
`model_package_name`, using `model_package_group_name` makes the Model Package
|
|
280
|
+
versioned (default: None).
|
|
281
|
+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
|
|
282
|
+
metadata_properties (MetadataProperties): MetadataProperties object (default: None)
|
|
283
|
+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
|
|
284
|
+
for AWS Marketplace (default: False).
|
|
285
|
+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
|
|
286
|
+
or "PendingManualApproval" (default: "PendingManualApproval").
|
|
287
|
+
description (str): Model Package description (default: None).
|
|
288
|
+
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
|
|
289
|
+
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
|
|
290
|
+
metadata properties (default: None).
|
|
291
|
+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
|
|
292
|
+
"MACHINE_LEARNING" (default: None).
|
|
293
|
+
sample_payload_url (str): The S3 path where the sample payload is stored
|
|
294
|
+
(default: None).
|
|
295
|
+
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
|
|
296
|
+
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
|
|
297
|
+
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
|
|
298
|
+
skip_model_validation (str): Indicates if you want to skip model validation.
|
|
299
|
+
Values can be "All" or "None" (default: None).
|
|
300
|
+
source_uri (str): The URI of the source for the model package (default: None).
|
|
301
|
+
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
|
|
302
|
+
quantitative information about a model (default: None).
|
|
303
|
+
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
|
|
304
|
+
"""
|
|
305
|
+
if containers:
|
|
306
|
+
# Containers are provided. Now we can merge missing entries from config.
|
|
307
|
+
# If Containers are not provided, it is safe to ignore. This is because,
|
|
308
|
+
# if this object is provided to the API, then Image is required for Containers.
|
|
309
|
+
# That is not supported by the config now. So if we merge values from config,
|
|
310
|
+
# then API will throw an exception. In the future, when SageMaker Config starts
|
|
311
|
+
# supporting other parameters we can add that.
|
|
312
|
+
update_list_of_dicts_with_values_from_config(
|
|
313
|
+
containers,
|
|
314
|
+
MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH,
|
|
315
|
+
required_key_paths=["Image"],
|
|
316
|
+
sagemaker_session=sagemaker_session,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if validation_specification:
|
|
320
|
+
# ValidationSpecification is provided. Now we can merge missing entries from config.
|
|
321
|
+
# If ValidationSpecification is not provided, it is safe to ignore. This is because,
|
|
322
|
+
# if this object is provided to the API, then both ValidationProfiles and ValidationRole
|
|
323
|
+
# are required and for ValidationProfile, ProfileName is a required parameter. That is
|
|
324
|
+
# not supported by the config now. So if we merge values from config, then API will
|
|
325
|
+
# throw an exception. In the future, when SageMaker Config starts supporting other
|
|
326
|
+
# parameters we can add that.
|
|
327
|
+
validation_role = resolve_value_from_config(
|
|
328
|
+
validation_specification.get(VALIDATION_ROLE, None),
|
|
329
|
+
MODEL_PACKAGE_VALIDATION_ROLE_PATH,
|
|
330
|
+
sagemaker_session=sagemaker_session,
|
|
331
|
+
)
|
|
332
|
+
validation_specification[VALIDATION_ROLE] = validation_role
|
|
333
|
+
validation_profiles = validation_specification.get(VALIDATION_PROFILES, [])
|
|
334
|
+
update_list_of_dicts_with_values_from_config(
|
|
335
|
+
validation_profiles,
|
|
336
|
+
MODEL_PACKAGE_VALIDATION_PROFILES_PATH,
|
|
337
|
+
required_key_paths=["ProfileName", "TransformJobDefinition"],
|
|
338
|
+
sagemaker_session=sagemaker_session,
|
|
339
|
+
)
|
|
340
|
+
model_pkg_request = get_create_model_package_request(
|
|
341
|
+
model_package_name,
|
|
342
|
+
model_package_group_name,
|
|
343
|
+
containers,
|
|
344
|
+
content_types,
|
|
345
|
+
response_types,
|
|
346
|
+
inference_instances,
|
|
347
|
+
transform_instances,
|
|
348
|
+
model_metrics,
|
|
349
|
+
metadata_properties,
|
|
350
|
+
marketplace_cert,
|
|
351
|
+
approval_status,
|
|
352
|
+
description,
|
|
353
|
+
drift_check_baselines=drift_check_baselines,
|
|
354
|
+
customer_metadata_properties=customer_metadata_properties,
|
|
355
|
+
validation_specification=validation_specification,
|
|
356
|
+
domain=domain,
|
|
357
|
+
sample_payload_url=sample_payload_url,
|
|
358
|
+
task=task,
|
|
359
|
+
skip_model_validation=skip_model_validation,
|
|
360
|
+
source_uri=source_uri,
|
|
361
|
+
model_card=model_card,
|
|
362
|
+
model_life_cycle=model_life_cycle,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def submit(request):
|
|
366
|
+
if model_package_group_name is not None and not model_package_group_name.startswith("arn:"):
|
|
367
|
+
is_model_package_group_present = False
|
|
368
|
+
try:
|
|
369
|
+
model_package_groups_response = sagemaker_session.search(
|
|
370
|
+
resource="ModelPackageGroup",
|
|
371
|
+
search_expression={
|
|
372
|
+
"Filters": [
|
|
373
|
+
{
|
|
374
|
+
"Name": "ModelPackageGroupName",
|
|
375
|
+
"Value": request["ModelPackageGroupName"],
|
|
376
|
+
"Operator": "Equals",
|
|
377
|
+
}
|
|
378
|
+
],
|
|
379
|
+
},
|
|
380
|
+
)
|
|
381
|
+
if len(model_package_groups_response.get("Results")) > 0:
|
|
382
|
+
is_model_package_group_present = True
|
|
383
|
+
except Exception: # pylint: disable=W0703
|
|
384
|
+
model_package_groups = []
|
|
385
|
+
model_package_groups_response = (
|
|
386
|
+
sagemaker_session.sagemaker_client.list_model_package_groups(
|
|
387
|
+
NameContains=request["ModelPackageGroupName"],
|
|
388
|
+
)
|
|
389
|
+
)
|
|
390
|
+
model_package_groups = (
|
|
391
|
+
model_package_groups
|
|
392
|
+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
|
|
393
|
+
)
|
|
394
|
+
next_token = model_package_groups_response.get("NextToken")
|
|
395
|
+
|
|
396
|
+
while next_token is not None and next_token != "":
|
|
397
|
+
model_package_groups_response = (
|
|
398
|
+
sagemaker_session.sagemaker_client.list_model_package_groups(
|
|
399
|
+
NameContains=request["ModelPackageGroupName"], NextToken=next_token
|
|
400
|
+
)
|
|
401
|
+
)
|
|
402
|
+
model_package_groups = (
|
|
403
|
+
model_package_groups
|
|
404
|
+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
|
|
405
|
+
)
|
|
406
|
+
next_token = model_package_groups_response.get("NextToken")
|
|
407
|
+
|
|
408
|
+
filtered_model_package_group = list(
|
|
409
|
+
filter(
|
|
410
|
+
lambda mpg: mpg.get("ModelPackageGroupName")
|
|
411
|
+
== request["ModelPackageGroupName"],
|
|
412
|
+
model_package_groups,
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
is_model_package_group_present = len(filtered_model_package_group) > 0
|
|
416
|
+
if not is_model_package_group_present:
|
|
417
|
+
_create_resource(
|
|
418
|
+
lambda: sagemaker_session.sagemaker_client.create_model_package_group(
|
|
419
|
+
ModelPackageGroupName=request["ModelPackageGroupName"]
|
|
420
|
+
)
|
|
421
|
+
)
|
|
422
|
+
if "SourceUri" in request and request["SourceUri"] is not None:
|
|
423
|
+
# Remove inference spec from request if the
|
|
424
|
+
# given source uri can lead to auto-population of it
|
|
425
|
+
if can_model_package_source_uri_autopopulate(request["SourceUri"]):
|
|
426
|
+
if "InferenceSpecification" in request:
|
|
427
|
+
del request["InferenceSpecification"]
|
|
428
|
+
return sagemaker_session.sagemaker_client.create_model_package(**request)
|
|
429
|
+
# If source uri can't autopopulate,
|
|
430
|
+
# first create model package with just the inference spec
|
|
431
|
+
# and then update model package with the source uri.
|
|
432
|
+
# Done this way because passing source uri and inference spec together
|
|
433
|
+
# in create/update model package is not allowed in the base sdk.
|
|
434
|
+
request_source_uri = request["SourceUri"]
|
|
435
|
+
del request["SourceUri"]
|
|
436
|
+
model_package = sagemaker_session.sagemaker_client.create_model_package(**request)
|
|
437
|
+
update_source_uri_args = {
|
|
438
|
+
"ModelPackageArn": model_package.get("ModelPackageArn"),
|
|
439
|
+
"SourceUri": request_source_uri,
|
|
440
|
+
}
|
|
441
|
+
return sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)
|
|
442
|
+
return sagemaker_session.sagemaker_client.create_model_package(**request)
|
|
443
|
+
|
|
444
|
+
return sagemaker_session._intercept_create_request(
|
|
445
|
+
model_pkg_request, submit, create_model_package_from_containers.__name__
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
|
|
450
|
+
"""Create a SageMaker Model Package from the results of training with an Algorithm Package.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
name (str): ModelPackage name
|
|
454
|
+
description (str): Model Package description
|
|
455
|
+
algorithm_arn (str): arn or name of the algorithm used for training.
|
|
456
|
+
model_data (str or dict[str, Any]): s3 URI or a dictionary representing a
|
|
457
|
+
``ModelDataSource`` to the model artifacts produced by training
|
|
458
|
+
"""
|
|
459
|
+
sourceAlgorithm = {"AlgorithmName": algorithm_arn}
|
|
460
|
+
if isinstance(model_data, dict):
|
|
461
|
+
sourceAlgorithm["ModelDataSource"] = model_data
|
|
462
|
+
else:
|
|
463
|
+
sourceAlgorithm["ModelDataUrl"] = model_data
|
|
464
|
+
|
|
465
|
+
request = {
|
|
466
|
+
"ModelPackageName": name,
|
|
467
|
+
"ModelPackageDescription": description,
|
|
468
|
+
"SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]},
|
|
469
|
+
}
|
|
470
|
+
try:
|
|
471
|
+
logger.info("Creating model package with name: %s", name)
|
|
472
|
+
self.sagemaker_client.create_model_package(**request)
|
|
473
|
+
except ClientError as e:
|
|
474
|
+
error_code = e.response["Error"]["Code"]
|
|
475
|
+
message = e.response["Error"]["Message"]
|
|
476
|
+
|
|
477
|
+
if error_code == "ValidationException" and "ModelPackage already exists" in message:
|
|
478
|
+
logger.warning("Using already existing model package: %s", name)
|
|
479
|
+
else:
|
|
480
|
+
raise
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""Accessors to retrieve the model artifact S3 URI of pretrained machine learning models."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
from sagemaker.core.jumpstart import utils as jumpstart_utils
|
|
20
|
+
from sagemaker.core.jumpstart import artifacts
|
|
21
|
+
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
|
|
22
|
+
from sagemaker.core.jumpstart.enums import JumpStartModelType
|
|
23
|
+
from sagemaker.core.helper.session_helper import Session
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def retrieve(
|
|
30
|
+
region: Optional[str] = None,
|
|
31
|
+
model_id: Optional[str] = None,
|
|
32
|
+
model_version: Optional[str] = None,
|
|
33
|
+
hub_arn: Optional[str] = None,
|
|
34
|
+
model_scope: Optional[str] = None,
|
|
35
|
+
instance_type: Optional[str] = None,
|
|
36
|
+
tolerate_vulnerable_model: bool = False,
|
|
37
|
+
tolerate_deprecated_model: bool = False,
|
|
38
|
+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
|
39
|
+
config_name: Optional[str] = None,
|
|
40
|
+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
|
41
|
+
) -> str:
|
|
42
|
+
"""Retrieves the model artifact Amazon S3 URI for the model matching the given arguments.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
region (str): The AWS Region for which to retrieve the Jumpstart model S3 URI.
|
|
46
|
+
model_id (str): The model ID of the JumpStart model for which to retrieve
|
|
47
|
+
the model artifact S3 URI.
|
|
48
|
+
model_version (str): The version of the JumpStart model for which to retrieve
|
|
49
|
+
the model artifact S3 URI.
|
|
50
|
+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
|
|
51
|
+
model details from. (default: None).
|
|
52
|
+
model_scope (str): The model type.
|
|
53
|
+
Valid values: "training" and "inference".
|
|
54
|
+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
|
|
55
|
+
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
|
|
56
|
+
specifications should be tolerated without raising an exception. If ``False``, raises an
|
|
57
|
+
exception if the script used by this version of the model has dependencies with known
|
|
58
|
+
security vulnerabilities. (Default: False).
|
|
59
|
+
tolerate_deprecated_model (bool): ``True`` if deprecated versions of model
|
|
60
|
+
specifications should be tolerated without raising an exception. If ``False``, raises
|
|
61
|
+
an exception if the version of the model is deprecated. (Default: False).
|
|
62
|
+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
|
|
63
|
+
object, used for SageMaker interactions. If not
|
|
64
|
+
specified, one is created using the default AWS configuration
|
|
65
|
+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
|
|
66
|
+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
|
|
67
|
+
model_type (JumpStartModelType): The type of the model, can be open weights model
|
|
68
|
+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
str: The model artifact S3 URI for the corresponding model.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
NotImplementedError: If the scope is not supported.
|
|
75
|
+
ValueError: If the combination of arguments specified is not supported.
|
|
76
|
+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
|
|
77
|
+
known security vulnerabilities.
|
|
78
|
+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
|
|
79
|
+
"""
|
|
80
|
+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"Must specify JumpStart `model_id` and `model_version` when retrieving model URIs."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return artifacts._retrieve_model_uri(
|
|
86
|
+
model_id=model_id,
|
|
87
|
+
model_version=model_version, # type: ignore
|
|
88
|
+
hub_arn=hub_arn,
|
|
89
|
+
model_scope=model_scope,
|
|
90
|
+
instance_type=instance_type,
|
|
91
|
+
region=region,
|
|
92
|
+
tolerate_vulnerable_model=tolerate_vulnerable_model,
|
|
93
|
+
tolerate_deprecated_model=tolerate_deprecated_model,
|
|
94
|
+
sagemaker_session=sagemaker_session,
|
|
95
|
+
config_name=config_name,
|
|
96
|
+
model_type=model_type,
|
|
97
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""SageMaker modules directory."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from sagemaker.core.utils.utils import logger as sagemaker_core_logger
|
|
17
|
+
from sagemaker.core.helper.session_helper import Session, get_execution_role # noqa: F401
|
|
18
|
+
|
|
19
|
+
logger = sagemaker_core_logger
|