sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sagemaker/core/__init__.py +16 -0
- sagemaker/core/_studio.py +116 -0
- sagemaker/core/_version.py +11 -0
- sagemaker/core/accept_types.py +131 -0
- sagemaker/core/analytics.py +744 -0
- sagemaker/core/apiutils/__init__.py +13 -0
- sagemaker/core/apiutils/_base_types.py +228 -0
- sagemaker/core/apiutils/_boto_functions.py +130 -0
- sagemaker/core/apiutils/_utils.py +34 -0
- sagemaker/core/base_deserializers.py +35 -0
- sagemaker/core/base_serializers.py +35 -0
- sagemaker/core/clarify/__init__.py +2898 -0
- sagemaker/core/collection.py +467 -0
- sagemaker/core/common_utils.py +2281 -0
- sagemaker/core/compute_resource_requirements/__init__.py +18 -0
- sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
- sagemaker/core/config/__init__.py +181 -0
- sagemaker/core/config/config.py +238 -0
- sagemaker/core/config/config_manager.py +595 -0
- sagemaker/core/config/config_schema.py +1220 -0
- sagemaker/core/config/config_utils.py +297 -0
- {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
- sagemaker/core/constants.py +73 -0
- sagemaker/core/content_types.py +137 -0
- sagemaker/core/debugger/__init__.py +39 -0
- sagemaker/core/debugger/debugger.py +945 -0
- sagemaker/core/debugger/framework_profile.py +292 -0
- sagemaker/core/debugger/metrics_config.py +468 -0
- sagemaker/core/debugger/profiler.py +42 -0
- sagemaker/core/debugger/profiler_config.py +190 -0
- sagemaker/core/debugger/profiler_constants.py +40 -0
- sagemaker/core/debugger/utils.py +148 -0
- sagemaker/core/deprecations.py +254 -0
- sagemaker/core/deserializers/__init__.py +10 -0
- sagemaker/core/deserializers/base.py +424 -0
- sagemaker/core/deserializers/implementations.py +157 -0
- sagemaker/core/drift_check_baselines.py +106 -0
- sagemaker/core/enums.py +51 -0
- sagemaker/core/environment_variables.py +101 -0
- sagemaker/core/exceptions.py +108 -0
- sagemaker/core/experiments/__init__.py +53 -0
- sagemaker/core/experiments/_api_types.py +251 -0
- sagemaker/core/experiments/_environment.py +124 -0
- sagemaker/core/experiments/_helper.py +294 -0
- sagemaker/core/experiments/_metrics.py +333 -0
- sagemaker/core/experiments/_run_context.py +58 -0
- sagemaker/core/experiments/_utils.py +216 -0
- sagemaker/core/experiments/experiment.py +244 -0
- sagemaker/core/experiments/run.py +970 -0
- sagemaker/core/experiments/trial.py +296 -0
- sagemaker/core/experiments/trial_component.py +387 -0
- sagemaker/core/explainer/__init__.py +24 -0
- sagemaker/core/explainer/clarify_explainer_config.py +298 -0
- sagemaker/core/explainer/explainer_config.py +44 -0
- sagemaker/core/fw_utils.py +1176 -0
- sagemaker/core/git_utils.py +349 -0
- sagemaker/core/helper/pipeline_variable.py +82 -0
- sagemaker/core/helper/session_helper.py +2965 -0
- sagemaker/core/huggingface/__init__.py +29 -0
- sagemaker/core/huggingface/llm_utils.py +150 -0
- sagemaker/core/huggingface/processing.py +139 -0
- sagemaker/core/huggingface/training_compiler/config.py +167 -0
- sagemaker/core/hyperparameters.py +172 -0
- sagemaker/core/image_retriever/__init__.py +3 -0
- sagemaker/core/image_retriever/image_retriever.py +640 -0
- sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
- sagemaker/core/image_retriever/test.py +7 -0
- sagemaker/core/image_uri_config/__init__.py +13 -0
- sagemaker/core/image_uri_config/autogluon.json +1335 -0
- sagemaker/core/image_uri_config/blazingtext.json +50 -0
- sagemaker/core/image_uri_config/chainer.json +104 -0
- sagemaker/core/image_uri_config/clarify.json +39 -0
- sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
- sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
- sagemaker/core/image_uri_config/data-wrangler.json +91 -0
- sagemaker/core/image_uri_config/debugger.json +34 -0
- sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
- sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
- sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
- sagemaker/core/image_uri_config/djl-lmi.json +136 -0
- sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
- sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
- sagemaker/core/image_uri_config/factorization-machines.json +50 -0
- sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
- sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
- sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
- sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
- sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
- sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
- sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
- sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
- sagemaker/core/image_uri_config/huggingface.json +2138 -0
- sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
- sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
- sagemaker/core/image_uri_config/image-classification.json +50 -0
- sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
- sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
- sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
- sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
- sagemaker/core/image_uri_config/ipinsights.json +50 -0
- sagemaker/core/image_uri_config/kmeans.json +50 -0
- sagemaker/core/image_uri_config/knn.json +50 -0
- sagemaker/core/image_uri_config/lda.json +26 -0
- sagemaker/core/image_uri_config/linear-learner.json +50 -0
- sagemaker/core/image_uri_config/model-monitor.json +42 -0
- sagemaker/core/image_uri_config/mxnet.json +1154 -0
- sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
- sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
- sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
- sagemaker/core/image_uri_config/ntm.json +50 -0
- sagemaker/core/image_uri_config/object-detection.json +50 -0
- sagemaker/core/image_uri_config/object2vec.json +50 -0
- sagemaker/core/image_uri_config/pca.json +50 -0
- sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
- sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
- sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
- sagemaker/core/image_uri_config/pytorch.json +3101 -0
- sagemaker/core/image_uri_config/randomcutforest.json +50 -0
- sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
- sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
- sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
- sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
- sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
- sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
- sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
- sagemaker/core/image_uri_config/seq2seq.json +50 -0
- sagemaker/core/image_uri_config/sklearn.json +446 -0
- sagemaker/core/image_uri_config/spark.json +280 -0
- sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
- sagemaker/core/image_uri_config/stabilityai.json +53 -0
- sagemaker/core/image_uri_config/tensorflow.json +5086 -0
- sagemaker/core/image_uri_config/vw.json +25 -0
- sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
- sagemaker/core/image_uri_config/xgboost.json +888 -0
- sagemaker/core/image_uris.py +810 -0
- sagemaker/core/inference_config.py +144 -0
- sagemaker/core/inference_recommender/__init__.py +18 -0
- sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
- sagemaker/core/inputs.py +366 -0
- sagemaker/core/instance_group.py +61 -0
- sagemaker/core/instance_types.py +164 -0
- sagemaker/core/instance_types_gpu_info.py +43 -0
- sagemaker/core/interactive_apps/__init__.py +41 -0
- sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
- sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
- sagemaker/core/interactive_apps/tensorboard.py +149 -0
- sagemaker/core/iterators.py +186 -0
- sagemaker/core/job.py +380 -0
- sagemaker/core/jumpstart/__init__.py +156 -0
- sagemaker/core/jumpstart/accessors.py +390 -0
- sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
- sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
- sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
- sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
- sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
- sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
- sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
- sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
- sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
- sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
- sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
- sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
- sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
- sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
- sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
- sagemaker/core/jumpstart/cache.py +663 -0
- sagemaker/core/jumpstart/configs.py +50 -0
- sagemaker/core/jumpstart/constants.py +198 -0
- sagemaker/core/jumpstart/deserializers.py +81 -0
- sagemaker/core/jumpstart/document.py +76 -0
- sagemaker/core/jumpstart/enums.py +168 -0
- sagemaker/core/jumpstart/exceptions.py +236 -0
- sagemaker/core/jumpstart/factory/utils.py +833 -0
- sagemaker/core/jumpstart/filters.py +597 -0
- sagemaker/core/jumpstart/hub/constants.py +16 -0
- sagemaker/core/jumpstart/hub/hub.py +291 -0
- sagemaker/core/jumpstart/hub/interfaces.py +936 -0
- sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
- sagemaker/core/jumpstart/hub/parsers.py +288 -0
- sagemaker/core/jumpstart/hub/types.py +35 -0
- sagemaker/core/jumpstart/hub/utils.py +260 -0
- sagemaker/core/jumpstart/models.py +499 -0
- sagemaker/core/jumpstart/notebook_utils.py +575 -0
- sagemaker/core/jumpstart/parameters.py +20 -0
- sagemaker/core/jumpstart/payload_utils.py +239 -0
- sagemaker/core/jumpstart/region_config.json +163 -0
- sagemaker/core/jumpstart/search.py +171 -0
- sagemaker/core/jumpstart/serializers.py +81 -0
- sagemaker/core/jumpstart/session_utils.py +234 -0
- sagemaker/core/jumpstart/types.py +3044 -0
- sagemaker/core/jumpstart/utils.py +1731 -0
- sagemaker/core/jumpstart/validators.py +257 -0
- sagemaker/core/lambda_helper.py +312 -0
- sagemaker/core/lineage/__init__.py +42 -0
- sagemaker/core/lineage/_api_types.py +239 -0
- sagemaker/core/lineage/_utils.py +49 -0
- sagemaker/core/lineage/action.py +345 -0
- sagemaker/core/lineage/artifact.py +646 -0
- sagemaker/core/lineage/association.py +190 -0
- sagemaker/core/lineage/context.py +505 -0
- sagemaker/core/lineage/lineage_trial_component.py +191 -0
- sagemaker/core/lineage/query.py +732 -0
- sagemaker/core/lineage/visualizer.py +346 -0
- sagemaker/core/local/__init__.py +18 -0
- sagemaker/core/local/data.py +413 -0
- sagemaker/core/local/entities.py +678 -0
- sagemaker/core/local/exceptions.py +17 -0
- sagemaker/core/local/image.py +1243 -0
- sagemaker/core/local/local_session.py +739 -0
- sagemaker/core/local/utils.py +245 -0
- sagemaker/core/logs.py +181 -0
- sagemaker/core/metadata_properties.py +56 -0
- sagemaker/core/metric_definitions.py +91 -0
- sagemaker/core/mlflow/__init__.py +38 -0
- sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
- sagemaker/core/model_card/__init__.py +26 -0
- sagemaker/core/model_life_cycle.py +51 -0
- sagemaker/core/model_metrics.py +160 -0
- sagemaker/core/model_monitor/__init__.py +66 -0
- sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
- sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
- sagemaker/core/model_monitor/data_capture_config.py +115 -0
- sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
- sagemaker/core/model_monitor/dataset_format.py +102 -0
- sagemaker/core/model_monitor/model_monitoring.py +4266 -0
- sagemaker/core/model_monitor/monitoring_alert.py +76 -0
- sagemaker/core/model_monitor/monitoring_files.py +506 -0
- sagemaker/core/model_monitor/utils.py +793 -0
- sagemaker/core/model_registry.py +480 -0
- sagemaker/core/model_uris.py +97 -0
- sagemaker/core/modules/__init__.py +19 -0
- sagemaker/core/modules/configs.py +226 -0
- sagemaker/core/modules/constants.py +37 -0
- sagemaker/core/modules/distributed.py +182 -0
- sagemaker/core/modules/local_core/__init__.py +0 -0
- sagemaker/core/modules/local_core/local_container.py +605 -0
- sagemaker/core/modules/templates.py +83 -0
- sagemaker/core/modules/train/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
- sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
- sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
- sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
- sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
- sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
- sagemaker/core/modules/types.py +19 -0
- sagemaker/core/modules/utils.py +194 -0
- sagemaker/core/network.py +185 -0
- sagemaker/core/parameter.py +173 -0
- sagemaker/core/payloads.py +185 -0
- sagemaker/core/processing.py +1597 -0
- sagemaker/core/remote_function/__init__.py +19 -0
- sagemaker/core/remote_function/checkpoint_location.py +47 -0
- sagemaker/core/remote_function/client.py +1285 -0
- sagemaker/core/remote_function/core/__init__.py +0 -0
- sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
- sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
- sagemaker/core/remote_function/core/serialization.py +422 -0
- sagemaker/core/remote_function/core/stored_function.py +226 -0
- sagemaker/core/remote_function/custom_file_filter.py +128 -0
- sagemaker/core/remote_function/errors.py +104 -0
- sagemaker/core/remote_function/invoke_function.py +172 -0
- sagemaker/core/remote_function/job.py +2140 -0
- sagemaker/core/remote_function/logging_config.py +38 -0
- sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
- sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
- sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
- sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
- sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
- sagemaker/core/remote_function/spark_config.py +149 -0
- sagemaker/core/resource_requirements.py +168 -0
- {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
- sagemaker/core/s3/__init__.py +41 -0
- sagemaker/core/s3/client.py +367 -0
- sagemaker/core/s3/utils.py +175 -0
- sagemaker/core/script_uris.py +93 -0
- sagemaker/core/serializers/__init__.py +11 -0
- sagemaker/core/serializers/base.py +510 -0
- sagemaker/core/serializers/implementations.py +159 -0
- sagemaker/core/serializers/utils.py +223 -0
- sagemaker/core/serverless_inference_config.py +63 -0
- sagemaker/core/session_settings.py +55 -0
- sagemaker/core/shapes/__init__.py +3 -0
- sagemaker/core/shapes/model_card_shapes.py +159 -0
- {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
- sagemaker/core/spark/__init__.py +16 -0
- sagemaker/core/spark/defaults.py +16 -0
- sagemaker/core/spark/processing.py +1380 -0
- sagemaker/core/telemetry/__init__.py +23 -0
- sagemaker/core/telemetry/constants.py +84 -0
- sagemaker/core/telemetry/telemetry_logging.py +284 -0
- sagemaker/core/tools/__init__.py +1 -0
- {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
- {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
- {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
- {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
- sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
- {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
- {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
- {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
- {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
- {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
- sagemaker/core/training/__init__.py +14 -0
- sagemaker/core/training/configs.py +333 -0
- sagemaker/core/training/constants.py +37 -0
- sagemaker/core/training/utils.py +77 -0
- sagemaker/core/training_compiler/__init__.py +16 -0
- sagemaker/core/training_compiler/config.py +197 -0
- sagemaker/core/training_compiler_config.py +197 -0
- sagemaker/core/transformer.py +793 -0
- sagemaker/core/user_agent.py +76 -0
- sagemaker/core/utilities/__init__.py +24 -0
- sagemaker/core/utilities/cache.py +169 -0
- sagemaker/core/utilities/search_expression.py +133 -0
- sagemaker/core/utils/__init__.py +48 -0
- sagemaker/core/utils/code_injection/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
- {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
- sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
- {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
- {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
- sagemaker/core/workflow/__init__.py +152 -0
- sagemaker/core/workflow/conditions.py +313 -0
- sagemaker/core/workflow/entities.py +58 -0
- sagemaker/core/workflow/execution_variables.py +89 -0
- sagemaker/core/workflow/functions.py +193 -0
- sagemaker/core/workflow/parameters.py +222 -0
- sagemaker/core/workflow/pipeline_context.py +394 -0
- sagemaker/core/workflow/pipeline_definition_config.py +31 -0
- sagemaker/core/workflow/properties.py +285 -0
- sagemaker/core/workflow/step_outputs.py +65 -0
- sagemaker/core/workflow/utilities.py +507 -0
- sagemaker/lineage/__init__.py +33 -0
- sagemaker/lineage/action.py +28 -0
- sagemaker/lineage/artifact.py +28 -0
- sagemaker/lineage/context.py +28 -0
- sagemaker/lineage/lineage_trial_component.py +28 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
- sagemaker_core-2.1.1.dist-info/RECORD +355 -0
- sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
- sagemaker_core/_version.py +0 -3
- sagemaker_core/helper/session_helper.py +0 -769
- sagemaker_core/resources/__init__.py +0 -1
- sagemaker_core/shapes/__init__.py +0 -1
- sagemaker_core/tools/__init__.py +0 -1
- sagemaker_core-1.0.62.dist-info/RECORD +0 -35
- sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
- {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
- {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
- {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
- {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
- {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,732 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
4
|
+
# may not use this file except in compliance with the License. A copy of
|
|
5
|
+
# the License is located at
|
|
6
|
+
#
|
|
7
|
+
# http://aws.amazon.com/apache2.0/
|
|
8
|
+
#
|
|
9
|
+
# or in the "license" file accompanying this file. This file is
|
|
10
|
+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
11
|
+
# ANY KIND, either express or implied. See the License for the specific
|
|
12
|
+
# language governing permissions and limitations under the License.
|
|
13
|
+
"""This module contains code to query SageMaker lineage."""
|
|
14
|
+
from __future__ import absolute_import
|
|
15
|
+
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import Any, Optional, Union, List, Dict
|
|
19
|
+
from json import dumps
|
|
20
|
+
from re import sub, search
|
|
21
|
+
|
|
22
|
+
from sagemaker.core.common_utils import get_module
|
|
23
|
+
from sagemaker.core.lineage._utils import get_resource_name_from_arn
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LineageEntityEnum(Enum):
|
|
27
|
+
"""Enum of lineage entities for use in a query filter."""
|
|
28
|
+
|
|
29
|
+
TRIAL = "Trial"
|
|
30
|
+
ACTION = "Action"
|
|
31
|
+
ARTIFACT = "Artifact"
|
|
32
|
+
CONTEXT = "Context"
|
|
33
|
+
TRIAL_COMPONENT = "TrialComponent"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LineageSourceEnum(Enum):
|
|
37
|
+
"""Enum of lineage types for use in a query filter."""
|
|
38
|
+
|
|
39
|
+
CHECKPOINT = "Checkpoint"
|
|
40
|
+
DATASET = "DataSet"
|
|
41
|
+
ENDPOINT = "Endpoint"
|
|
42
|
+
IMAGE = "Image"
|
|
43
|
+
MODEL = "Model"
|
|
44
|
+
MODEL_DATA = "ModelData"
|
|
45
|
+
MODEL_DEPLOYMENT = "ModelDeployment"
|
|
46
|
+
MODEL_GROUP = "ModelGroup"
|
|
47
|
+
MODEL_REPLACE = "ModelReplaced"
|
|
48
|
+
TENSORBOARD = "TensorBoard"
|
|
49
|
+
TRAINING_JOB = "TrainingJob"
|
|
50
|
+
APPROVAL = "Approval"
|
|
51
|
+
PROCESSING_JOB = "ProcessingJob"
|
|
52
|
+
TRANSFORM_JOB = "TransformJob"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class LineageQueryDirectionEnum(Enum):
|
|
56
|
+
"""Enum of query filter directions."""
|
|
57
|
+
|
|
58
|
+
BOTH = "Both"
|
|
59
|
+
ASCENDANTS = "Ascendants"
|
|
60
|
+
DESCENDANTS = "Descendants"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Edge:
|
|
64
|
+
"""A connecting edge for a lineage graph."""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
source_arn: str,
|
|
69
|
+
destination_arn: str,
|
|
70
|
+
association_type: str,
|
|
71
|
+
):
|
|
72
|
+
"""Initialize ``Edge`` instance."""
|
|
73
|
+
self.source_arn = source_arn
|
|
74
|
+
self.destination_arn = destination_arn
|
|
75
|
+
self.association_type = association_type
|
|
76
|
+
|
|
77
|
+
def __hash__(self):
|
|
78
|
+
"""Define hash function for ``Edge``."""
|
|
79
|
+
return hash(
|
|
80
|
+
(
|
|
81
|
+
"source_arn",
|
|
82
|
+
self.source_arn,
|
|
83
|
+
"destination_arn",
|
|
84
|
+
self.destination_arn,
|
|
85
|
+
"association_type",
|
|
86
|
+
self.association_type,
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def __eq__(self, other):
|
|
91
|
+
"""Define equal function for ``Edge``."""
|
|
92
|
+
return (
|
|
93
|
+
self.association_type == other.association_type
|
|
94
|
+
and self.source_arn == other.source_arn
|
|
95
|
+
and self.destination_arn == other.destination_arn
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def __str__(self):
|
|
99
|
+
"""Define string representation of ``Edge``.
|
|
100
|
+
|
|
101
|
+
Format:
|
|
102
|
+
{
|
|
103
|
+
'source_arn': 'string',
|
|
104
|
+
'destination_arn': 'string',
|
|
105
|
+
'association_type': 'string'
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
return str(self.__dict__)
|
|
110
|
+
|
|
111
|
+
def __repr__(self):
|
|
112
|
+
"""Define string representation of ``Edge``.
|
|
113
|
+
|
|
114
|
+
Format:
|
|
115
|
+
{
|
|
116
|
+
'source_arn': 'string',
|
|
117
|
+
'destination_arn': 'string',
|
|
118
|
+
'association_type': 'string'
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
"""
|
|
122
|
+
return "\n\t" + str(self.__dict__)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Vertex:
|
|
126
|
+
"""A vertex for a lineage graph."""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
arn: str,
|
|
131
|
+
lineage_entity: str,
|
|
132
|
+
lineage_source: str,
|
|
133
|
+
sagemaker_session,
|
|
134
|
+
):
|
|
135
|
+
"""Initialize ``Vertex`` instance."""
|
|
136
|
+
self.arn = arn
|
|
137
|
+
self.lineage_entity = lineage_entity
|
|
138
|
+
self.lineage_source = lineage_source
|
|
139
|
+
self._session = sagemaker_session
|
|
140
|
+
|
|
141
|
+
def __hash__(self):
|
|
142
|
+
"""Define hash function for ``Vertex``."""
|
|
143
|
+
return hash(
|
|
144
|
+
(
|
|
145
|
+
"arn",
|
|
146
|
+
self.arn,
|
|
147
|
+
"lineage_entity",
|
|
148
|
+
self.lineage_entity,
|
|
149
|
+
"lineage_source",
|
|
150
|
+
self.lineage_source,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def __eq__(self, other):
|
|
155
|
+
"""Define equal function for ``Vertex``."""
|
|
156
|
+
return (
|
|
157
|
+
self.arn == other.arn
|
|
158
|
+
and self.lineage_entity == other.lineage_entity
|
|
159
|
+
and self.lineage_source == other.lineage_source
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def __str__(self):
|
|
163
|
+
"""Define string representation of ``Vertex``.
|
|
164
|
+
|
|
165
|
+
Format:
|
|
166
|
+
{
|
|
167
|
+
'arn': 'string',
|
|
168
|
+
'lineage_entity': 'string',
|
|
169
|
+
'lineage_source': 'string',
|
|
170
|
+
'_session': <sagemaker.session.Session object>
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
return str(self.__dict__)
|
|
175
|
+
|
|
176
|
+
def __repr__(self):
|
|
177
|
+
"""Define string representation of ``Vertex``.
|
|
178
|
+
|
|
179
|
+
Format:
|
|
180
|
+
{
|
|
181
|
+
'arn': 'string',
|
|
182
|
+
'lineage_entity': 'string',
|
|
183
|
+
'lineage_source': 'string',
|
|
184
|
+
'_session': <sagemaker.session.Session object>
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
"""
|
|
188
|
+
return "\n\t" + str(self.__dict__)
|
|
189
|
+
|
|
190
|
+
def to_lineage_object(self):
|
|
191
|
+
"""Convert the ``Vertex`` object to its corresponding lineage object.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
A ``Vertex`` object to its corresponding ``Artifact``,``Action``, ``Context``
|
|
195
|
+
or ``TrialComponent`` object.
|
|
196
|
+
"""
|
|
197
|
+
from sagemaker.lineage.context import Context, EndpointContext
|
|
198
|
+
from sagemaker.lineage.action import Action
|
|
199
|
+
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
|
|
200
|
+
|
|
201
|
+
if self.lineage_entity == LineageEntityEnum.CONTEXT.value:
|
|
202
|
+
resource_name = get_resource_name_from_arn(self.arn)
|
|
203
|
+
if self.lineage_source == LineageSourceEnum.ENDPOINT.value:
|
|
204
|
+
return EndpointContext.load(
|
|
205
|
+
context_name=resource_name, sagemaker_session=self._session
|
|
206
|
+
)
|
|
207
|
+
return Context.load(context_name=resource_name, sagemaker_session=self._session)
|
|
208
|
+
|
|
209
|
+
if self.lineage_entity == LineageEntityEnum.ARTIFACT.value:
|
|
210
|
+
return self._artifact_to_lineage_object()
|
|
211
|
+
|
|
212
|
+
if self.lineage_entity == LineageEntityEnum.ACTION.value:
|
|
213
|
+
return Action.load(action_name=self.arn.split("/")[1], sagemaker_session=self._session)
|
|
214
|
+
|
|
215
|
+
if self.lineage_entity == LineageEntityEnum.TRIAL_COMPONENT.value:
|
|
216
|
+
trial_component_name = get_resource_name_from_arn(self.arn)
|
|
217
|
+
return LineageTrialComponent.load(
|
|
218
|
+
trial_component_name=trial_component_name, sagemaker_session=self._session
|
|
219
|
+
)
|
|
220
|
+
raise ValueError("Vertex cannot be converted to a lineage object.")
|
|
221
|
+
|
|
222
|
+
def _artifact_to_lineage_object(self):
|
|
223
|
+
"""Convert the ``Vertex`` object to its corresponding ``Artifact``."""
|
|
224
|
+
from sagemaker.lineage.artifact import Artifact, ModelArtifact, ImageArtifact
|
|
225
|
+
from sagemaker.lineage.artifact import DatasetArtifact
|
|
226
|
+
|
|
227
|
+
if self.lineage_source == LineageSourceEnum.MODEL.value:
|
|
228
|
+
return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
|
|
229
|
+
if self.lineage_source == LineageSourceEnum.DATASET.value:
|
|
230
|
+
return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
|
|
231
|
+
if self.lineage_source == LineageSourceEnum.IMAGE.value:
|
|
232
|
+
return ImageArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
|
|
233
|
+
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class PyvisVisualizer(object):
|
|
237
|
+
"""Create object used for visualizing graph using Pyvis library."""
|
|
238
|
+
|
|
239
|
+
def __init__(self, graph_styles, pyvis_options: Optional[Dict[str, Any]] = None):
|
|
240
|
+
"""Init for PyvisVisualizer.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
graph_styles: A dictionary that contains graph style for node and edges by their type.
|
|
244
|
+
Example: Display the nodes with different color by their lineage entity / different
|
|
245
|
+
shape by start arn.
|
|
246
|
+
lineage_graph_styles = {
|
|
247
|
+
"TrialComponent": {
|
|
248
|
+
"name": "Trial Component",
|
|
249
|
+
"style": {"background-color": "#f6cf61"},
|
|
250
|
+
"isShape": "False",
|
|
251
|
+
},
|
|
252
|
+
"Context": {
|
|
253
|
+
"name": "Context",
|
|
254
|
+
"style": {"background-color": "#ff9900"},
|
|
255
|
+
"isShape": "False",
|
|
256
|
+
},
|
|
257
|
+
"StartArn": {
|
|
258
|
+
"name": "StartArn",
|
|
259
|
+
"style": {"shape": "star"},
|
|
260
|
+
"isShape": "True",
|
|
261
|
+
"symbol": "★", # shape symbol for legend
|
|
262
|
+
},
|
|
263
|
+
}
|
|
264
|
+
pyvis_options(optional): A dict containing PyVis options to customize visualization.
|
|
265
|
+
(see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
|
|
266
|
+
"""
|
|
267
|
+
# import visualization packages
|
|
268
|
+
(
|
|
269
|
+
self.Network,
|
|
270
|
+
self.Options,
|
|
271
|
+
self.IFrame,
|
|
272
|
+
self.BeautifulSoup,
|
|
273
|
+
) = self._import_visual_modules()
|
|
274
|
+
|
|
275
|
+
self.graph_styles = graph_styles
|
|
276
|
+
|
|
277
|
+
if pyvis_options is None:
|
|
278
|
+
# default pyvis graph options
|
|
279
|
+
pyvis_options = {
|
|
280
|
+
"configure": {"enabled": False},
|
|
281
|
+
"layout": {
|
|
282
|
+
"hierarchical": {
|
|
283
|
+
"enabled": True,
|
|
284
|
+
"blockShifting": True,
|
|
285
|
+
"direction": "LR",
|
|
286
|
+
"sortMethod": "directed",
|
|
287
|
+
"shakeTowards": "leaves",
|
|
288
|
+
}
|
|
289
|
+
},
|
|
290
|
+
"interaction": {"multiselect": True, "navigationButtons": True},
|
|
291
|
+
"physics": {
|
|
292
|
+
"enabled": False,
|
|
293
|
+
"hierarchicalRepulsion": {"centralGravity": 0, "avoidOverlap": None},
|
|
294
|
+
"minVelocity": 0.75,
|
|
295
|
+
"solver": "hierarchicalRepulsion",
|
|
296
|
+
},
|
|
297
|
+
}
|
|
298
|
+
# A string representation of a Javascript-like object used to override pyvis options
|
|
299
|
+
self._pyvis_options = f"var options = {dumps(pyvis_options)}"
|
|
300
|
+
|
|
301
|
+
def _import_visual_modules(self):
|
|
302
|
+
"""Import modules needed for visualization."""
|
|
303
|
+
get_module("pyvis")
|
|
304
|
+
from pyvis.network import Network
|
|
305
|
+
from pyvis.options import Options
|
|
306
|
+
from IPython.display import IFrame
|
|
307
|
+
|
|
308
|
+
get_module("bs4")
|
|
309
|
+
from bs4 import BeautifulSoup
|
|
310
|
+
|
|
311
|
+
return Network, Options, IFrame, BeautifulSoup
|
|
312
|
+
|
|
313
|
+
def _node_color(self, entity):
|
|
314
|
+
"""Return node color by background-color specified in graph styles."""
|
|
315
|
+
return self.graph_styles[entity]["style"]["background-color"]
|
|
316
|
+
|
|
317
|
+
def _get_legend_line(self, component_name):
|
|
318
|
+
"""Generate lengend div line for each graph component in graph_styles."""
|
|
319
|
+
if self.graph_styles[component_name]["isShape"] == "False":
|
|
320
|
+
return '<div><div style="background-color: {color}; width: 1.6vw; height: 1.6vw;\
|
|
321
|
+
display: inline-block; font-size: 1.5vw; vertical-align: -0.2em;"></div>\
|
|
322
|
+
<div style="width: 0.3vw; height: 1.5vw; display: inline-block;"></div>\
|
|
323
|
+
<div style="display: inline-block; font-size: 1.5vw;">{name}</div></div>'.format(
|
|
324
|
+
color=self.graph_styles[component_name]["style"]["background-color"],
|
|
325
|
+
name=self.graph_styles[component_name]["name"],
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return '<div style="background-color: #ffffff; width: 1.6vw; height: 1.6vw;\
|
|
329
|
+
display: inline-block; font-size: 0.9vw; vertical-align: -0.2em;">{shape}</div>\
|
|
330
|
+
<div style="width: 0.3vw; height: 1.5vw; display: inline-block;"></div>\
|
|
331
|
+
<div style="display: inline-block; font-size: 1.5vw;">{name}</div></div>'.format(
|
|
332
|
+
shape=self.graph_styles[component_name]["style"]["shape"],
|
|
333
|
+
name=self.graph_styles[component_name]["name"],
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def _add_legend(self, path):
|
|
337
|
+
"""Embed legend to html file generated by pyvis."""
|
|
338
|
+
with open(path, "r") as f:
|
|
339
|
+
content = self.BeautifulSoup(f, "html.parser")
|
|
340
|
+
|
|
341
|
+
legend = """
|
|
342
|
+
<div style="display: inline-block; font-size: 1vw; font-family: verdana;
|
|
343
|
+
vertical-align: top; padding: 1vw;">
|
|
344
|
+
"""
|
|
345
|
+
# iterate through graph styles to get legend
|
|
346
|
+
for component in self.graph_styles.keys():
|
|
347
|
+
legend += self._get_legend_line(component_name=component)
|
|
348
|
+
|
|
349
|
+
legend += "</div>"
|
|
350
|
+
|
|
351
|
+
legend_div = self.BeautifulSoup(legend, "html.parser")
|
|
352
|
+
|
|
353
|
+
content.div.insert_after(legend_div)
|
|
354
|
+
|
|
355
|
+
html = content.prettify()
|
|
356
|
+
|
|
357
|
+
with open(path, "w", encoding="utf8") as file:
|
|
358
|
+
file.write(html)
|
|
359
|
+
|
|
360
|
+
def render(self, elements, path="lineage_graph_pyvis.html"):
|
|
361
|
+
"""Render graph for lineage query result.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
elements: A dictionary that contains the node and the edges of the graph.
|
|
365
|
+
Example:
|
|
366
|
+
elements["nodes"] contains list of tuples, each tuple represents a node
|
|
367
|
+
format: (node arn, node lineage source, node lineage entity,
|
|
368
|
+
node is start arn)
|
|
369
|
+
elements["edges"] contains list of tuples, each tuple represents an edge
|
|
370
|
+
format: (edge source arn, edge destination arn, edge association type)
|
|
371
|
+
|
|
372
|
+
path(optional): The path/filename of the rendered graph html file.
|
|
373
|
+
(default path: "lineage_graph_pyvis.html")
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
display graph: The interactive visualization is presented as a static HTML file.
|
|
377
|
+
|
|
378
|
+
"""
|
|
379
|
+
net = self.Network(height="600px", width="82%", notebook=True, directed=True)
|
|
380
|
+
net.set_options(self._pyvis_options)
|
|
381
|
+
|
|
382
|
+
# add nodes to graph
|
|
383
|
+
for arn, source, entity, is_start_arn in elements["nodes"]:
|
|
384
|
+
entity_text = sub(r"(\w)([A-Z])", r"\1 \2", entity)
|
|
385
|
+
source = sub(r"(\w)([A-Z])", r"\1 \2", source)
|
|
386
|
+
account_id = search(r":\d{12}:", arn)
|
|
387
|
+
name = search(r"\/.*", arn)
|
|
388
|
+
node_info = (
|
|
389
|
+
"Entity: "
|
|
390
|
+
+ entity_text
|
|
391
|
+
+ "\nType: "
|
|
392
|
+
+ source
|
|
393
|
+
+ "\nAccount ID: "
|
|
394
|
+
+ str(account_id.group()[1:-1])
|
|
395
|
+
+ "\nName: "
|
|
396
|
+
+ str(name.group()[1:])
|
|
397
|
+
)
|
|
398
|
+
if is_start_arn: # startarn
|
|
399
|
+
net.add_node(
|
|
400
|
+
arn,
|
|
401
|
+
label=source,
|
|
402
|
+
title=node_info,
|
|
403
|
+
color=self._node_color(entity),
|
|
404
|
+
shape="star",
|
|
405
|
+
borderWidth=3,
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
net.add_node(
|
|
409
|
+
arn,
|
|
410
|
+
label=source,
|
|
411
|
+
title=node_info,
|
|
412
|
+
color=self._node_color(entity),
|
|
413
|
+
borderWidth=3,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# add edges to graph
|
|
417
|
+
for src, dest, asso_type in elements["edges"]:
|
|
418
|
+
net.add_edge(src, dest, title=asso_type, width=2)
|
|
419
|
+
|
|
420
|
+
net.write_html(path)
|
|
421
|
+
self._add_legend(path)
|
|
422
|
+
|
|
423
|
+
return self.IFrame(path, width="100%", height="600px")
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class LineageQueryResult(object):
|
|
427
|
+
"""A wrapper around the results of a lineage query."""
|
|
428
|
+
|
|
429
|
+
def __init__(
|
|
430
|
+
self,
|
|
431
|
+
edges: List[Edge] = None,
|
|
432
|
+
vertices: List[Vertex] = None,
|
|
433
|
+
startarn: List[str] = None,
|
|
434
|
+
):
|
|
435
|
+
"""Init for LineageQueryResult.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
edges (List[Edge]): The edges of the query result.
|
|
439
|
+
vertices (List[Vertex]): The vertices of the query result.
|
|
440
|
+
"""
|
|
441
|
+
self.edges = []
|
|
442
|
+
self.vertices = []
|
|
443
|
+
self.startarn = []
|
|
444
|
+
|
|
445
|
+
if edges is not None:
|
|
446
|
+
self.edges = edges
|
|
447
|
+
|
|
448
|
+
if vertices is not None:
|
|
449
|
+
self.vertices = vertices
|
|
450
|
+
|
|
451
|
+
if startarn is not None:
|
|
452
|
+
self.startarn = startarn
|
|
453
|
+
|
|
454
|
+
def __str__(self):
|
|
455
|
+
"""Define string representation of ``LineageQueryResult``.
|
|
456
|
+
|
|
457
|
+
Format:
|
|
458
|
+
{
|
|
459
|
+
'edges':[
|
|
460
|
+
{
|
|
461
|
+
'source_arn': 'string',
|
|
462
|
+
'destination_arn': 'string',
|
|
463
|
+
'association_type': 'string'
|
|
464
|
+
},
|
|
465
|
+
],
|
|
466
|
+
|
|
467
|
+
'vertices':[
|
|
468
|
+
{
|
|
469
|
+
'arn': 'string',
|
|
470
|
+
'lineage_entity': 'string',
|
|
471
|
+
'lineage_source': 'string',
|
|
472
|
+
'_session': <sagemaker.session.Session object>
|
|
473
|
+
},
|
|
474
|
+
],
|
|
475
|
+
|
|
476
|
+
'startarn':['string', ...]
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
"""
|
|
480
|
+
return (
|
|
481
|
+
"{"
|
|
482
|
+
+ "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items())
|
|
483
|
+
+ "\n}"
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
def _covert_edges_to_tuples(self):
|
|
487
|
+
"""Convert edges to tuple format for visualizer."""
|
|
488
|
+
edges = []
|
|
489
|
+
# get edge info in the form of (source, target, label)
|
|
490
|
+
for edge in self.edges:
|
|
491
|
+
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
|
|
492
|
+
return edges
|
|
493
|
+
|
|
494
|
+
def _covert_vertices_to_tuples(self):
|
|
495
|
+
"""Convert vertices to tuple format for visualizer."""
|
|
496
|
+
verts = []
|
|
497
|
+
# get vertex info in the form of (id, label, class)
|
|
498
|
+
for vert in self.vertices:
|
|
499
|
+
if vert.arn in self.startarn:
|
|
500
|
+
# add "startarn" class to node if arn is a startarn
|
|
501
|
+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True))
|
|
502
|
+
else:
|
|
503
|
+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False))
|
|
504
|
+
return verts
|
|
505
|
+
|
|
506
|
+
def _get_visualization_elements(self):
|
|
507
|
+
"""Get elements(nodes+edges) for visualization."""
|
|
508
|
+
verts = self._covert_vertices_to_tuples()
|
|
509
|
+
edges = self._covert_edges_to_tuples()
|
|
510
|
+
|
|
511
|
+
elements = {"nodes": verts, "edges": edges}
|
|
512
|
+
return elements
|
|
513
|
+
|
|
514
|
+
def visualize(
|
|
515
|
+
self,
|
|
516
|
+
path: Optional[str] = "lineage_graph_pyvis.html",
|
|
517
|
+
pyvis_options: Optional[Dict[str, Any]] = None,
|
|
518
|
+
):
|
|
519
|
+
"""Visualize lineage query result.
|
|
520
|
+
|
|
521
|
+
Creates a PyvisVisualizer object to render network graph with Pyvis library.
|
|
522
|
+
Pyvis library should be installed before using this method (run "pip install pyvis")
|
|
523
|
+
The elements(nodes & edges) are preprocessed in this method and sent to
|
|
524
|
+
PyvisVisualizer for rendering graph.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
path(optional): The path/filename of the rendered graph html file.
|
|
528
|
+
(default path: "lineage_graph_pyvis.html")
|
|
529
|
+
pyvis_options(optional): A dict containing PyVis options to customize visualization.
|
|
530
|
+
(see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
display graph: The interactive visualization is presented as a static HTML file.
|
|
534
|
+
"""
|
|
535
|
+
lineage_graph_styles = {
|
|
536
|
+
# nodes can have shape / color
|
|
537
|
+
"TrialComponent": {
|
|
538
|
+
"name": "Trial Component",
|
|
539
|
+
"style": {"background-color": "#f6cf61"},
|
|
540
|
+
"isShape": "False",
|
|
541
|
+
},
|
|
542
|
+
"Context": {
|
|
543
|
+
"name": "Context",
|
|
544
|
+
"style": {"background-color": "#ff9900"},
|
|
545
|
+
"isShape": "False",
|
|
546
|
+
},
|
|
547
|
+
"Action": {
|
|
548
|
+
"name": "Action",
|
|
549
|
+
"style": {"background-color": "#88c396"},
|
|
550
|
+
"isShape": "False",
|
|
551
|
+
},
|
|
552
|
+
"Artifact": {
|
|
553
|
+
"name": "Artifact",
|
|
554
|
+
"style": {"background-color": "#146eb4"},
|
|
555
|
+
"isShape": "False",
|
|
556
|
+
},
|
|
557
|
+
"StartArn": {
|
|
558
|
+
"name": "StartArn",
|
|
559
|
+
"style": {"shape": "star"},
|
|
560
|
+
"isShape": "True",
|
|
561
|
+
"symbol": "★", # shape symbol for legend
|
|
562
|
+
},
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
pyvis_vis = PyvisVisualizer(lineage_graph_styles, pyvis_options)
|
|
566
|
+
elements = self._get_visualization_elements()
|
|
567
|
+
return pyvis_vis.render(elements=elements, path=path)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class LineageFilter(object):
|
|
571
|
+
"""A filter used in a lineage query."""
|
|
572
|
+
|
|
573
|
+
def __init__(
|
|
574
|
+
self,
|
|
575
|
+
entities: Optional[List[Union[LineageEntityEnum, str]]] = None,
|
|
576
|
+
sources: Optional[List[Union[LineageSourceEnum, str]]] = None,
|
|
577
|
+
created_before: Optional[datetime] = None,
|
|
578
|
+
created_after: Optional[datetime] = None,
|
|
579
|
+
modified_before: Optional[datetime] = None,
|
|
580
|
+
modified_after: Optional[datetime] = None,
|
|
581
|
+
properties: Optional[Dict[str, str]] = None,
|
|
582
|
+
):
|
|
583
|
+
"""Initialize ``LineageFilter`` instance."""
|
|
584
|
+
self.entities = entities
|
|
585
|
+
self.sources = sources
|
|
586
|
+
self.created_before = created_before
|
|
587
|
+
self.created_after = created_after
|
|
588
|
+
self.modified_before = modified_before
|
|
589
|
+
self.modified_after = modified_after
|
|
590
|
+
self.properties = properties
|
|
591
|
+
|
|
592
|
+
def _to_request_dict(self):
|
|
593
|
+
"""Convert the lineage filter to its API representation."""
|
|
594
|
+
filter_request = {}
|
|
595
|
+
if self.sources:
|
|
596
|
+
filter_request["Types"] = list(
|
|
597
|
+
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
|
|
598
|
+
)
|
|
599
|
+
if self.entities:
|
|
600
|
+
filter_request["LineageTypes"] = list(
|
|
601
|
+
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
|
|
602
|
+
)
|
|
603
|
+
if self.created_before:
|
|
604
|
+
filter_request["CreatedBefore"] = self.created_before
|
|
605
|
+
if self.created_after:
|
|
606
|
+
filter_request["CreatedAfter"] = self.created_after
|
|
607
|
+
if self.modified_before:
|
|
608
|
+
filter_request["ModifiedBefore"] = self.modified_before
|
|
609
|
+
if self.modified_after:
|
|
610
|
+
filter_request["ModifiedAfter"] = self.modified_after
|
|
611
|
+
if self.properties:
|
|
612
|
+
filter_request["Properties"] = self.properties
|
|
613
|
+
return filter_request
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
class LineageQuery(object):
|
|
617
|
+
"""Creates an object used for performing lineage queries."""
|
|
618
|
+
|
|
619
|
+
def __init__(self, sagemaker_session):
|
|
620
|
+
"""Initialize ``LineageQuery`` instance."""
|
|
621
|
+
self._session = sagemaker_session
|
|
622
|
+
|
|
623
|
+
def _get_edge(self, edge):
|
|
624
|
+
"""Convert lineage query API response to an Edge."""
|
|
625
|
+
return Edge(
|
|
626
|
+
source_arn=edge["SourceArn"],
|
|
627
|
+
destination_arn=edge["DestinationArn"],
|
|
628
|
+
association_type=edge["AssociationType"] if "AssociationType" in edge else None,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
def _get_vertex(self, vertex):
|
|
632
|
+
"""Convert lineage query API response to a Vertex."""
|
|
633
|
+
vertex_type = None
|
|
634
|
+
if "Type" in vertex:
|
|
635
|
+
vertex_type = vertex["Type"]
|
|
636
|
+
return Vertex(
|
|
637
|
+
arn=vertex["Arn"],
|
|
638
|
+
lineage_source=vertex_type,
|
|
639
|
+
lineage_entity=vertex["LineageType"],
|
|
640
|
+
sagemaker_session=self._session,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
def _convert_api_response(self, response, converted) -> LineageQueryResult:
|
|
644
|
+
"""Convert the lineage query API response to its Python representation."""
|
|
645
|
+
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
|
|
646
|
+
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
|
|
647
|
+
|
|
648
|
+
edge_set = set()
|
|
649
|
+
for edge in converted.edges:
|
|
650
|
+
if edge in edge_set:
|
|
651
|
+
converted.edges.remove(edge)
|
|
652
|
+
edge_set.add(edge)
|
|
653
|
+
|
|
654
|
+
vertex_set = set()
|
|
655
|
+
for vertex in converted.vertices:
|
|
656
|
+
if vertex in vertex_set:
|
|
657
|
+
converted.vertices.remove(vertex)
|
|
658
|
+
vertex_set.add(vertex)
|
|
659
|
+
|
|
660
|
+
return converted
|
|
661
|
+
|
|
662
|
+
def _collapse_cross_account_artifacts(self, query_response):
|
|
663
|
+
"""Collapse the duplicate vertices and edges for cross-account."""
|
|
664
|
+
for edge in query_response.edges:
|
|
665
|
+
if (
|
|
666
|
+
"artifact" in edge.source_arn
|
|
667
|
+
and "artifact" in edge.destination_arn
|
|
668
|
+
and edge.source_arn.split("/")[1] == edge.destination_arn.split("/")[1]
|
|
669
|
+
and edge.source_arn != edge.destination_arn
|
|
670
|
+
):
|
|
671
|
+
edge_source_arn = edge.source_arn
|
|
672
|
+
edge_destination_arn = edge.destination_arn
|
|
673
|
+
self._update_cross_account_edge(
|
|
674
|
+
edges=query_response.edges,
|
|
675
|
+
arn=edge_source_arn,
|
|
676
|
+
duplicate_arn=edge_destination_arn,
|
|
677
|
+
)
|
|
678
|
+
self._update_cross_account_vertex(
|
|
679
|
+
query_response=query_response, duplicate_arn=edge_destination_arn
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# remove the duplicate edges from cross account
|
|
683
|
+
new_edge = [e for e in query_response.edges if not e.source_arn == e.destination_arn]
|
|
684
|
+
query_response.edges = new_edge
|
|
685
|
+
|
|
686
|
+
return query_response
|
|
687
|
+
|
|
688
|
+
def _update_cross_account_edge(self, edges, arn, duplicate_arn):
|
|
689
|
+
"""Replace the duplicate arn with arn in edges list."""
|
|
690
|
+
for idx, e in enumerate(edges):
|
|
691
|
+
if e.destination_arn == duplicate_arn:
|
|
692
|
+
edges[idx].destination_arn = arn
|
|
693
|
+
elif e.source_arn == duplicate_arn:
|
|
694
|
+
edges[idx].source_arn = arn
|
|
695
|
+
|
|
696
|
+
def _update_cross_account_vertex(self, query_response, duplicate_arn):
|
|
697
|
+
"""Remove the vertex with duplicate arn in the vertices list."""
|
|
698
|
+
query_response.vertices = [v for v in query_response.vertices if not v.arn == duplicate_arn]
|
|
699
|
+
|
|
700
|
+
def query(
|
|
701
|
+
self,
|
|
702
|
+
start_arns: List[str],
|
|
703
|
+
direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH,
|
|
704
|
+
include_edges: bool = True,
|
|
705
|
+
query_filter: LineageFilter = None,
|
|
706
|
+
max_depth: int = 10,
|
|
707
|
+
) -> LineageQueryResult:
|
|
708
|
+
"""Perform a lineage query.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
start_arns (List[str]): A list of ARNs that will be used as the starting point
|
|
712
|
+
for the query.
|
|
713
|
+
direction (LineageQueryDirectionEnum, optional): The direction of the query.
|
|
714
|
+
include_edges (bool, optional): If true, return edges in addition to vertices.
|
|
715
|
+
query_filter (LineageQueryFilter, optional): The query filter.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
LineageQueryResult: The lineage query result.
|
|
719
|
+
"""
|
|
720
|
+
query_response = self._session.sagemaker_client.query_lineage(
|
|
721
|
+
StartArns=start_arns,
|
|
722
|
+
Direction=direction.value,
|
|
723
|
+
IncludeEdges=include_edges,
|
|
724
|
+
Filters=query_filter._to_request_dict() if query_filter else {},
|
|
725
|
+
MaxDepth=max_depth,
|
|
726
|
+
)
|
|
727
|
+
# create query result for startarn info
|
|
728
|
+
query_result = LineageQueryResult(startarn=start_arns)
|
|
729
|
+
query_response = self._convert_api_response(query_response, query_result)
|
|
730
|
+
query_response = self._collapse_cross_account_artifacts(query_response)
|
|
731
|
+
|
|
732
|
+
return query_response
|