sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,744 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Placeholder docstring"""
14
+ from __future__ import print_function, absolute_import
15
+
16
+ from abc import ABCMeta, abstractmethod
17
+ from collections import defaultdict, OrderedDict
18
+ import datetime
19
+ import logging
20
+
21
+ from six import with_metaclass
22
+
23
+ from sagemaker.core.helper.session_helper import Session
24
+ from sagemaker.core.common_utils import DeferredError
25
+ from sagemaker.core.lineage import artifact
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ try:
30
+ import pandas as pd
31
+ except ImportError as e:
32
+ logger.warning("pandas failed to import. Analytics features will be impaired or broken.")
33
+ # Any subsequent attempt to use pandas will raise the ImportError
34
+ pd = DeferredError(e)
35
+
36
+ METRICS_PERIOD_DEFAULT = 60 # seconds
37
+
38
+
39
+ class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)):
40
+ """Base class for tuning job or training job analytics classes.
41
+
42
+ Understands common functionality like persistence and caching.
43
+ """
44
+
45
+ def __init__(self):
46
+ """Initializes ``AnalyticsMetricsBase`` instance."""
47
+ self._dataframe = None
48
+
49
+ def export_csv(self, filename):
50
+ """Persists the analytics dataframe to a file.
51
+
52
+ Args:
53
+ filename (str): The name of the file to save to.
54
+ """
55
+ self.dataframe().to_csv(filename)
56
+
57
+ def dataframe(self, force_refresh=False):
58
+ """A pandas dataframe with lots of interesting results about this object.
59
+
60
+ Created by calling SageMaker List and Describe APIs and converting them into a
61
+ convenient tabular summary.
62
+
63
+ Args:
64
+ force_refresh (bool): Set to True to fetch the latest data from
65
+ SageMaker API.
66
+ """
67
+ if force_refresh:
68
+ self.clear_cache()
69
+ if self._dataframe is None:
70
+ self._dataframe = self._fetch_dataframe()
71
+ return self._dataframe
72
+
73
+ @abstractmethod
74
+ def _fetch_dataframe(self):
75
+ """Sub-class must calculate the dataframe and return it."""
76
+
77
+ def clear_cache(self):
78
+ """Clear the object of all local caches of API methods.
79
+
80
+ So that the next time any properties are accessed they will be refreshed from the service.
81
+ """
82
+ self._dataframe = None
83
+
84
+
85
+ class HyperparameterTuningJobAnalytics(AnalyticsMetricsBase):
86
+ """Fetch results about a hyperparameter tuning job and make them accessible for analytics."""
87
+
88
+ def __init__(self, hyperparameter_tuning_job_name, sagemaker_session=None):
89
+ """Initialize a ``HyperparameterTuningJobAnalytics`` instance.
90
+
91
+ Args:
92
+ hyperparameter_tuning_job_name (str): name of the
93
+ HyperparameterTuningJob to analyze.
94
+ sagemaker_session (sagemaker.session.Session): Session object which
95
+ manages interactions with Amazon SageMaker APIs and any other
96
+ AWS services needed. If not specified, one is created using the
97
+ default AWS configuration chain.
98
+ """
99
+ sagemaker_session = sagemaker_session or Session()
100
+ self._sage_client = sagemaker_session.sagemaker_client
101
+ self._tuning_job_name = hyperparameter_tuning_job_name
102
+ self._tuning_job_describe_result = None
103
+ self._training_job_summaries = None
104
+ super(HyperparameterTuningJobAnalytics, self).__init__()
105
+ self.clear_cache()
106
+
107
+ @property
108
+ def name(self):
109
+ """Name of the HyperparameterTuningJob being analyzed"""
110
+ return self._tuning_job_name
111
+
112
+ def __repr__(self):
113
+ """Human-readable representation override."""
114
+ return "<sagemaker.HyperparameterTuningJobAnalytics for %s>" % self.name
115
+
116
+ def clear_cache(self):
117
+ """Clear the object of all local caches of API methods."""
118
+ super(HyperparameterTuningJobAnalytics, self).clear_cache()
119
+ self._tuning_job_describe_result = None
120
+ self._training_job_summaries = None
121
+
122
+ def _fetch_dataframe(self):
123
+ """Return a pandas dataframe with all the training jobs.
124
+
125
+ This includes their hyperparameters, results, and metadata, as well as
126
+ a column to indicate if a training job was the best seen so far.
127
+ """
128
+
129
+ def reshape(training_summary):
130
+ # Helper method to reshape a single training job summary into a dataframe record
131
+ out = {}
132
+ for k, v in training_summary["TunedHyperParameters"].items():
133
+ # Something (bokeh?) gets confused with ints so convert to float
134
+ try:
135
+ v = float(v)
136
+ except (TypeError, ValueError):
137
+ pass
138
+ out[k] = v
139
+ out["TrainingJobName"] = training_summary["TrainingJobName"]
140
+ out["TrainingJobStatus"] = training_summary["TrainingJobStatus"]
141
+ out["FinalObjectiveValue"] = training_summary.get(
142
+ "FinalHyperParameterTuningJobObjectiveMetric", {}
143
+ ).get("Value")
144
+
145
+ start_time = training_summary.get("TrainingStartTime", None)
146
+ end_time = training_summary.get("TrainingEndTime", None)
147
+ out["TrainingStartTime"] = start_time
148
+ out["TrainingEndTime"] = end_time
149
+ if start_time and end_time:
150
+ out["TrainingElapsedTimeSeconds"] = (end_time - start_time).total_seconds()
151
+ if "TrainingJobDefinitionName" in training_summary:
152
+ out["TrainingJobDefinitionName"] = training_summary["TrainingJobDefinitionName"]
153
+ return out
154
+
155
+ # Run that helper over all the summaries.
156
+ df = pd.DataFrame([reshape(tjs) for tjs in self.training_job_summaries()])
157
+ return df
158
+
159
+ @property
160
+ def tuning_ranges(self):
161
+ """A dictionary describing the ranges of all tuned hyperparameters.
162
+
163
+ The keys are the names of the hyperparameter, and the values are the ranges.
164
+
165
+ The output can take one of two forms:
166
+
167
+ * If the 'TrainingJobDefinition' field is present in the job description, the output
168
+ is a dictionary constructed from 'ParameterRanges' in
169
+ 'HyperParameterTuningJobConfig' of the job description. The keys are the
170
+ parameter names, while the values are the parameter ranges.
171
+ Example:
172
+ >>> {
173
+ >>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
174
+ >>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
175
+ >>> "iterations": {"MaxValue": "100", "MinValue": "50", "Name": "iterations"},
176
+ >>> "num_layers": {"MaxValue": "30", "MinValue": "5", "Name": "num_layers"},
177
+ >>> }
178
+ * If the 'TrainingJobDefinitions' field (list) is present in the job description,
179
+ the output is a dictionary with keys as the 'DefinitionName' values from
180
+ all items in 'TrainingJobDefinitions', and each value would be a dictionary
181
+ constructed from 'HyperParameterRanges' in each item in 'TrainingJobDefinitions'
182
+ in the same format as above
183
+ Example:
184
+ >>> {
185
+ >>> "estimator_1": {
186
+ >>> "eta": {"MaxValue": "1", "MinValue": "0", "Name": "eta"},
187
+ >>> "gamma": {"MaxValue": "10", "MinValue": "0", "Name": "gamma"},
188
+ >>> },
189
+ >>> "estimator_2": {
190
+ >>> "framework": {"Values": ["TF", "MXNet"], "Name": "framework"},
191
+ >>> "gamma": {"MaxValue": "1.0", "MinValue": "0.2", "Name": "gamma"}
192
+ >>> }
193
+ >>> }
194
+
195
+ For more details about the 'TrainingJobDefinition' and 'TrainingJobDefinitions' fields
196
+ in job description, see
197
+ https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
198
+ """
199
+ description = self.description()
200
+
201
+ if "TrainingJobDefinition" in description:
202
+ return self._prepare_parameter_ranges(
203
+ description["HyperParameterTuningJobConfig"]["ParameterRanges"]
204
+ )
205
+
206
+ return {
207
+ training_job_definition["DefinitionName"]: self._prepare_parameter_ranges(
208
+ training_job_definition["HyperParameterRanges"]
209
+ )
210
+ for training_job_definition in description["TrainingJobDefinitions"]
211
+ }
212
+
213
+ def _prepare_parameter_ranges(self, parameter_ranges):
214
+ """Convert parameter ranges a dictionary using the parameter range names as the keys"""
215
+ out = {}
216
+ for _, ranges in parameter_ranges.items():
217
+ for param in ranges:
218
+ out[param["Name"]] = param
219
+ return out
220
+
221
+ def description(self, force_refresh=False):
222
+ """Call ``DescribeHyperParameterTuningJob`` for the hyperparameter tuning job.
223
+
224
+ Args:
225
+ force_refresh (bool): Set to True to fetch the latest data from
226
+ SageMaker API.
227
+
228
+ Returns:
229
+ dict: The Amazon SageMaker response for
230
+ ``DescribeHyperParameterTuningJob``.
231
+ """
232
+ if force_refresh:
233
+ self.clear_cache()
234
+ if not self._tuning_job_describe_result:
235
+ self._tuning_job_describe_result = self._sage_client.describe_hyper_parameter_tuning_job( # noqa: E501 # pylint: disable=line-too-long
236
+ HyperParameterTuningJobName=self.name
237
+ )
238
+ return self._tuning_job_describe_result
239
+
240
+ def training_job_summaries(self, force_refresh=False):
241
+ """A (paginated) list of everything from ``ListTrainingJobsForTuningJob``.
242
+
243
+ Args:
244
+ force_refresh (bool): Set to True to fetch the latest data from
245
+ SageMaker API.
246
+
247
+ Returns:
248
+ dict: The Amazon SageMaker response for
249
+ ``ListTrainingJobsForTuningJob``.
250
+ """
251
+ if force_refresh:
252
+ self.clear_cache()
253
+ if self._training_job_summaries is not None:
254
+ return self._training_job_summaries
255
+ output = []
256
+ next_args = {}
257
+ for count in range(100):
258
+ logger.debug("Calling list_training_jobs_for_hyper_parameter_tuning_job %d", count)
259
+ raw_result = self._sage_client.list_training_jobs_for_hyper_parameter_tuning_job(
260
+ HyperParameterTuningJobName=self.name, MaxResults=100, **next_args
261
+ )
262
+ new_output = raw_result["TrainingJobSummaries"]
263
+ output.extend(new_output)
264
+ logger.debug(
265
+ "Got %d more TrainingJobs. Total so far: %d",
266
+ len(new_output),
267
+ len(output),
268
+ )
269
+ if ("NextToken" in raw_result) and (len(new_output) > 0):
270
+ next_args["NextToken"] = raw_result["NextToken"]
271
+ else:
272
+ break
273
+ self._training_job_summaries = output
274
+ return output
275
+
276
+
277
+ class TrainingJobAnalytics(AnalyticsMetricsBase):
278
+ """Fetch training curve data from CloudWatch Metrics for a specific training job."""
279
+
280
+ CLOUDWATCH_NAMESPACE = "/aws/sagemaker/TrainingJobs"
281
+
282
+ def __init__(
283
+ self,
284
+ training_job_name,
285
+ metric_names=None,
286
+ sagemaker_session=None,
287
+ start_time=None,
288
+ end_time=None,
289
+ period=None,
290
+ ):
291
+ """Initialize a ``TrainingJobAnalytics`` instance.
292
+
293
+ Args:
294
+ training_job_name (str): name of the TrainingJob to analyze.
295
+ metric_names (list, optional): string names of all the metrics to
296
+ collect for this training job. If not specified, then it will
297
+ use all metric names configured for this job.
298
+ sagemaker_session (sagemaker.session.Session): Session object which
299
+ manages interactions with Amazon SageMaker APIs and any other
300
+ AWS services needed. If not specified, one is specified using
301
+ the default AWS configuration chain.
302
+ start_time:
303
+ end_time:
304
+ period:
305
+ """
306
+ sagemaker_session = sagemaker_session or Session()
307
+ self._sage_client = sagemaker_session.sagemaker_client
308
+ self._cloudwatch = sagemaker_session.boto_session.client("cloudwatch")
309
+ self._training_job_name = training_job_name
310
+ self._start_time = start_time
311
+ self._end_time = end_time
312
+ self._period = period or METRICS_PERIOD_DEFAULT
313
+
314
+ if metric_names:
315
+ self._metric_names = metric_names
316
+ else:
317
+ self._metric_names = self._metric_names_for_training_job()
318
+
319
+ super(TrainingJobAnalytics, self).__init__()
320
+ self.clear_cache()
321
+
322
+ @property
323
+ def name(self):
324
+ """Name of the TrainingJob being analyzed"""
325
+ return self._training_job_name
326
+
327
+ def __repr__(self):
328
+ """The human-readable representation override."""
329
+ return "<sagemaker.TrainingJobAnalytics for %s>" % self.name
330
+
331
+ def clear_cache(self):
332
+ """Clear the object of all local caches of API methods.
333
+
334
+ This is so that the next time any properties are accessed they will be
335
+ refreshed from the service.
336
+ """
337
+ super(TrainingJobAnalytics, self).clear_cache()
338
+ self._data = defaultdict(list)
339
+ self._time_interval = self._determine_timeinterval()
340
+
341
+ def _determine_timeinterval(self):
342
+ """Return a dict with two datetime objects.
343
+
344
+ The dict includes the `start_time` and `end_time`, covering the interval
345
+ of the training job.
346
+
347
+ Returns:
348
+ a dict with the `start_time` and `end_time`.
349
+ """
350
+ description = self._sage_client.describe_training_job(TrainingJobName=self.name)
351
+ start_time = self._start_time or description["TrainingStartTime"] # datetime object
352
+ # Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
353
+ # This results in logs being searched in the time range in which the correct log line was
354
+ # not present.
355
+ # Example - Log time - 2018-10-22 08:25:55
356
+ # Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
357
+ # CW will consider end time as 2018-10-22 08:25 and will not be able to search the
358
+ # correct log.
359
+ end_time = self._end_time or description.get(
360
+ "TrainingEndTime", datetime.datetime.utcnow()
361
+ ) + datetime.timedelta(minutes=1)
362
+
363
+ return {"start_time": start_time, "end_time": end_time}
364
+
365
+ def _fetch_dataframe(self):
366
+ for metric_name in self._metric_names:
367
+ self._fetch_metric(metric_name)
368
+ return pd.DataFrame(self._data)
369
+
370
+ def _fetch_metric(self, metric_name):
371
+ """Fetch all the values of a named metric, and add them to _data
372
+
373
+ Args:
374
+ metric_name: The metric name to fetch.
375
+ """
376
+ request = {
377
+ "Namespace": self.CLOUDWATCH_NAMESPACE,
378
+ "MetricName": metric_name,
379
+ "Dimensions": [{"Name": "TrainingJobName", "Value": self.name}],
380
+ "StartTime": self._time_interval["start_time"],
381
+ "EndTime": self._time_interval["end_time"],
382
+ "Period": self._period,
383
+ "Statistics": ["Average"],
384
+ }
385
+ raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)["Datapoints"]
386
+ if len(raw_cwm_data) == 0:
387
+ logger.warning("Warning: No metrics called %s found", metric_name)
388
+ return
389
+
390
+ # Process data: normalize to starting time, and sort.
391
+ base_time = min(raw_cwm_data, key=lambda pt: pt["Timestamp"])["Timestamp"]
392
+ all_xy = []
393
+ for pt in raw_cwm_data:
394
+ y = pt["Average"]
395
+ x = (pt["Timestamp"] - base_time).total_seconds()
396
+ all_xy.append([x, y])
397
+ all_xy = sorted(all_xy, key=lambda x: x[0])
398
+
399
+ # Store everything in _data to make a dataframe from
400
+ for elapsed_seconds, value in all_xy:
401
+ self._add_single_metric(elapsed_seconds, metric_name, value)
402
+
403
+ def _add_single_metric(self, timestamp, metric_name, value):
404
+ """Store a single metric in the _data dict.
405
+
406
+ This can be converted to a dataframe.
407
+
408
+ Args:
409
+ timestamp: The timestamp of the metric.
410
+ metric_name: The name of the metric.
411
+ value: The value of the metric.
412
+ """
413
+ # note that this method is built this way to make it possible to
414
+ # support live-refreshing charts in Bokeh at some point in the future.
415
+ self._data["timestamp"].append(timestamp)
416
+ self._data["metric_name"].append(metric_name)
417
+ self._data["value"].append(value)
418
+
419
+ def _metric_names_for_training_job(self):
420
+ """Helper method to discover the metrics defined for a training job."""
421
+ training_description = self._sage_client.describe_training_job(
422
+ TrainingJobName=self._training_job_name
423
+ )
424
+
425
+ metric_definitions = training_description["AlgorithmSpecification"]["MetricDefinitions"]
426
+ metric_names = [md["Name"] for md in metric_definitions]
427
+
428
+ return metric_names
429
+
430
+
431
+ class ArtifactAnalytics(AnalyticsMetricsBase):
432
+ """Fetch artifact data and make them accessible for analytics."""
433
+
434
+ def __init__(
435
+ self,
436
+ sort_by=None,
437
+ sort_order=None,
438
+ source_uri=None,
439
+ artifact_type=None,
440
+ sagemaker_session=None,
441
+ ):
442
+ """Initialize a ``ArtifactAnalytics`` instance.
443
+
444
+ Args:
445
+ sort_by (str, optional): The name of the resource property used to sort
446
+ the set of artifacts. Currently only support for sort by Name
447
+ sort_order(str optional): How trial components are ordered, valid values are Ascending
448
+ and Descending. The default is Descending.
449
+ source_uri(dict optional): The artifact source uri for filtering.
450
+ artifact_type(dict optional): The artifact type for filtering.
451
+ sagemaker_session (obj, optional): Sagemaker session. Defaults to None.
452
+ """
453
+ self._sort_by = sort_by if sort_by == "Name" else None
454
+ self._sort_order = sort_order
455
+ self._source_uri = source_uri
456
+ self._artifact_type = artifact_type
457
+ self._sagemaker_session = sagemaker_session
458
+ super(ArtifactAnalytics, self).__init__()
459
+ self.clear_cache()
460
+
461
+ def __repr__(self):
462
+ """Human-readable representation override."""
463
+ return "<sagemaker.ArtifactAnalytics>"
464
+
465
+ def _reshape_source_type(self, artifact_source_types):
466
+ """Reshape artifact source type."""
467
+ out = OrderedDict()
468
+ for artifact_source_type in artifact_source_types:
469
+ out["ArtifactSourceType"] = artifact_source_type
470
+ return out
471
+
472
+ def _reshape(self, artifact_summary):
473
+ """Reshape artifact summary."""
474
+ out = OrderedDict()
475
+ out["ArtifactName"] = artifact_summary.artifact_name
476
+ out["ArtifactArn"] = artifact_summary.artifact_arn
477
+ out["ArtifactType"] = artifact_summary.artifact_type
478
+ out["ArtifactSourceUri"] = artifact_summary.source.source_uri
479
+ out["CreationTime"] = artifact_summary.creation_time
480
+ out["LastModifiedTime"] = artifact_summary.last_modified_time
481
+ return out
482
+
483
+ def _fetch_dataframe(self):
484
+ """Return a pandas dataframe with all artifacts."""
485
+ df = pd.DataFrame([self._reshape(artifact) for artifact in self._get_list_artifacts()])
486
+ return df
487
+
488
+ def _get_list_artifacts(self):
489
+ """List artifacts."""
490
+ artifacts = artifact.Artifact.list(
491
+ source_uri=self._source_uri,
492
+ artifact_type=self._artifact_type,
493
+ sort_by=self._sort_by,
494
+ sort_order=self._sort_order,
495
+ sagemaker_session=self._sagemaker_session,
496
+ )
497
+ return artifacts
498
+
499
+
500
+ class ExperimentAnalytics(AnalyticsMetricsBase):
501
+ """Fetch trial component data and make them accessible for analytics."""
502
+
503
+ MAX_TRIAL_COMPONENTS = 10000
504
+
505
+ def __init__(
506
+ self,
507
+ experiment_name=None,
508
+ search_expression=None,
509
+ sort_by=None,
510
+ sort_order=None,
511
+ metric_names=None,
512
+ parameter_names=None,
513
+ sagemaker_session=None,
514
+ input_artifact_names=None,
515
+ output_artifact_names=None,
516
+ ):
517
+ """Initialize a ``ExperimentAnalytics`` instance.
518
+
519
+ Args:
520
+ experiment_name (str, optional): Name of the experiment if you want to constrain the
521
+ search to only trial components belonging to an experiment.
522
+ search_expression (dict, optional): The search query to find the set of trial components
523
+ to use to populate the data frame.
524
+ sort_by (str, optional): The name of the resource property used to sort
525
+ the set of trial components.
526
+ sort_order(str optional): How trial components are ordered, valid values are Ascending
527
+ and Descending. The default is Descending.
528
+ metric_names (list, optional): string names of all the metrics to be shown in the
529
+ data frame. If not specified, all metrics will be shown of all trials.
530
+ parameter_names (list, optional): string names of the parameters to be shown in the
531
+ data frame. If not specified, all parameters will be shown of all trials.
532
+ sagemaker_session (sagemaker.session.Session): Session object which manages interactions
533
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified,
534
+ one is created using the default AWS configuration chain.
535
+ input_artifact_names(dict optional):The input artifacts for the experiment. Examples of
536
+ input artifacts are datasets, algorithms, hyperparameters, source code, and instance
537
+ types.
538
+ output_artifact_names(dict optional): The output artifacts for the experiment. Examples
539
+ of output artifacts are metrics, snapshots, logs, and images.
540
+ """
541
+ sagemaker_session = sagemaker_session or Session()
542
+ self._sage_client = sagemaker_session.sagemaker_client
543
+
544
+ if not experiment_name and not search_expression:
545
+ raise ValueError("Either experiment_name or search_expression must be supplied.")
546
+
547
+ self._experiment_name = experiment_name
548
+ self._search_expression = search_expression
549
+ self._sort_by = sort_by
550
+ self._sort_order = sort_order
551
+ self._metric_names = metric_names
552
+ self._parameter_names = parameter_names
553
+ self._input_artifact_names = input_artifact_names
554
+ self._output_artifact_names = output_artifact_names
555
+ self._trial_components = None
556
+ super(ExperimentAnalytics, self).__init__()
557
+ self.clear_cache()
558
+
559
+ @property
560
+ def name(self):
561
+ """Name of the Experiment being analyzed."""
562
+ return self._experiment_name
563
+
564
+ def __repr__(self):
565
+ """The human-readable representation override."""
566
+ return "<sagemaker.ExperimentAnalytics for %s>" % self.name
567
+
568
+ def clear_cache(self):
569
+ """Clear the object of all local caches of API methods."""
570
+ super(ExperimentAnalytics, self).clear_cache()
571
+ self._trial_components = None
572
+
573
+ def _reshape_parameters(self, parameters):
574
+ """Reshape trial component parameters to a pandas column.
575
+
576
+ Args:
577
+ parameters: trial component parameters
578
+ Returns:
579
+ dict: Key: Parameter name, Value: Parameter value
580
+ """
581
+ out = OrderedDict()
582
+ for name, value in sorted(parameters.items()):
583
+ if self._parameter_names and name not in self._parameter_names:
584
+ continue
585
+ out[name] = value.get("NumberValue", value.get("StringValue"))
586
+ return out
587
+
588
+ def _reshape_metrics(self, metrics):
589
+ """Reshape trial component metrics to a pandas column.
590
+
591
+ Args:
592
+ metrics: trial component metrics
593
+ Returns:
594
+ dict: Key: Metric name, Value: Metric value
595
+ """
596
+ statistic_types = ["Min", "Max", "Avg", "StdDev", "Last", "Count"]
597
+ out = OrderedDict()
598
+ for metric_summary in metrics:
599
+ metric_name = metric_summary["MetricName"]
600
+ if self._metric_names and metric_name not in self._metric_names:
601
+ continue
602
+
603
+ for stat_type in statistic_types:
604
+ stat_value = metric_summary.get(stat_type)
605
+ if stat_value is not None:
606
+ out["{} - {}".format(metric_name, stat_type)] = stat_value
607
+ return out
608
+
609
+ def _reshape_artifacts(self, artifacts, _artifact_names):
610
+ """Reshape trial component input/output artifacts to a pandas column.
611
+
612
+ Args:
613
+ artifacts: trial component input/output artifacts
614
+ Returns:
615
+ dict: Key: artifacts name, Value: artifacts value
616
+ """
617
+ out = OrderedDict()
618
+ for name, value in sorted(artifacts.items()):
619
+ if _artifact_names and (name not in _artifact_names):
620
+ continue
621
+ out["{} - {}".format(name, "MediaType")] = value.get("MediaType")
622
+ out["{} - {}".format(name, "Value")] = value.get("Value")
623
+ return out
624
+
625
+ def _reshape_parents(self, parents):
626
+ """Reshape trial component parents to a pandas column.
627
+
628
+ Args:
629
+ parents: trial component parents (trials and experiments)
630
+ Returns:
631
+ dict: Key: artifacts name, Value: artifacts value
632
+ """
633
+ out = OrderedDict()
634
+ trials = []
635
+ experiments = []
636
+ for parent in parents:
637
+ trials.append(parent["TrialName"])
638
+ experiments.append(parent["ExperimentName"])
639
+ out["Trials"] = trials
640
+ out["Experiments"] = experiments
641
+ return out
642
+
643
+ def _reshape(self, trial_component):
644
+ """Reshape trial component data to pandas columns.
645
+
646
+ Args:
647
+ trial_component: dict representing a trial component
648
+ Returns:
649
+ dict: Key-Value pair representing the data in the pandas dataframe
650
+ """
651
+ out = OrderedDict()
652
+ for attribute in ["TrialComponentName", "DisplayName"]:
653
+ out[attribute] = trial_component.get(attribute, "")
654
+
655
+ source = trial_component.get("Source", "")
656
+ if source:
657
+ out["SourceArn"] = source["SourceArn"]
658
+
659
+ out.update(self._reshape_parameters(trial_component.get("Parameters", [])))
660
+ out.update(self._reshape_metrics(trial_component.get("Metrics", [])))
661
+ out.update(
662
+ self._reshape_artifacts(
663
+ trial_component.get("InputArtifacts", []), self._input_artifact_names
664
+ )
665
+ )
666
+ out.update(
667
+ self._reshape_artifacts(
668
+ trial_component.get("OutputArtifacts", []), self._output_artifact_names
669
+ )
670
+ )
671
+ out.update(self._reshape_parents(trial_component.get("Parents", [])))
672
+ return out
673
+
674
+ def _fetch_dataframe(self):
675
+ """Return a pandas dataframe includes all the trial_components."""
676
+
677
+ df = pd.DataFrame([self._reshape(component) for component in self._get_trial_components()])
678
+ return df
679
+
680
+ def _get_trial_components(self, force_refresh=False):
681
+ """Get all trial components matching the given search query expression.
682
+
683
+ Args:
684
+ force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
685
+
686
+ Returns:
687
+ list: List of dicts representing the trial components
688
+ """
689
+ if force_refresh:
690
+ self.clear_cache()
691
+ if self._trial_components is not None:
692
+ return self._trial_components
693
+
694
+ if not self._search_expression:
695
+ self._search_expression = {}
696
+
697
+ if self._experiment_name:
698
+ if not self._search_expression.get("Filters"):
699
+ self._search_expression["Filters"] = []
700
+
701
+ self._search_expression["Filters"].append(
702
+ {
703
+ "Name": "Parents.ExperimentName",
704
+ "Operator": "Equals",
705
+ "Value": self._experiment_name,
706
+ }
707
+ )
708
+
709
+ return self._search(self._search_expression, self._sort_by, self._sort_order)
710
+
711
+ def _search(self, search_expression, sort_by, sort_order):
712
+ """Perform a search query using SageMaker Search and return the matching trial components.
713
+
714
+ Args:
715
+ search_expression: Search expression to filter trial components.
716
+ sort_by: The name of the resource property used to sort the trial components.
717
+ sort_order: How trial components are ordered, valid values are Ascending
718
+ and Descending. The default is Descending.
719
+ Returns:
720
+ list: List of dict representing trial components.
721
+ """
722
+ trial_components = []
723
+
724
+ search_args = {
725
+ "Resource": "ExperimentTrialComponent",
726
+ "SearchExpression": search_expression,
727
+ }
728
+
729
+ if sort_by:
730
+ search_args["SortBy"] = sort_by
731
+
732
+ if sort_order:
733
+ search_args["SortOrder"] = sort_order
734
+
735
+ while len(trial_components) < self.MAX_TRIAL_COMPONENTS:
736
+ search_response = self._sage_client.search(**search_args)
737
+ components = [result["TrialComponent"] for result in search_response["Results"]]
738
+ trial_components.extend(components)
739
+ if "NextToken" in search_response and len(components) > 0:
740
+ search_args["NextToken"] = search_response["NextToken"]
741
+ else:
742
+ break
743
+
744
+ return trial_components