sagemaker-core 1.0.62__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sagemaker/__init__.py +2 -0
  2. sagemaker/core/__init__.py +16 -0
  3. sagemaker/core/_studio.py +116 -0
  4. sagemaker/core/_version.py +11 -0
  5. sagemaker/core/accept_types.py +131 -0
  6. sagemaker/core/analytics.py +744 -0
  7. sagemaker/core/apiutils/__init__.py +13 -0
  8. sagemaker/core/apiutils/_base_types.py +228 -0
  9. sagemaker/core/apiutils/_boto_functions.py +130 -0
  10. sagemaker/core/apiutils/_utils.py +34 -0
  11. sagemaker/core/base_deserializers.py +35 -0
  12. sagemaker/core/base_serializers.py +35 -0
  13. sagemaker/core/clarify/__init__.py +2898 -0
  14. sagemaker/core/collection.py +467 -0
  15. sagemaker/core/common_utils.py +2399 -0
  16. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  17. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  18. sagemaker/core/config/__init__.py +181 -0
  19. sagemaker/core/config/config.py +238 -0
  20. sagemaker/core/config/config_manager.py +595 -0
  21. sagemaker/core/config/config_schema.py +1220 -0
  22. sagemaker/core/config/config_utils.py +297 -0
  23. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  24. sagemaker/core/constants.py +73 -0
  25. sagemaker/core/content_types.py +137 -0
  26. sagemaker/core/debugger/__init__.py +39 -0
  27. sagemaker/core/debugger/debugger.py +945 -0
  28. sagemaker/core/debugger/framework_profile.py +292 -0
  29. sagemaker/core/debugger/metrics_config.py +468 -0
  30. sagemaker/core/debugger/profiler.py +42 -0
  31. sagemaker/core/debugger/profiler_config.py +190 -0
  32. sagemaker/core/debugger/profiler_constants.py +40 -0
  33. sagemaker/core/debugger/utils.py +148 -0
  34. sagemaker/core/deprecations.py +254 -0
  35. sagemaker/core/deserializers/__init__.py +10 -0
  36. sagemaker/core/deserializers/base.py +424 -0
  37. sagemaker/core/deserializers/implementations.py +157 -0
  38. sagemaker/core/drift_check_baselines.py +106 -0
  39. sagemaker/core/enums.py +51 -0
  40. sagemaker/core/environment_variables.py +101 -0
  41. sagemaker/core/exceptions.py +108 -0
  42. sagemaker/core/experiments/__init__.py +53 -0
  43. sagemaker/core/experiments/_api_types.py +251 -0
  44. sagemaker/core/experiments/_environment.py +124 -0
  45. sagemaker/core/experiments/_helper.py +294 -0
  46. sagemaker/core/experiments/_metrics.py +333 -0
  47. sagemaker/core/experiments/_run_context.py +58 -0
  48. sagemaker/core/experiments/_utils.py +216 -0
  49. sagemaker/core/experiments/experiment.py +247 -0
  50. sagemaker/core/experiments/run.py +970 -0
  51. sagemaker/core/experiments/trial.py +296 -0
  52. sagemaker/core/experiments/trial_component.py +387 -0
  53. sagemaker/core/explainer/__init__.py +24 -0
  54. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  55. sagemaker/core/explainer/explainer_config.py +44 -0
  56. sagemaker/core/fw_utils.py +1220 -0
  57. sagemaker/core/git_utils.py +415 -0
  58. sagemaker/core/helper/pipeline_variable.py +82 -0
  59. sagemaker/core/helper/session_helper.py +2977 -0
  60. sagemaker/core/hyperparameters.py +172 -0
  61. sagemaker/core/image_retriever/__init__.py +3 -0
  62. sagemaker/core/image_retriever/image_retriever.py +640 -0
  63. sagemaker/core/image_retriever/image_retriever_utils.py +509 -0
  64. sagemaker/core/image_retriever/test.py +7 -0
  65. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  66. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  67. sagemaker/core/image_uri_config/chainer.json +104 -0
  68. sagemaker/core/image_uri_config/clarify.json +39 -0
  69. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  70. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  71. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  72. sagemaker/core/image_uri_config/debugger.json +34 -0
  73. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  74. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  75. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  76. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  77. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  78. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  79. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  80. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  81. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +770 -0
  82. sagemaker/core/image_uri_config/huggingface-llm.json +1267 -0
  83. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  84. sagemaker/core/image_uri_config/huggingface-neuronx.json +686 -0
  85. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  86. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  87. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  88. sagemaker/core/image_uri_config/huggingface-vllm-neuronx.json +38 -0
  89. sagemaker/core/image_uri_config/huggingface.json +2287 -0
  90. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  91. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  92. sagemaker/core/image_uri_config/image-classification.json +50 -0
  93. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  94. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  95. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  96. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  97. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  98. sagemaker/core/image_uri_config/kmeans.json +50 -0
  99. sagemaker/core/image_uri_config/knn.json +50 -0
  100. sagemaker/core/image_uri_config/lda.json +26 -0
  101. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  102. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  103. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  104. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  105. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  106. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  107. sagemaker/core/image_uri_config/ntm.json +50 -0
  108. sagemaker/core/image_uri_config/object-detection.json +50 -0
  109. sagemaker/core/image_uri_config/object2vec.json +50 -0
  110. sagemaker/core/image_uri_config/pca.json +50 -0
  111. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  112. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  113. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  114. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  115. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  116. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  117. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  118. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  119. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  120. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  121. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +252 -0
  122. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  123. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  124. sagemaker/core/image_uri_config/sklearn.json +494 -0
  125. sagemaker/core/image_uri_config/spark.json +280 -0
  126. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  127. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  128. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  129. sagemaker/core/image_uri_config/vw.json +25 -0
  130. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  131. sagemaker/core/image_uri_config/xgboost.json +972 -0
  132. sagemaker/core/image_uris.py +816 -0
  133. sagemaker/core/inference_config.py +144 -0
  134. sagemaker/core/inference_recommender/__init__.py +18 -0
  135. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  136. sagemaker/core/inputs.py +366 -0
  137. sagemaker/core/instance_group.py +61 -0
  138. sagemaker/core/instance_types.py +164 -0
  139. sagemaker/core/instance_types_gpu_info.py +43 -0
  140. sagemaker/core/interactive_apps/__init__.py +41 -0
  141. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  142. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  143. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  144. sagemaker/core/iterators.py +197 -0
  145. sagemaker/core/job.py +380 -0
  146. sagemaker/core/jumpstart/__init__.py +156 -0
  147. sagemaker/core/jumpstart/accessors.py +390 -0
  148. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  149. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  150. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  151. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  152. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  153. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  154. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  155. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  156. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  157. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  158. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  159. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  160. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  161. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  162. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  163. sagemaker/core/jumpstart/cache.py +663 -0
  164. sagemaker/core/jumpstart/configs.py +50 -0
  165. sagemaker/core/jumpstart/constants.py +198 -0
  166. sagemaker/core/jumpstart/deserializers.py +81 -0
  167. sagemaker/core/jumpstart/document.py +76 -0
  168. sagemaker/core/jumpstart/enums.py +168 -0
  169. sagemaker/core/jumpstart/exceptions.py +236 -0
  170. sagemaker/core/jumpstart/factory/utils.py +833 -0
  171. sagemaker/core/jumpstart/filters.py +597 -0
  172. sagemaker/core/jumpstart/hub/constants.py +16 -0
  173. sagemaker/core/jumpstart/hub/hub.py +291 -0
  174. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  175. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  176. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  177. sagemaker/core/jumpstart/hub/types.py +35 -0
  178. sagemaker/core/jumpstart/hub/utils.py +260 -0
  179. sagemaker/core/jumpstart/models.py +501 -0
  180. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  181. sagemaker/core/jumpstart/parameters.py +20 -0
  182. sagemaker/core/jumpstart/payload_utils.py +239 -0
  183. sagemaker/core/jumpstart/region_config.json +171 -0
  184. sagemaker/core/jumpstart/search.py +171 -0
  185. sagemaker/core/jumpstart/serializers.py +81 -0
  186. sagemaker/core/jumpstart/session_utils.py +234 -0
  187. sagemaker/core/jumpstart/types.py +3044 -0
  188. sagemaker/core/jumpstart/utils.py +1731 -0
  189. sagemaker/core/jumpstart/validators.py +257 -0
  190. sagemaker/core/lambda_helper.py +312 -0
  191. sagemaker/core/lineage/__init__.py +42 -0
  192. sagemaker/core/lineage/_api_types.py +239 -0
  193. sagemaker/core/lineage/_utils.py +49 -0
  194. sagemaker/core/lineage/action.py +345 -0
  195. sagemaker/core/lineage/artifact.py +646 -0
  196. sagemaker/core/lineage/association.py +190 -0
  197. sagemaker/core/lineage/context.py +505 -0
  198. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  199. sagemaker/core/lineage/query.py +732 -0
  200. sagemaker/core/lineage/visualizer.py +346 -0
  201. sagemaker/core/local/__init__.py +18 -0
  202. sagemaker/core/local/data.py +423 -0
  203. sagemaker/core/local/entities.py +678 -0
  204. sagemaker/core/local/exceptions.py +17 -0
  205. sagemaker/core/local/image.py +1243 -0
  206. sagemaker/core/local/local_session.py +739 -0
  207. sagemaker/core/local/utils.py +246 -0
  208. sagemaker/core/logs.py +181 -0
  209. sagemaker/core/metadata_properties.py +56 -0
  210. sagemaker/core/metric_definitions.py +91 -0
  211. sagemaker/core/mlflow/__init__.py +38 -0
  212. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  213. sagemaker/core/model_card/__init__.py +26 -0
  214. sagemaker/core/model_life_cycle.py +51 -0
  215. sagemaker/core/model_metrics.py +160 -0
  216. sagemaker/core/model_monitor/__init__.py +66 -0
  217. sagemaker/core/model_monitor/clarify_model_monitoring.py +1497 -0
  218. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  219. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  220. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  221. sagemaker/core/model_monitor/dataset_format.py +102 -0
  222. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  223. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  224. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  225. sagemaker/core/model_monitor/utils.py +793 -0
  226. sagemaker/core/model_registry.py +480 -0
  227. sagemaker/core/model_uris.py +97 -0
  228. sagemaker/core/modules/__init__.py +19 -0
  229. sagemaker/core/modules/configs.py +239 -0
  230. sagemaker/core/modules/constants.py +37 -0
  231. sagemaker/core/modules/distributed.py +182 -0
  232. sagemaker/core/modules/local_core/local_container.py +605 -0
  233. sagemaker/core/modules/templates.py +83 -0
  234. sagemaker/core/modules/train/__init__.py +14 -0
  235. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  236. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  237. sagemaker/core/modules/train/container_drivers/common/utils.py +205 -0
  238. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  240. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  241. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  243. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  245. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  246. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  247. sagemaker/core/modules/types.py +19 -0
  248. sagemaker/core/modules/utils.py +194 -0
  249. sagemaker/core/network.py +185 -0
  250. sagemaker/core/parameter.py +173 -0
  251. sagemaker/core/payloads.py +185 -0
  252. sagemaker/core/processing.py +1599 -0
  253. sagemaker/core/remote_function/__init__.py +19 -0
  254. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  255. sagemaker/core/remote_function/client.py +1310 -0
  256. sagemaker/core/remote_function/core/__init__.py +0 -0
  257. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  258. sagemaker/core/remote_function/core/pipeline_variables.py +347 -0
  259. sagemaker/core/remote_function/core/serialization.py +410 -0
  260. sagemaker/core/remote_function/core/stored_function.py +223 -0
  261. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  262. sagemaker/core/remote_function/errors.py +102 -0
  263. sagemaker/core/remote_function/invoke_function.py +167 -0
  264. sagemaker/core/remote_function/job.py +2121 -0
  265. sagemaker/core/remote_function/logging_config.py +38 -0
  266. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  267. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  268. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  269. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  270. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  271. sagemaker/core/remote_function/spark_config.py +149 -0
  272. sagemaker/core/resource_requirements.py +168 -0
  273. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  274. sagemaker/core/s3/__init__.py +41 -0
  275. sagemaker/core/s3/client.py +367 -0
  276. sagemaker/core/s3/utils.py +175 -0
  277. sagemaker/core/script_uris.py +93 -0
  278. sagemaker/core/serializers/__init__.py +11 -0
  279. sagemaker/core/serializers/base.py +510 -0
  280. sagemaker/core/serializers/implementations.py +159 -0
  281. sagemaker/core/serializers/utils.py +223 -0
  282. sagemaker/core/serverless_inference_config.py +63 -0
  283. sagemaker/core/session_settings.py +55 -0
  284. sagemaker/core/shapes/__init__.py +3 -0
  285. sagemaker/core/shapes/model_card_shapes.py +159 -0
  286. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  287. sagemaker/core/spark/__init__.py +16 -0
  288. sagemaker/core/spark/defaults.py +16 -0
  289. sagemaker/core/spark/processing.py +1380 -0
  290. sagemaker/core/telemetry/__init__.py +23 -0
  291. sagemaker/core/telemetry/constants.py +82 -0
  292. sagemaker/core/telemetry/telemetry_logging.py +285 -0
  293. sagemaker/core/tools/__init__.py +1 -0
  294. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  295. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  296. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  297. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  298. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  299. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  300. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  301. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  302. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  303. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  304. sagemaker/core/training/__init__.py +14 -0
  305. sagemaker/core/training/configs.py +345 -0
  306. sagemaker/core/training/constants.py +37 -0
  307. sagemaker/core/training/utils.py +77 -0
  308. sagemaker/core/training_compiler/__init__.py +16 -0
  309. sagemaker/core/training_compiler/config.py +197 -0
  310. sagemaker/core/training_compiler_config.py +197 -0
  311. sagemaker/core/transformer.py +793 -0
  312. sagemaker/core/user_agent.py +76 -0
  313. sagemaker/core/utilities/__init__.py +24 -0
  314. sagemaker/core/utilities/cache.py +169 -0
  315. sagemaker/core/utilities/search_expression.py +133 -0
  316. sagemaker/core/utils/__init__.py +48 -0
  317. sagemaker/core/utils/code_injection/__init__.py +0 -0
  318. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  319. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  320. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  321. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  322. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  324. sagemaker/core/workflow/__init__.py +152 -0
  325. sagemaker/core/workflow/conditions.py +313 -0
  326. sagemaker/core/workflow/entities.py +58 -0
  327. sagemaker/core/workflow/execution_variables.py +89 -0
  328. sagemaker/core/workflow/functions.py +193 -0
  329. sagemaker/core/workflow/parameters.py +222 -0
  330. sagemaker/core/workflow/pipeline_context.py +394 -0
  331. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  332. sagemaker/core/workflow/properties.py +285 -0
  333. sagemaker/core/workflow/step_outputs.py +65 -0
  334. sagemaker/core/workflow/utilities.py +514 -0
  335. sagemaker/lineage/__init__.py +33 -0
  336. sagemaker/lineage/action.py +28 -0
  337. sagemaker/lineage/artifact.py +28 -0
  338. sagemaker/lineage/context.py +28 -0
  339. sagemaker/lineage/lineage_trial_component.py +28 -0
  340. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/METADATA +28 -9
  341. sagemaker_core-2.3.1.dist-info/RECORD +351 -0
  342. sagemaker_core-2.3.1.dist-info/top_level.txt +1 -0
  343. sagemaker_core/_version.py +0 -3
  344. sagemaker_core/helper/session_helper.py +0 -769
  345. sagemaker_core/resources/__init__.py +0 -1
  346. sagemaker_core/shapes/__init__.py +0 -1
  347. sagemaker_core/tools/__init__.py +0 -1
  348. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  349. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  350. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  351. {sagemaker_core/helper → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  352. {sagemaker_core/main → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  353. {sagemaker_core/main/code_injection → sagemaker/core/modules/local_core}/__init__.py +0 -0
  354. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  355. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  357. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/WHEEL +0 -0
  358. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,732 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains code to query SageMaker lineage."""
14
+ from __future__ import absolute_import
15
+
16
+ from datetime import datetime
17
+ from enum import Enum
18
+ from typing import Any, Optional, Union, List, Dict
19
+ from json import dumps
20
+ from re import sub, search
21
+
22
+ from sagemaker.core.common_utils import get_module
23
+ from sagemaker.core.lineage._utils import get_resource_name_from_arn
24
+
25
+
26
+ class LineageEntityEnum(Enum):
27
+ """Enum of lineage entities for use in a query filter."""
28
+
29
+ TRIAL = "Trial"
30
+ ACTION = "Action"
31
+ ARTIFACT = "Artifact"
32
+ CONTEXT = "Context"
33
+ TRIAL_COMPONENT = "TrialComponent"
34
+
35
+
36
+ class LineageSourceEnum(Enum):
37
+ """Enum of lineage types for use in a query filter."""
38
+
39
+ CHECKPOINT = "Checkpoint"
40
+ DATASET = "DataSet"
41
+ ENDPOINT = "Endpoint"
42
+ IMAGE = "Image"
43
+ MODEL = "Model"
44
+ MODEL_DATA = "ModelData"
45
+ MODEL_DEPLOYMENT = "ModelDeployment"
46
+ MODEL_GROUP = "ModelGroup"
47
+ MODEL_REPLACE = "ModelReplaced"
48
+ TENSORBOARD = "TensorBoard"
49
+ TRAINING_JOB = "TrainingJob"
50
+ APPROVAL = "Approval"
51
+ PROCESSING_JOB = "ProcessingJob"
52
+ TRANSFORM_JOB = "TransformJob"
53
+
54
+
55
+ class LineageQueryDirectionEnum(Enum):
56
+ """Enum of query filter directions."""
57
+
58
+ BOTH = "Both"
59
+ ASCENDANTS = "Ascendants"
60
+ DESCENDANTS = "Descendants"
61
+
62
+
63
+ class Edge:
64
+ """A connecting edge for a lineage graph."""
65
+
66
+ def __init__(
67
+ self,
68
+ source_arn: str,
69
+ destination_arn: str,
70
+ association_type: str,
71
+ ):
72
+ """Initialize ``Edge`` instance."""
73
+ self.source_arn = source_arn
74
+ self.destination_arn = destination_arn
75
+ self.association_type = association_type
76
+
77
+ def __hash__(self):
78
+ """Define hash function for ``Edge``."""
79
+ return hash(
80
+ (
81
+ "source_arn",
82
+ self.source_arn,
83
+ "destination_arn",
84
+ self.destination_arn,
85
+ "association_type",
86
+ self.association_type,
87
+ )
88
+ )
89
+
90
+ def __eq__(self, other):
91
+ """Define equal function for ``Edge``."""
92
+ return (
93
+ self.association_type == other.association_type
94
+ and self.source_arn == other.source_arn
95
+ and self.destination_arn == other.destination_arn
96
+ )
97
+
98
+ def __str__(self):
99
+ """Define string representation of ``Edge``.
100
+
101
+ Format:
102
+ {
103
+ 'source_arn': 'string',
104
+ 'destination_arn': 'string',
105
+ 'association_type': 'string'
106
+ }
107
+
108
+ """
109
+ return str(self.__dict__)
110
+
111
+ def __repr__(self):
112
+ """Define string representation of ``Edge``.
113
+
114
+ Format:
115
+ {
116
+ 'source_arn': 'string',
117
+ 'destination_arn': 'string',
118
+ 'association_type': 'string'
119
+ }
120
+
121
+ """
122
+ return "\n\t" + str(self.__dict__)
123
+
124
+
125
+ class Vertex:
126
+ """A vertex for a lineage graph."""
127
+
128
+ def __init__(
129
+ self,
130
+ arn: str,
131
+ lineage_entity: str,
132
+ lineage_source: str,
133
+ sagemaker_session,
134
+ ):
135
+ """Initialize ``Vertex`` instance."""
136
+ self.arn = arn
137
+ self.lineage_entity = lineage_entity
138
+ self.lineage_source = lineage_source
139
+ self._session = sagemaker_session
140
+
141
+ def __hash__(self):
142
+ """Define hash function for ``Vertex``."""
143
+ return hash(
144
+ (
145
+ "arn",
146
+ self.arn,
147
+ "lineage_entity",
148
+ self.lineage_entity,
149
+ "lineage_source",
150
+ self.lineage_source,
151
+ )
152
+ )
153
+
154
+ def __eq__(self, other):
155
+ """Define equal function for ``Vertex``."""
156
+ return (
157
+ self.arn == other.arn
158
+ and self.lineage_entity == other.lineage_entity
159
+ and self.lineage_source == other.lineage_source
160
+ )
161
+
162
+ def __str__(self):
163
+ """Define string representation of ``Vertex``.
164
+
165
+ Format:
166
+ {
167
+ 'arn': 'string',
168
+ 'lineage_entity': 'string',
169
+ 'lineage_source': 'string',
170
+ '_session': <sagemaker.session.Session object>
171
+ }
172
+
173
+ """
174
+ return str(self.__dict__)
175
+
176
+ def __repr__(self):
177
+ """Define string representation of ``Vertex``.
178
+
179
+ Format:
180
+ {
181
+ 'arn': 'string',
182
+ 'lineage_entity': 'string',
183
+ 'lineage_source': 'string',
184
+ '_session': <sagemaker.session.Session object>
185
+ }
186
+
187
+ """
188
+ return "\n\t" + str(self.__dict__)
189
+
190
+ def to_lineage_object(self):
191
+ """Convert the ``Vertex`` object to its corresponding lineage object.
192
+
193
+ Returns:
194
+ A ``Vertex`` object to its corresponding ``Artifact``,``Action``, ``Context``
195
+ or ``TrialComponent`` object.
196
+ """
197
+ from sagemaker.lineage.context import Context, EndpointContext
198
+ from sagemaker.lineage.action import Action
199
+ from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
200
+
201
+ if self.lineage_entity == LineageEntityEnum.CONTEXT.value:
202
+ resource_name = get_resource_name_from_arn(self.arn)
203
+ if self.lineage_source == LineageSourceEnum.ENDPOINT.value:
204
+ return EndpointContext.load(
205
+ context_name=resource_name, sagemaker_session=self._session
206
+ )
207
+ return Context.load(context_name=resource_name, sagemaker_session=self._session)
208
+
209
+ if self.lineage_entity == LineageEntityEnum.ARTIFACT.value:
210
+ return self._artifact_to_lineage_object()
211
+
212
+ if self.lineage_entity == LineageEntityEnum.ACTION.value:
213
+ return Action.load(action_name=self.arn.split("/")[1], sagemaker_session=self._session)
214
+
215
+ if self.lineage_entity == LineageEntityEnum.TRIAL_COMPONENT.value:
216
+ trial_component_name = get_resource_name_from_arn(self.arn)
217
+ return LineageTrialComponent.load(
218
+ trial_component_name=trial_component_name, sagemaker_session=self._session
219
+ )
220
+ raise ValueError("Vertex cannot be converted to a lineage object.")
221
+
222
+ def _artifact_to_lineage_object(self):
223
+ """Convert the ``Vertex`` object to its corresponding ``Artifact``."""
224
+ from sagemaker.lineage.artifact import Artifact, ModelArtifact, ImageArtifact
225
+ from sagemaker.lineage.artifact import DatasetArtifact
226
+
227
+ if self.lineage_source == LineageSourceEnum.MODEL.value:
228
+ return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
229
+ if self.lineage_source == LineageSourceEnum.DATASET.value:
230
+ return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
231
+ if self.lineage_source == LineageSourceEnum.IMAGE.value:
232
+ return ImageArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
233
+ return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
234
+
235
+
236
+ class PyvisVisualizer(object):
237
+ """Create object used for visualizing graph using Pyvis library."""
238
+
239
+ def __init__(self, graph_styles, pyvis_options: Optional[Dict[str, Any]] = None):
240
+ """Init for PyvisVisualizer.
241
+
242
+ Args:
243
+ graph_styles: A dictionary that contains graph style for node and edges by their type.
244
+ Example: Display the nodes with different color by their lineage entity / different
245
+ shape by start arn.
246
+ lineage_graph_styles = {
247
+ "TrialComponent": {
248
+ "name": "Trial Component",
249
+ "style": {"background-color": "#f6cf61"},
250
+ "isShape": "False",
251
+ },
252
+ "Context": {
253
+ "name": "Context",
254
+ "style": {"background-color": "#ff9900"},
255
+ "isShape": "False",
256
+ },
257
+ "StartArn": {
258
+ "name": "StartArn",
259
+ "style": {"shape": "star"},
260
+ "isShape": "True",
261
+ "symbol": "★", # shape symbol for legend
262
+ },
263
+ }
264
+ pyvis_options(optional): A dict containing PyVis options to customize visualization.
265
+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
266
+ """
267
+ # import visualization packages
268
+ (
269
+ self.Network,
270
+ self.Options,
271
+ self.IFrame,
272
+ self.BeautifulSoup,
273
+ ) = self._import_visual_modules()
274
+
275
+ self.graph_styles = graph_styles
276
+
277
+ if pyvis_options is None:
278
+ # default pyvis graph options
279
+ pyvis_options = {
280
+ "configure": {"enabled": False},
281
+ "layout": {
282
+ "hierarchical": {
283
+ "enabled": True,
284
+ "blockShifting": True,
285
+ "direction": "LR",
286
+ "sortMethod": "directed",
287
+ "shakeTowards": "leaves",
288
+ }
289
+ },
290
+ "interaction": {"multiselect": True, "navigationButtons": True},
291
+ "physics": {
292
+ "enabled": False,
293
+ "hierarchicalRepulsion": {"centralGravity": 0, "avoidOverlap": None},
294
+ "minVelocity": 0.75,
295
+ "solver": "hierarchicalRepulsion",
296
+ },
297
+ }
298
+ # A string representation of a Javascript-like object used to override pyvis options
299
+ self._pyvis_options = f"var options = {dumps(pyvis_options)}"
300
+
301
+ def _import_visual_modules(self):
302
+ """Import modules needed for visualization."""
303
+ get_module("pyvis")
304
+ from pyvis.network import Network
305
+ from pyvis.options import Options
306
+ from IPython.display import IFrame
307
+
308
+ get_module("bs4")
309
+ from bs4 import BeautifulSoup
310
+
311
+ return Network, Options, IFrame, BeautifulSoup
312
+
313
+ def _node_color(self, entity):
314
+ """Return node color by background-color specified in graph styles."""
315
+ return self.graph_styles[entity]["style"]["background-color"]
316
+
317
+ def _get_legend_line(self, component_name):
318
+ """Generate lengend div line for each graph component in graph_styles."""
319
+ if self.graph_styles[component_name]["isShape"] == "False":
320
+ return '<div><div style="background-color: {color}; width: 1.6vw; height: 1.6vw;\
321
+ display: inline-block; font-size: 1.5vw; vertical-align: -0.2em;"></div>\
322
+ <div style="width: 0.3vw; height: 1.5vw; display: inline-block;"></div>\
323
+ <div style="display: inline-block; font-size: 1.5vw;">{name}</div></div>'.format(
324
+ color=self.graph_styles[component_name]["style"]["background-color"],
325
+ name=self.graph_styles[component_name]["name"],
326
+ )
327
+
328
+ return '<div style="background-color: #ffffff; width: 1.6vw; height: 1.6vw;\
329
+ display: inline-block; font-size: 0.9vw; vertical-align: -0.2em;">{shape}</div>\
330
+ <div style="width: 0.3vw; height: 1.5vw; display: inline-block;"></div>\
331
+ <div style="display: inline-block; font-size: 1.5vw;">{name}</div></div>'.format(
332
+ shape=self.graph_styles[component_name]["style"]["shape"],
333
+ name=self.graph_styles[component_name]["name"],
334
+ )
335
+
336
+ def _add_legend(self, path):
337
+ """Embed legend to html file generated by pyvis."""
338
+ with open(path, "r") as f:
339
+ content = self.BeautifulSoup(f, "html.parser")
340
+
341
+ legend = """
342
+ <div style="display: inline-block; font-size: 1vw; font-family: verdana;
343
+ vertical-align: top; padding: 1vw;">
344
+ """
345
+ # iterate through graph styles to get legend
346
+ for component in self.graph_styles.keys():
347
+ legend += self._get_legend_line(component_name=component)
348
+
349
+ legend += "</div>"
350
+
351
+ legend_div = self.BeautifulSoup(legend, "html.parser")
352
+
353
+ content.div.insert_after(legend_div)
354
+
355
+ html = content.prettify()
356
+
357
+ with open(path, "w", encoding="utf8") as file:
358
+ file.write(html)
359
+
360
+ def render(self, elements, path="lineage_graph_pyvis.html"):
361
+ """Render graph for lineage query result.
362
+
363
+ Args:
364
+ elements: A dictionary that contains the node and the edges of the graph.
365
+ Example:
366
+ elements["nodes"] contains list of tuples, each tuple represents a node
367
+ format: (node arn, node lineage source, node lineage entity,
368
+ node is start arn)
369
+ elements["edges"] contains list of tuples, each tuple represents an edge
370
+ format: (edge source arn, edge destination arn, edge association type)
371
+
372
+ path(optional): The path/filename of the rendered graph html file.
373
+ (default path: "lineage_graph_pyvis.html")
374
+
375
+ Returns:
376
+ display graph: The interactive visualization is presented as a static HTML file.
377
+
378
+ """
379
+ net = self.Network(height="600px", width="82%", notebook=True, directed=True)
380
+ net.set_options(self._pyvis_options)
381
+
382
+ # add nodes to graph
383
+ for arn, source, entity, is_start_arn in elements["nodes"]:
384
+ entity_text = sub(r"(\w)([A-Z])", r"\1 \2", entity)
385
+ source = sub(r"(\w)([A-Z])", r"\1 \2", source)
386
+ account_id = search(r":\d{12}:", arn)
387
+ name = search(r"\/.*", arn)
388
+ node_info = (
389
+ "Entity: "
390
+ + entity_text
391
+ + "\nType: "
392
+ + source
393
+ + "\nAccount ID: "
394
+ + str(account_id.group()[1:-1])
395
+ + "\nName: "
396
+ + str(name.group()[1:])
397
+ )
398
+ if is_start_arn: # startarn
399
+ net.add_node(
400
+ arn,
401
+ label=source,
402
+ title=node_info,
403
+ color=self._node_color(entity),
404
+ shape="star",
405
+ borderWidth=3,
406
+ )
407
+ else:
408
+ net.add_node(
409
+ arn,
410
+ label=source,
411
+ title=node_info,
412
+ color=self._node_color(entity),
413
+ borderWidth=3,
414
+ )
415
+
416
+ # add edges to graph
417
+ for src, dest, asso_type in elements["edges"]:
418
+ net.add_edge(src, dest, title=asso_type, width=2)
419
+
420
+ net.write_html(path)
421
+ self._add_legend(path)
422
+
423
+ return self.IFrame(path, width="100%", height="600px")
424
+
425
+
426
+ class LineageQueryResult(object):
427
+ """A wrapper around the results of a lineage query."""
428
+
429
+ def __init__(
430
+ self,
431
+ edges: List[Edge] = None,
432
+ vertices: List[Vertex] = None,
433
+ startarn: List[str] = None,
434
+ ):
435
+ """Init for LineageQueryResult.
436
+
437
+ Args:
438
+ edges (List[Edge]): The edges of the query result.
439
+ vertices (List[Vertex]): The vertices of the query result.
440
+ """
441
+ self.edges = []
442
+ self.vertices = []
443
+ self.startarn = []
444
+
445
+ if edges is not None:
446
+ self.edges = edges
447
+
448
+ if vertices is not None:
449
+ self.vertices = vertices
450
+
451
+ if startarn is not None:
452
+ self.startarn = startarn
453
+
454
+ def __str__(self):
455
+ """Define string representation of ``LineageQueryResult``.
456
+
457
+ Format:
458
+ {
459
+ 'edges':[
460
+ {
461
+ 'source_arn': 'string',
462
+ 'destination_arn': 'string',
463
+ 'association_type': 'string'
464
+ },
465
+ ],
466
+
467
+ 'vertices':[
468
+ {
469
+ 'arn': 'string',
470
+ 'lineage_entity': 'string',
471
+ 'lineage_source': 'string',
472
+ '_session': <sagemaker.session.Session object>
473
+ },
474
+ ],
475
+
476
+ 'startarn':['string', ...]
477
+ }
478
+
479
+ """
480
+ return (
481
+ "{"
482
+ + "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items())
483
+ + "\n}"
484
+ )
485
+
486
+ def _covert_edges_to_tuples(self):
487
+ """Convert edges to tuple format for visualizer."""
488
+ edges = []
489
+ # get edge info in the form of (source, target, label)
490
+ for edge in self.edges:
491
+ edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
492
+ return edges
493
+
494
+ def _covert_vertices_to_tuples(self):
495
+ """Convert vertices to tuple format for visualizer."""
496
+ verts = []
497
+ # get vertex info in the form of (id, label, class)
498
+ for vert in self.vertices:
499
+ if vert.arn in self.startarn:
500
+ # add "startarn" class to node if arn is a startarn
501
+ verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True))
502
+ else:
503
+ verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False))
504
+ return verts
505
+
506
+ def _get_visualization_elements(self):
507
+ """Get elements(nodes+edges) for visualization."""
508
+ verts = self._covert_vertices_to_tuples()
509
+ edges = self._covert_edges_to_tuples()
510
+
511
+ elements = {"nodes": verts, "edges": edges}
512
+ return elements
513
+
514
+ def visualize(
515
+ self,
516
+ path: Optional[str] = "lineage_graph_pyvis.html",
517
+ pyvis_options: Optional[Dict[str, Any]] = None,
518
+ ):
519
+ """Visualize lineage query result.
520
+
521
+ Creates a PyvisVisualizer object to render network graph with Pyvis library.
522
+ Pyvis library should be installed before using this method (run "pip install pyvis")
523
+ The elements(nodes & edges) are preprocessed in this method and sent to
524
+ PyvisVisualizer for rendering graph.
525
+
526
+ Args:
527
+ path(optional): The path/filename of the rendered graph html file.
528
+ (default path: "lineage_graph_pyvis.html")
529
+ pyvis_options(optional): A dict containing PyVis options to customize visualization.
530
+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
531
+
532
+ Returns:
533
+ display graph: The interactive visualization is presented as a static HTML file.
534
+ """
535
+ lineage_graph_styles = {
536
+ # nodes can have shape / color
537
+ "TrialComponent": {
538
+ "name": "Trial Component",
539
+ "style": {"background-color": "#f6cf61"},
540
+ "isShape": "False",
541
+ },
542
+ "Context": {
543
+ "name": "Context",
544
+ "style": {"background-color": "#ff9900"},
545
+ "isShape": "False",
546
+ },
547
+ "Action": {
548
+ "name": "Action",
549
+ "style": {"background-color": "#88c396"},
550
+ "isShape": "False",
551
+ },
552
+ "Artifact": {
553
+ "name": "Artifact",
554
+ "style": {"background-color": "#146eb4"},
555
+ "isShape": "False",
556
+ },
557
+ "StartArn": {
558
+ "name": "StartArn",
559
+ "style": {"shape": "star"},
560
+ "isShape": "True",
561
+ "symbol": "★", # shape symbol for legend
562
+ },
563
+ }
564
+
565
+ pyvis_vis = PyvisVisualizer(lineage_graph_styles, pyvis_options)
566
+ elements = self._get_visualization_elements()
567
+ return pyvis_vis.render(elements=elements, path=path)
568
+
569
+
570
+ class LineageFilter(object):
571
+ """A filter used in a lineage query."""
572
+
573
+ def __init__(
574
+ self,
575
+ entities: Optional[List[Union[LineageEntityEnum, str]]] = None,
576
+ sources: Optional[List[Union[LineageSourceEnum, str]]] = None,
577
+ created_before: Optional[datetime] = None,
578
+ created_after: Optional[datetime] = None,
579
+ modified_before: Optional[datetime] = None,
580
+ modified_after: Optional[datetime] = None,
581
+ properties: Optional[Dict[str, str]] = None,
582
+ ):
583
+ """Initialize ``LineageFilter`` instance."""
584
+ self.entities = entities
585
+ self.sources = sources
586
+ self.created_before = created_before
587
+ self.created_after = created_after
588
+ self.modified_before = modified_before
589
+ self.modified_after = modified_after
590
+ self.properties = properties
591
+
592
+ def _to_request_dict(self):
593
+ """Convert the lineage filter to its API representation."""
594
+ filter_request = {}
595
+ if self.sources:
596
+ filter_request["Types"] = list(
597
+ map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
598
+ )
599
+ if self.entities:
600
+ filter_request["LineageTypes"] = list(
601
+ map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
602
+ )
603
+ if self.created_before:
604
+ filter_request["CreatedBefore"] = self.created_before
605
+ if self.created_after:
606
+ filter_request["CreatedAfter"] = self.created_after
607
+ if self.modified_before:
608
+ filter_request["ModifiedBefore"] = self.modified_before
609
+ if self.modified_after:
610
+ filter_request["ModifiedAfter"] = self.modified_after
611
+ if self.properties:
612
+ filter_request["Properties"] = self.properties
613
+ return filter_request
614
+
615
+
616
+ class LineageQuery(object):
617
+ """Creates an object used for performing lineage queries."""
618
+
619
+ def __init__(self, sagemaker_session):
620
+ """Initialize ``LineageQuery`` instance."""
621
+ self._session = sagemaker_session
622
+
623
+ def _get_edge(self, edge):
624
+ """Convert lineage query API response to an Edge."""
625
+ return Edge(
626
+ source_arn=edge["SourceArn"],
627
+ destination_arn=edge["DestinationArn"],
628
+ association_type=edge["AssociationType"] if "AssociationType" in edge else None,
629
+ )
630
+
631
+ def _get_vertex(self, vertex):
632
+ """Convert lineage query API response to a Vertex."""
633
+ vertex_type = None
634
+ if "Type" in vertex:
635
+ vertex_type = vertex["Type"]
636
+ return Vertex(
637
+ arn=vertex["Arn"],
638
+ lineage_source=vertex_type,
639
+ lineage_entity=vertex["LineageType"],
640
+ sagemaker_session=self._session,
641
+ )
642
+
643
+ def _convert_api_response(self, response, converted) -> LineageQueryResult:
644
+ """Convert the lineage query API response to its Python representation."""
645
+ converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
646
+ converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
647
+
648
+ edge_set = set()
649
+ for edge in converted.edges:
650
+ if edge in edge_set:
651
+ converted.edges.remove(edge)
652
+ edge_set.add(edge)
653
+
654
+ vertex_set = set()
655
+ for vertex in converted.vertices:
656
+ if vertex in vertex_set:
657
+ converted.vertices.remove(vertex)
658
+ vertex_set.add(vertex)
659
+
660
+ return converted
661
+
662
+ def _collapse_cross_account_artifacts(self, query_response):
663
+ """Collapse the duplicate vertices and edges for cross-account."""
664
+ for edge in query_response.edges:
665
+ if (
666
+ "artifact" in edge.source_arn
667
+ and "artifact" in edge.destination_arn
668
+ and edge.source_arn.split("/")[1] == edge.destination_arn.split("/")[1]
669
+ and edge.source_arn != edge.destination_arn
670
+ ):
671
+ edge_source_arn = edge.source_arn
672
+ edge_destination_arn = edge.destination_arn
673
+ self._update_cross_account_edge(
674
+ edges=query_response.edges,
675
+ arn=edge_source_arn,
676
+ duplicate_arn=edge_destination_arn,
677
+ )
678
+ self._update_cross_account_vertex(
679
+ query_response=query_response, duplicate_arn=edge_destination_arn
680
+ )
681
+
682
+ # remove the duplicate edges from cross account
683
+ new_edge = [e for e in query_response.edges if not e.source_arn == e.destination_arn]
684
+ query_response.edges = new_edge
685
+
686
+ return query_response
687
+
688
+ def _update_cross_account_edge(self, edges, arn, duplicate_arn):
689
+ """Replace the duplicate arn with arn in edges list."""
690
+ for idx, e in enumerate(edges):
691
+ if e.destination_arn == duplicate_arn:
692
+ edges[idx].destination_arn = arn
693
+ elif e.source_arn == duplicate_arn:
694
+ edges[idx].source_arn = arn
695
+
696
+ def _update_cross_account_vertex(self, query_response, duplicate_arn):
697
+ """Remove the vertex with duplicate arn in the vertices list."""
698
+ query_response.vertices = [v for v in query_response.vertices if not v.arn == duplicate_arn]
699
+
700
+ def query(
701
+ self,
702
+ start_arns: List[str],
703
+ direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH,
704
+ include_edges: bool = True,
705
+ query_filter: LineageFilter = None,
706
+ max_depth: int = 10,
707
+ ) -> LineageQueryResult:
708
+ """Perform a lineage query.
709
+
710
+ Args:
711
+ start_arns (List[str]): A list of ARNs that will be used as the starting point
712
+ for the query.
713
+ direction (LineageQueryDirectionEnum, optional): The direction of the query.
714
+ include_edges (bool, optional): If true, return edges in addition to vertices.
715
+ query_filter (LineageQueryFilter, optional): The query filter.
716
+
717
+ Returns:
718
+ LineageQueryResult: The lineage query result.
719
+ """
720
+ query_response = self._session.sagemaker_client.query_lineage(
721
+ StartArns=start_arns,
722
+ Direction=direction.value,
723
+ IncludeEdges=include_edges,
724
+ Filters=query_filter._to_request_dict() if query_filter else {},
725
+ MaxDepth=max_depth,
726
+ )
727
+ # create query result for startarn info
728
+ query_result = LineageQueryResult(startarn=start_arns)
729
+ query_response = self._convert_api_response(query_response, query_result)
730
+ query_response = self._collapse_cross_account_artifacts(query_response)
731
+
732
+ return query_response