sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,646 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module contains code to create and manage SageMaker ``Artifact``."""
14
+ from __future__ import absolute_import
15
+
16
+ import logging
17
+ import math
18
+
19
+ from datetime import datetime
20
+ from typing import Iterator, Union, Any, Optional, List
21
+
22
+ from sagemaker.core.apiutils import _base_types, _utils
23
+ from sagemaker.core.lineage import _api_types
24
+ from sagemaker.core.lineage._api_types import ArtifactSource, ArtifactSummary
25
+ from sagemaker.core.lineage.query import (
26
+ LineageQuery,
27
+ LineageFilter,
28
+ LineageSourceEnum,
29
+ LineageEntityEnum,
30
+ LineageQueryDirectionEnum,
31
+ )
32
+ from sagemaker.core.lineage._utils import _disassociate, get_resource_name_from_arn
33
+ from sagemaker.core.lineage.association import Association
34
+ from sagemaker.core.common_utils import get_module, format_tags
35
+
36
+ LOGGER = logging.getLogger("sagemaker")
37
+
38
+
39
+ class Artifact(_base_types.Record):
40
+ """An Amazon SageMaker artifact, which is part of a SageMaker lineage.
41
+
42
+ Examples:
43
+ .. code-block:: python
44
+
45
+ from sagemaker.lineage import artifact
46
+
47
+ my_artifact = artifact.Artifact.create(
48
+ artifact_name='MyArtifact',
49
+ artifact_type='S3File',
50
+ source_uri='s3://...')
51
+
52
+ my_artifact.properties["added"] = "property"
53
+ my_artifact.save()
54
+
55
+ for artfct in artifact.Artifact.list():
56
+ print(artfct)
57
+
58
+ my_artifact.delete()
59
+
60
+ Attributes:
61
+ artifact_arn (str): The ARN of the artifact.
62
+ artifact_name (str): The name of the artifact.
63
+ artifact_type (str): The type of the artifact.
64
+ source (obj): The source of the artifact with a URI and types.
65
+ properties (dict): Dictionary of properties.
66
+ tags (List[dict[str, str]]): A list of tags to associate with the artifact.
67
+ creation_time (datetime): When the artifact was created.
68
+ created_by (obj): Contextual info on which account created the artifact.
69
+ last_modified_time (datetime): When the artifact was last modified.
70
+ last_modified_by (obj): Contextual info on which account created the artifact.
71
+ """
72
+
73
+ artifact_arn: str = None
74
+ artifact_name: str = None
75
+ artifact_type: str = None
76
+ source: ArtifactSource = None
77
+ properties: dict = None
78
+ tags: list = None
79
+ creation_time: datetime = None
80
+ created_by: str = None
81
+ last_modified_time: datetime = None
82
+ last_modified_by: str = None
83
+
84
+ _boto_create_method: str = "create_artifact"
85
+ _boto_load_method: str = "describe_artifact"
86
+ _boto_update_method: str = "update_artifact"
87
+ _boto_delete_method: str = "delete_artifact"
88
+
89
+ _boto_update_members = [
90
+ "artifact_arn",
91
+ "artifact_name",
92
+ "properties",
93
+ "properties_to_remove",
94
+ ]
95
+
96
+ _boto_delete_members = ["artifact_arn"]
97
+
98
+ _custom_boto_types = {"source": (_api_types.ArtifactSource, False)}
99
+
100
+ def save(self) -> "Artifact":
101
+ """Save the state of this Artifact to SageMaker.
102
+
103
+ Note that this method must be run from a SageMaker context such as Studio or a training job
104
+ due to restrictions on the CreateArtifact API.
105
+
106
+ Returns:
107
+ Artifact: A SageMaker `Artifact` object.
108
+ """
109
+ return self._invoke_api(self._boto_update_method, self._boto_update_members)
110
+
111
+ def delete(self, disassociate: bool = False):
112
+ """Delete the artifact object.
113
+
114
+ Args:
115
+ disassociate (bool): When set to true, disassociate incoming and outgoing association.
116
+ """
117
+ if disassociate:
118
+ _disassociate(source_arn=self.artifact_arn, sagemaker_session=self.sagemaker_session)
119
+ _disassociate(
120
+ destination_arn=self.artifact_arn,
121
+ sagemaker_session=self.sagemaker_session,
122
+ )
123
+ self._invoke_api(self._boto_delete_method, self._boto_delete_members)
124
+
125
+ @classmethod
126
+ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
127
+ """Load an existing artifact and return an ``Artifact`` object representing it.
128
+
129
+ Args:
130
+ artifact_arn (str): ARN of the artifact
131
+ sagemaker_session (sagemaker.session.Session): Session object which
132
+ manages interactions with Amazon SageMaker APIs and any other
133
+ AWS services needed. If not specified, one is created using the
134
+ default AWS configuration chain.
135
+
136
+ Returns:
137
+ Artifact: A SageMaker ``Artifact`` object
138
+ """
139
+ artifact = cls._construct(
140
+ cls._boto_load_method,
141
+ artifact_arn=artifact_arn,
142
+ sagemaker_session=sagemaker_session,
143
+ )
144
+ return artifact
145
+
146
+ def downstream_trials(self, sagemaker_session=None) -> list:
147
+ """Use the lineage API to retrieve all downstream trials that use this artifact.
148
+
149
+ Args:
150
+ sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
151
+ will be created.
152
+
153
+ Returns:
154
+ [Trial]: A list of SageMaker `Trial` objects.
155
+ """
156
+ # don't specify destination type because for Trial Components it could be one of
157
+ # SageMaker[TrainingJob|ProcessingJob|TransformJob|ExperimentTrialComponent]
158
+ outgoing_associations: Iterator = Association.list(
159
+ source_arn=self.artifact_arn, sagemaker_session=sagemaker_session
160
+ )
161
+ trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))
162
+
163
+ return self._get_trial_from_trial_component(trial_component_arns)
164
+
165
+ def downstream_trials_v2(self) -> list:
166
+ """Use a lineage query to retrieve all downstream trials that use this artifact.
167
+
168
+ Returns:
169
+ [Trial]: A list of SageMaker `Trial` objects.
170
+ """
171
+ return self._trials(direction=LineageQueryDirectionEnum.DESCENDANTS)
172
+
173
+ def upstream_trials(self) -> List:
174
+ """Use the lineage query to retrieve all upstream trials that use this artifact.
175
+
176
+ Returns:
177
+ [Trial]: A list of SageMaker `Trial` objects.
178
+ """
179
+ return self._trials(direction=LineageQueryDirectionEnum.ASCENDANTS)
180
+
181
+ def _trials(
182
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
183
+ ) -> List:
184
+ """Use the lineage query to retrieve all trials that use this artifact.
185
+
186
+ Args:
187
+ direction (LineageQueryDirectionEnum, optional): The query direction.
188
+
189
+ Returns:
190
+ [Trial]: A list of SageMaker `Trial` objects.
191
+ """
192
+ query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
193
+ query_result = LineageQuery(self.sagemaker_session).query(
194
+ start_arns=[self.artifact_arn],
195
+ query_filter=query_filter,
196
+ direction=direction,
197
+ include_edges=False,
198
+ )
199
+ trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
200
+ return self._get_trial_from_trial_component(trial_component_arns)
201
+
202
+ def _get_trial_from_trial_component(self, trial_component_arns: list) -> List:
203
+ """Retrieve all upstream trial runs which that use the trial component arns.
204
+
205
+ Args:
206
+ trial_component_arns (list): list of trial component arns
207
+
208
+ Returns:
209
+ [Trial]: A list of SageMaker `Trial` objects.
210
+ """
211
+ if not trial_component_arns:
212
+ # no outgoing associations for this artifact
213
+ return []
214
+
215
+ get_module("smexperiments")
216
+ from smexperiments import trial_component, search_expression
217
+
218
+ max_search_by_arn: int = 60
219
+ num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
220
+ trial_components: list = []
221
+
222
+ sagemaker_session = self.sagemaker_session or _utils.default_session()
223
+ sagemaker_client = sagemaker_session.sagemaker_client
224
+
225
+ for i in range(num_search_batches):
226
+ start: int = i * max_search_by_arn
227
+ end: int = start + max_search_by_arn
228
+ arn_batch: list = trial_component_arns[start:end]
229
+ se: Any = self._get_search_expression(arn_batch, search_expression)
230
+ search_result: Any = trial_component.TrialComponent.search(
231
+ search_expression=se, sagemaker_boto_client=sagemaker_client
232
+ )
233
+
234
+ trial_components: list = trial_components + list(search_result)
235
+
236
+ trials: set = set()
237
+
238
+ for tc in list(trial_components):
239
+ for parent in tc.parents:
240
+ trials.add(parent["TrialName"])
241
+
242
+ return list(trials)
243
+
244
+ def _get_search_expression(self, arns: list, search_expression: object) -> object:
245
+ """Convert a set of arns to a search expression.
246
+
247
+ Args:
248
+ arns (list): Trial Component arns to search for.
249
+ search_expression (obj): smexperiments.search_expression
250
+
251
+ Returns:
252
+ search_expression (obj): Arns converted to a Trial Component search expression.
253
+ """
254
+ max_arn_per_filter: int = 3
255
+ num_filters: Union[float, int] = math.ceil(len(arns) / max_arn_per_filter)
256
+ filters: list = []
257
+
258
+ for i in range(num_filters):
259
+ start: int = i * max_arn_per_filter
260
+ end: int = i + max_arn_per_filter
261
+ batch_arns: list = arns[start:end]
262
+ search_filter = search_expression.Filter(
263
+ name="TrialComponentArn",
264
+ operator=search_expression.Operator.EQUALS,
265
+ value=",".join(batch_arns),
266
+ )
267
+
268
+ filters.append(search_filter)
269
+
270
+ search_expression = search_expression.SearchExpression(
271
+ filters=filters,
272
+ boolean_operator=search_expression.BooleanOperator.OR,
273
+ )
274
+ return search_expression
275
+
276
+ def set_tag(self, tag=None):
277
+ """Add a tag to the object.
278
+
279
+ Args:
280
+ tag (obj): Key value pair to set tag.
281
+
282
+ Returns:
283
+ list({str:str}): a list of key value pairs
284
+ """
285
+ return self._set_tags(resource_arn=self.artifact_arn, tags=[tag])
286
+
287
+ def set_tags(self, tags=None):
288
+ """Add tags to the object.
289
+
290
+ Args:
291
+ tags (Optional[Tags]): list of key value pairs.
292
+
293
+ Returns:
294
+ list({str:str}): a list of key value pairs
295
+ """
296
+ return self._set_tags(resource_arn=self.artifact_arn, tags=format_tags(tags))
297
+
298
+ @classmethod
299
+ def create(
300
+ cls,
301
+ artifact_name: Optional[str] = None,
302
+ source_uri: Optional[str] = None,
303
+ source_types: Optional[list] = None,
304
+ artifact_type: Optional[str] = None,
305
+ properties: Optional[dict] = None,
306
+ tags: Optional[dict] = None,
307
+ sagemaker_session=None,
308
+ ) -> "Artifact":
309
+ """Create an artifact and return an ``Artifact`` object representing it.
310
+
311
+ Args:
312
+ artifact_name (str, optional): Name of the artifact
313
+ source_uri (str, optional): Source URI of the artifact
314
+ source_types (list, optional): Source types
315
+ artifact_type (str, optional): Type of the artifact
316
+ properties (dict, optional): key/value properties
317
+ tags (dict, optional): AWS tags for the artifact
318
+ sagemaker_session (sagemaker.session.Session): Session object which
319
+ manages interactions with Amazon SageMaker APIs and any other
320
+ AWS services needed. If not specified, one is created using the
321
+ default AWS configuration chain.
322
+
323
+ Returns:
324
+ Artifact: A SageMaker ``Artifact`` object.
325
+ """
326
+ return super(Artifact, cls)._construct(
327
+ cls._boto_create_method,
328
+ artifact_name=artifact_name,
329
+ source=_api_types.ArtifactSource(source_uri=source_uri, source_types=source_types),
330
+ artifact_type=artifact_type,
331
+ properties=properties,
332
+ tags=tags,
333
+ sagemaker_session=sagemaker_session,
334
+ )
335
+
336
+ @classmethod
337
+ def list(
338
+ cls,
339
+ source_uri: Optional[str] = None,
340
+ artifact_type: Optional[str] = None,
341
+ created_before: Optional[datetime] = None,
342
+ created_after: Optional[datetime] = None,
343
+ sort_by: Optional[str] = None,
344
+ sort_order: Optional[str] = None,
345
+ max_results: Optional[int] = None,
346
+ next_token: Optional[str] = None,
347
+ sagemaker_session=None,
348
+ ) -> Iterator[ArtifactSummary]:
349
+ """Return a list of artifact summaries.
350
+
351
+ Args:
352
+ source_uri (str, optional): A source URI.
353
+ artifact_type (str, optional): An artifact type.
354
+ created_before (datetime.datetime, optional): Return artifacts created before this
355
+ instant.
356
+ created_after (datetime.datetime, optional): Return artifacts created after this
357
+ instant.
358
+ sort_by (str, optional): Which property to sort results by.
359
+ One of 'SourceArn', 'CreatedBefore','CreatedAfter'
360
+ sort_order (str, optional): One of 'Ascending', or 'Descending'.
361
+ max_results (int, optional): maximum number of artifacts to retrieve
362
+ next_token (str, optional): token for next page of results
363
+ sagemaker_session (sagemaker.session.Session): Session object which
364
+ manages interactions with Amazon SageMaker APIs and any other
365
+ AWS services needed. If not specified, one is created using the
366
+ default AWS configuration chain.
367
+
368
+ Returns:
369
+ collections.Iterator[ArtifactSummary]: An iterator
370
+ over ``ArtifactSummary`` objects.
371
+ """
372
+ return super(Artifact, cls)._list(
373
+ "list_artifacts",
374
+ _api_types.ArtifactSummary.from_boto,
375
+ "ArtifactSummaries",
376
+ source_uri=source_uri,
377
+ artifact_type=artifact_type,
378
+ created_before=created_before,
379
+ created_after=created_after,
380
+ sort_by=sort_by,
381
+ sort_order=sort_order,
382
+ max_results=max_results,
383
+ next_token=next_token,
384
+ sagemaker_session=sagemaker_session,
385
+ )
386
+
387
+ def s3_uri_artifacts(self, s3_uri: str) -> dict:
388
+ """Retrieve a list of artifacts that use provided s3 uri.
389
+
390
+ Args:
391
+ s3_uri (str): A S3 URI.
392
+
393
+ Returns:
394
+ A list of ``Artifacts``
395
+ """
396
+ return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri)
397
+
398
+
399
+ class ModelArtifact(Artifact):
400
+ """A SageMaker lineage artifact representing a model.
401
+
402
+ Common model specific lineage traversals to discover how the model is connected
403
+ to other entities.
404
+ """
405
+
406
+ def endpoints(self) -> list:
407
+ """Get association summaries for endpoints deployed with this model.
408
+
409
+ Returns:
410
+ [AssociationSummary]: A list of associations representing the endpoints using the model.
411
+ """
412
+ endpoint_development_actions: Iterator = Association.list(
413
+ source_arn=self.artifact_arn,
414
+ destination_type="Action",
415
+ sagemaker_session=self.sagemaker_session,
416
+ )
417
+
418
+ endpoint_context_list: list = [
419
+ endpoint_context_associations
420
+ for endpoint_development_action in endpoint_development_actions
421
+ for endpoint_context_associations in Association.list(
422
+ source_arn=endpoint_development_action.destination_arn,
423
+ destination_type="Context",
424
+ sagemaker_session=self.sagemaker_session,
425
+ )
426
+ ]
427
+ return endpoint_context_list
428
+
429
+ def endpoint_contexts(
430
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
431
+ ) -> List["Context"]:
432
+ """Get contexts representing endpoints from the models's lineage.
433
+
434
+ Args:
435
+ direction (LineageQueryDirectionEnum, optional): The query direction.
436
+
437
+ Returns:
438
+ list of Contexts: Contexts representing an endpoint.
439
+ """
440
+ query_filter = LineageFilter(
441
+ entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
442
+ )
443
+ query_result = LineageQuery(self.sagemaker_session).query(
444
+ start_arns=[self.artifact_arn],
445
+ query_filter=query_filter,
446
+ direction=direction,
447
+ include_edges=False,
448
+ )
449
+
450
+ endpoint_contexts = []
451
+ for vertex in query_result.vertices:
452
+ endpoint_contexts.append(vertex.to_lineage_object())
453
+ return endpoint_contexts
454
+
455
+ def dataset_artifacts(
456
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
457
+ ) -> List[Artifact]:
458
+ """Get artifacts representing datasets from the model's lineage.
459
+
460
+ Args:
461
+ direction (LineageQueryDirectionEnum, optional): The query direction.
462
+
463
+ Returns:
464
+ list of Artifacts: Artifacts representing a dataset.
465
+ """
466
+ query_filter = LineageFilter(
467
+ entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
468
+ )
469
+ query_result = LineageQuery(self.sagemaker_session).query(
470
+ start_arns=[self.artifact_arn],
471
+ query_filter=query_filter,
472
+ direction=direction,
473
+ include_edges=False,
474
+ )
475
+
476
+ dataset_artifacts = []
477
+ for vertex in query_result.vertices:
478
+ dataset_artifacts.append(vertex.to_lineage_object())
479
+ return dataset_artifacts
480
+
481
+ def training_job_arns(
482
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
483
+ ) -> List[str]:
484
+ """Get ARNs for all training jobs that appear in the model's lineage.
485
+
486
+ Returns:
487
+ list of str: Training job ARNs.
488
+ """
489
+ query_filter = LineageFilter(
490
+ entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
491
+ )
492
+ query_result = LineageQuery(self.sagemaker_session).query(
493
+ start_arns=[self.artifact_arn],
494
+ query_filter=query_filter,
495
+ direction=direction,
496
+ include_edges=False,
497
+ )
498
+
499
+ training_job_arns = []
500
+ for vertex in query_result.vertices:
501
+ trial_component_name = get_resource_name_from_arn(vertex.arn)
502
+ trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
503
+ TrialComponentName=trial_component_name
504
+ )
505
+ training_job_arns.append(trial_component["Source"]["SourceArn"])
506
+ return training_job_arns
507
+
508
+ def pipeline_execution_arn(
509
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
510
+ ) -> str:
511
+ """Get the ARN for the pipeline execution associated with this model (if any).
512
+
513
+ Returns:
514
+ str: A pipeline execution ARN.
515
+ """
516
+ training_job_arns = self.training_job_arns(direction=direction)
517
+ for training_job_arn in training_job_arns:
518
+ tags = self.sagemaker_session.sagemaker_client.list_tags(ResourceArn=training_job_arn)[
519
+ "Tags"
520
+ ]
521
+ for tag in tags:
522
+ if tag["Key"] == "sagemaker:pipeline-execution-arn":
523
+ return tag["Value"]
524
+
525
+ return None
526
+
527
+
528
+ class DatasetArtifact(Artifact):
529
+ """A SageMaker Lineage artifact representing a dataset.
530
+
531
+ Encapsulates common dataset specific lineage traversals to discover how the dataset is
532
+ connect to related entities.
533
+ """
534
+
535
+ def trained_models(self) -> List[Association]:
536
+ """Given a dataset artifact, get associated trained models.
537
+
538
+ Returns:
539
+ list(Association): List of Contexts representing model artifacts.
540
+ """
541
+ trial_components: Iterator = Association.list(
542
+ source_arn=self.artifact_arn, sagemaker_session=self.sagemaker_session
543
+ )
544
+ result: list = []
545
+ for trial_component in trial_components:
546
+ if "experiment-trial-component" in trial_component.destination_arn:
547
+ models = Association.list(
548
+ source_arn=trial_component.destination_arn,
549
+ destination_type="Context",
550
+ sagemaker_session=self.sagemaker_session,
551
+ )
552
+ result.extend(models)
553
+
554
+ return result
555
+
556
+ def endpoint_contexts(
557
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
558
+ ) -> List["Context"]:
559
+ """Get contexts representing endpoints from the dataset's lineage.
560
+
561
+ Args:
562
+ direction (LineageQueryDirectionEnum, optional): The query direction.
563
+
564
+ Returns:
565
+ list of Contexts: Contexts representing an endpoint.
566
+ """
567
+ query_filter = LineageFilter(
568
+ entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
569
+ )
570
+ query_result = LineageQuery(self.sagemaker_session).query(
571
+ start_arns=[self.artifact_arn],
572
+ query_filter=query_filter,
573
+ direction=direction,
574
+ include_edges=False,
575
+ )
576
+
577
+ endpoint_contexts = []
578
+ for vertex in query_result.vertices:
579
+ endpoint_contexts.append(vertex.to_lineage_object())
580
+ return endpoint_contexts
581
+
582
+ def upstream_datasets(self) -> List[Artifact]:
583
+ """Use the lineage query to retrieve upstream artifacts that use this dataset artifact.
584
+
585
+ Returns:
586
+ list of Artifacts: Artifacts representing an dataset.
587
+ """
588
+ return self._datasets(direction=LineageQueryDirectionEnum.ASCENDANTS)
589
+
590
+ def downstream_datasets(self) -> List[Artifact]:
591
+ """Use the lineage query to retrieve downstream artifacts that use this dataset.
592
+
593
+ Returns:
594
+ list of Artifacts: Artifacts representing an dataset.
595
+ """
596
+ return self._datasets(direction=LineageQueryDirectionEnum.DESCENDANTS)
597
+
598
+ def _datasets(
599
+ self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
600
+ ) -> List[Artifact]:
601
+ """Use the lineage query to retrieve all artifacts that use this dataset.
602
+
603
+ Args:
604
+ direction (LineageQueryDirectionEnum, optional): The query direction.
605
+
606
+ Returns:
607
+ list of Artifacts: Artifacts representing an dataset.
608
+ """
609
+ query_filter = LineageFilter(
610
+ entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
611
+ )
612
+ query_result = LineageQuery(self.sagemaker_session).query(
613
+ start_arns=[self.artifact_arn],
614
+ query_filter=query_filter,
615
+ direction=direction,
616
+ include_edges=False,
617
+ )
618
+ return [vertex.to_lineage_object() for vertex in query_result.vertices]
619
+
620
+
621
+ class ImageArtifact(Artifact):
622
+ """A SageMaker lineage artifact representing an image.
623
+
624
+ Common model specific lineage traversals to discover how the image is connected
625
+ to other entities.
626
+ """
627
+
628
+ def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
629
+ """Use the lineage query to retrieve datasets that use this image artifact.
630
+
631
+ Args:
632
+ direction (LineageQueryDirectionEnum): The query direction.
633
+
634
+ Returns:
635
+ list of Artifacts: Artifacts representing a dataset.
636
+ """
637
+ query_filter = LineageFilter(
638
+ entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
639
+ )
640
+ query_result = LineageQuery(self.sagemaker_session).query(
641
+ start_arns=[self.artifact_arn],
642
+ query_filter=query_filter,
643
+ direction=direction,
644
+ include_edges=False,
645
+ )
646
+ return [vertex.to_lineage_object() for vertex in query_result.vertices]