sagemaker-core 1.0.62__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (362) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +408 -3
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/constants.py +16 -0
  176. sagemaker/core/jumpstart/hub/hub.py +291 -0
  177. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  178. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  179. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  180. sagemaker/core/jumpstart/hub/types.py +35 -0
  181. sagemaker/core/jumpstart/hub/utils.py +260 -0
  182. sagemaker/core/jumpstart/models.py +499 -0
  183. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  184. sagemaker/core/jumpstart/parameters.py +20 -0
  185. sagemaker/core/jumpstart/payload_utils.py +239 -0
  186. sagemaker/core/jumpstart/region_config.json +163 -0
  187. sagemaker/core/jumpstart/search.py +171 -0
  188. sagemaker/core/jumpstart/serializers.py +81 -0
  189. sagemaker/core/jumpstart/session_utils.py +234 -0
  190. sagemaker/core/jumpstart/types.py +3044 -0
  191. sagemaker/core/jumpstart/utils.py +1731 -0
  192. sagemaker/core/jumpstart/validators.py +257 -0
  193. sagemaker/core/lambda_helper.py +312 -0
  194. sagemaker/core/lineage/__init__.py +42 -0
  195. sagemaker/core/lineage/_api_types.py +239 -0
  196. sagemaker/core/lineage/_utils.py +49 -0
  197. sagemaker/core/lineage/action.py +345 -0
  198. sagemaker/core/lineage/artifact.py +646 -0
  199. sagemaker/core/lineage/association.py +190 -0
  200. sagemaker/core/lineage/context.py +505 -0
  201. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  202. sagemaker/core/lineage/query.py +732 -0
  203. sagemaker/core/lineage/visualizer.py +346 -0
  204. sagemaker/core/local/__init__.py +18 -0
  205. sagemaker/core/local/data.py +413 -0
  206. sagemaker/core/local/entities.py +678 -0
  207. sagemaker/core/local/exceptions.py +17 -0
  208. sagemaker/core/local/image.py +1243 -0
  209. sagemaker/core/local/local_session.py +739 -0
  210. sagemaker/core/local/utils.py +245 -0
  211. sagemaker/core/logs.py +181 -0
  212. sagemaker/core/metadata_properties.py +56 -0
  213. sagemaker/core/metric_definitions.py +91 -0
  214. sagemaker/core/mlflow/__init__.py +38 -0
  215. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  216. sagemaker/core/model_card/__init__.py +26 -0
  217. sagemaker/core/model_life_cycle.py +51 -0
  218. sagemaker/core/model_metrics.py +160 -0
  219. sagemaker/core/model_monitor/__init__.py +66 -0
  220. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  221. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  222. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  223. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  224. sagemaker/core/model_monitor/dataset_format.py +102 -0
  225. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  226. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  227. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  228. sagemaker/core/model_monitor/utils.py +793 -0
  229. sagemaker/core/model_registry.py +480 -0
  230. sagemaker/core/model_uris.py +97 -0
  231. sagemaker/core/modules/__init__.py +19 -0
  232. sagemaker/core/modules/configs.py +226 -0
  233. sagemaker/core/modules/constants.py +37 -0
  234. sagemaker/core/modules/distributed.py +182 -0
  235. sagemaker/core/modules/local_core/__init__.py +0 -0
  236. sagemaker/core/modules/local_core/local_container.py +605 -0
  237. sagemaker/core/modules/templates.py +83 -0
  238. sagemaker/core/modules/train/__init__.py +14 -0
  239. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  242. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  247. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  249. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  250. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  251. sagemaker/core/modules/types.py +19 -0
  252. sagemaker/core/modules/utils.py +194 -0
  253. sagemaker/core/network.py +185 -0
  254. sagemaker/core/parameter.py +173 -0
  255. sagemaker/core/payloads.py +185 -0
  256. sagemaker/core/processing.py +1597 -0
  257. sagemaker/core/remote_function/__init__.py +19 -0
  258. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  259. sagemaker/core/remote_function/client.py +1285 -0
  260. sagemaker/core/remote_function/core/__init__.py +0 -0
  261. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  262. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  263. sagemaker/core/remote_function/core/serialization.py +422 -0
  264. sagemaker/core/remote_function/core/stored_function.py +226 -0
  265. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  266. sagemaker/core/remote_function/errors.py +104 -0
  267. sagemaker/core/remote_function/invoke_function.py +172 -0
  268. sagemaker/core/remote_function/job.py +2140 -0
  269. sagemaker/core/remote_function/logging_config.py +38 -0
  270. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  271. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  272. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  273. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  274. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  275. sagemaker/core/remote_function/spark_config.py +149 -0
  276. sagemaker/core/resource_requirements.py +168 -0
  277. {sagemaker_core/main → sagemaker/core}/resources.py +19098 -10895
  278. sagemaker/core/s3/__init__.py +41 -0
  279. sagemaker/core/s3/client.py +367 -0
  280. sagemaker/core/s3/utils.py +175 -0
  281. sagemaker/core/script_uris.py +93 -0
  282. sagemaker/core/serializers/__init__.py +11 -0
  283. sagemaker/core/serializers/base.py +510 -0
  284. sagemaker/core/serializers/implementations.py +159 -0
  285. sagemaker/core/serializers/utils.py +223 -0
  286. sagemaker/core/serverless_inference_config.py +63 -0
  287. sagemaker/core/session_settings.py +55 -0
  288. sagemaker/core/shapes/__init__.py +3 -0
  289. sagemaker/core/shapes/model_card_shapes.py +159 -0
  290. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +5810 -1806
  291. sagemaker/core/spark/__init__.py +16 -0
  292. sagemaker/core/spark/defaults.py +16 -0
  293. sagemaker/core/spark/processing.py +1380 -0
  294. sagemaker/core/telemetry/__init__.py +23 -0
  295. sagemaker/core/telemetry/constants.py +84 -0
  296. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  297. sagemaker/core/tools/__init__.py +1 -0
  298. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  299. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  300. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  301. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  302. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  303. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  304. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  305. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  306. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  307. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  308. sagemaker/core/training/__init__.py +14 -0
  309. sagemaker/core/training/configs.py +333 -0
  310. sagemaker/core/training/constants.py +37 -0
  311. sagemaker/core/training/utils.py +77 -0
  312. sagemaker/core/training_compiler/__init__.py +16 -0
  313. sagemaker/core/training_compiler/config.py +197 -0
  314. sagemaker/core/training_compiler_config.py +197 -0
  315. sagemaker/core/transformer.py +793 -0
  316. sagemaker/core/user_agent.py +76 -0
  317. sagemaker/core/utilities/__init__.py +24 -0
  318. sagemaker/core/utilities/cache.py +169 -0
  319. sagemaker/core/utilities/search_expression.py +133 -0
  320. sagemaker/core/utils/__init__.py +48 -0
  321. sagemaker/core/utils/code_injection/__init__.py +0 -0
  322. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +5979 -176
  324. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  325. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  326. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  327. {sagemaker_core/main → sagemaker/core/utils}/utils.py +27 -22
  328. sagemaker/core/workflow/__init__.py +152 -0
  329. sagemaker/core/workflow/conditions.py +313 -0
  330. sagemaker/core/workflow/entities.py +58 -0
  331. sagemaker/core/workflow/execution_variables.py +89 -0
  332. sagemaker/core/workflow/functions.py +193 -0
  333. sagemaker/core/workflow/parameters.py +222 -0
  334. sagemaker/core/workflow/pipeline_context.py +394 -0
  335. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  336. sagemaker/core/workflow/properties.py +285 -0
  337. sagemaker/core/workflow/step_outputs.py +65 -0
  338. sagemaker/core/workflow/utilities.py +507 -0
  339. sagemaker/lineage/__init__.py +33 -0
  340. sagemaker/lineage/action.py +28 -0
  341. sagemaker/lineage/artifact.py +28 -0
  342. sagemaker/lineage/context.py +28 -0
  343. sagemaker/lineage/lineage_trial_component.py +28 -0
  344. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  345. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  346. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  347. sagemaker_core/_version.py +0 -3
  348. sagemaker_core/helper/session_helper.py +0 -769
  349. sagemaker_core/resources/__init__.py +0 -1
  350. sagemaker_core/shapes/__init__.py +0 -1
  351. sagemaker_core/tools/__init__.py +0 -1
  352. sagemaker_core-1.0.62.dist-info/RECORD +0 -35
  353. sagemaker_core-1.0.62.dist-info/top_level.txt +0 -1
  354. {sagemaker_core → sagemaker/core/helper}/__init__.py +0 -0
  355. {sagemaker_core/helper → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  356. {sagemaker_core/main → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  357. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/hub}/__init__.py +0 -0
  358. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  361. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  362. {sagemaker_core-1.0.62.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,294 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Contains the helper classes for SageMaker Experiment."""
14
+ from __future__ import absolute_import
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+
20
+ import botocore
21
+
22
+ from sagemaker.core import s3
23
+ from sagemaker.core.experiments._utils import is_already_exist_error
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ _DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts"
29
+ _DEFAULT_ARTIFACT_TYPE = "Tracker"
30
+
31
+
32
+ class _ArtifactUploader(object):
33
+ """Artifact uploader"""
34
+
35
+ def __init__(
36
+ self,
37
+ trial_component_name,
38
+ sagemaker_session,
39
+ artifact_bucket=None,
40
+ artifact_prefix=_DEFAULT_ARTIFACT_PREFIX,
41
+ ):
42
+ """Initialize a `_ArtifactUploader` instance.
43
+
44
+ Args:
45
+ trial_component_name (str): The name of the trial component,
46
+ which is used to generate the S3 path to upload the artifact to.
47
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
48
+ manages interactions with Amazon SageMaker APIs and any other
49
+ AWS services needed.
50
+ artifact_bucket (str): The S3 bucket to upload the artifact to.
51
+ If not specified, the default bucket defined in `sagemaker_session`
52
+ will be used.
53
+ artifact_prefix (str): The S3 key prefix used to generate the S3 path
54
+ to upload the artifact to (default: "trial-component-artifacts").
55
+ """
56
+ self.sagemaker_session = sagemaker_session
57
+ self.trial_component_name = trial_component_name
58
+ self.artifact_bucket = artifact_bucket
59
+ self.artifact_prefix = artifact_prefix
60
+ self._s3_client = self.sagemaker_session.boto_session.client("s3")
61
+
62
+ def upload_artifact(self, file_path, extra_args=None):
63
+ """Upload an artifact file to S3.
64
+
65
+ Args:
66
+ file_path (str): the file path of the artifact
67
+ extra_args (dict): Optional extra arguments that may be passed to the upload operation.
68
+ Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
69
+ ExtraArgs parameter documentation here:
70
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
71
+
72
+ Returns:
73
+ (str, str): The s3 URI of the uploaded file and the etag of the file.
74
+
75
+ Raises:
76
+ ValueError: If file does not exist.
77
+ """
78
+ file_path = os.path.expanduser(file_path)
79
+ if not os.path.isfile(file_path):
80
+ raise ValueError(
81
+ "{} does not exist or is not a file. Please supply a file path.".format(file_path)
82
+ )
83
+
84
+ # If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket.
85
+ # In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix
86
+ # needs to be updated too (because not updating self.artifact_prefix would result in
87
+ # different behavior the 1st time this method is called vs the 2nd).
88
+ self.artifact_bucket, self.artifact_prefix = s3.determine_bucket_and_prefix(
89
+ bucket=self.artifact_bucket,
90
+ key_prefix=self.artifact_prefix,
91
+ sagemaker_session=self.sagemaker_session,
92
+ )
93
+
94
+ artifact_name = os.path.basename(file_path)
95
+ artifact_s3_key = "{}/{}/{}".format(
96
+ self.artifact_prefix, self.trial_component_name, artifact_name
97
+ )
98
+ self._s3_client.upload_file(
99
+ file_path,
100
+ self.artifact_bucket,
101
+ artifact_s3_key,
102
+ ExtraArgs=extra_args,
103
+ )
104
+ etag = self._try_get_etag(artifact_s3_key)
105
+ return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
106
+
107
+ def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None):
108
+ """Upload an artifact object to S3.
109
+
110
+ Args:
111
+ artifact_name (str): the name of the artifact.
112
+ artifact_object (obj): the object of the artifact
113
+ file_extension (str): Optional file extension.
114
+
115
+ Returns:
116
+ str: The s3 URI of the uploaded file and the version of the file.
117
+ """
118
+
119
+ # If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket.
120
+ # In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix
121
+ # needs to be updated too (because not updating self.artifact_prefix would result in
122
+ # different behavior the 1st time this method is called vs the 2nd).
123
+ self.artifact_bucket, self.artifact_prefix = s3.determine_bucket_and_prefix(
124
+ bucket=self.artifact_bucket,
125
+ key_prefix=self.artifact_prefix,
126
+ sagemaker_session=self.sagemaker_session,
127
+ )
128
+
129
+ if file_extension:
130
+ artifact_name = (
131
+ artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension
132
+ )
133
+ artifact_s3_key = "{}/{}/{}".format(
134
+ self.artifact_prefix, self.trial_component_name, artifact_name
135
+ )
136
+ self._s3_client.put_object(
137
+ Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key
138
+ )
139
+ etag = self._try_get_etag(artifact_s3_key)
140
+ return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
141
+
142
+ def _try_get_etag(self, key):
143
+ """Get ETag of given key and return None if not allowed
144
+
145
+ Args:
146
+ key (str): The S3 object key.
147
+
148
+ Returns:
149
+ str: The S3 object ETag if it allows, otherwise return None.
150
+ """
151
+ try:
152
+ response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key)
153
+ return response["ETag"]
154
+ except botocore.exceptions.ClientError as error:
155
+ # requires read permissions
156
+ logger.warning("Failed to get ETag of %s due to %s", key, error)
157
+ return None
158
+
159
+
160
+ class _LineageArtifactManager(object):
161
+ """A helper class to manage Lineage Artifacts"""
162
+
163
+ def __init__(
164
+ self,
165
+ name,
166
+ source_uri,
167
+ etag,
168
+ source_arn=None,
169
+ dest_arn=None,
170
+ artifact_type=_DEFAULT_ARTIFACT_TYPE,
171
+ ):
172
+ """Initialize a `_LineageArtifactManager` instance.
173
+
174
+ Args:
175
+ name (str): The name of the Lineage artifact to be created.
176
+ source_uri (str): The source URI used to create the Lineage artifact.
177
+ etag (str): The S3 Etag used to create the Lineage artifact.
178
+ source_arn (str): The source ARN of a trail component to associate
179
+ this Lineage artifact with (default: None).
180
+ dest_arn (str): The destination ARN of a trial component to associate
181
+ this Lineage artifact with (default: None).
182
+ artifact_type (str): The type of the Lineage artifact (default: "Tracker").
183
+ """
184
+ self.name = name
185
+ self.source_uri = source_uri
186
+ self.etag = etag
187
+ self.source_arn = source_arn
188
+ self.dest_arn = dest_arn
189
+ self.artifact_arn = None
190
+ self.artifact_type = artifact_type
191
+
192
+ def create_artifact(self, sagemaker_session):
193
+ """Create the artifact by calling `CreateArtifact` API
194
+
195
+ Args:
196
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
197
+ manages interactions with Amazon SageMaker APIs and any other
198
+ AWS services needed.
199
+ """
200
+ source_ids = []
201
+ if self.etag:
202
+ source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag})
203
+
204
+ try:
205
+ response = sagemaker_session.sagemaker_client.create_artifact(
206
+ ArtifactName=self.name,
207
+ ArtifactType=self.artifact_type,
208
+ Source={"SourceUri": self.source_uri, "SourceTypes": source_ids},
209
+ )
210
+ self.artifact_arn = response["ArtifactArn"]
211
+ except botocore.exceptions.ClientError as err:
212
+ err_info = err.response["Error"]
213
+ if not is_already_exist_error(err_info):
214
+ raise
215
+ logger.warning(
216
+ "Skip creating the artifact since it already exists: %s", err_info["Message"]
217
+ )
218
+
219
+ def add_association(self, sagemaker_session):
220
+ """Associate the artifact with a source/destination ARN (e.g. trial component arn)
221
+
222
+ Args:
223
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
224
+ manages interactions with Amazon SageMaker APIs and any other
225
+ AWS services needed.
226
+ """
227
+ source_arn = self.source_arn if self.source_arn else self.artifact_arn
228
+ dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn
229
+ # if the trial component (job) is the source then it produced the artifact,
230
+ # otherwise the artifact contributed to the trial component (job)
231
+ association_edge_type = "Produced" if self.source_arn else "ContributedTo"
232
+ try:
233
+ sagemaker_session.sagemaker_client.add_association(
234
+ SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type
235
+ )
236
+ except botocore.exceptions.ClientError as err:
237
+ err_info = err.response["Error"]
238
+ if not is_already_exist_error(err_info):
239
+ raise
240
+ logger.warning(
241
+ "Skip associating since the association already exists: %s", err_info["Message"]
242
+ )
243
+
244
+
245
+ class _LineageArtifactTracker(object):
246
+ """Lineage Artifact Tracker"""
247
+
248
+ def __init__(self, trial_component_arn, sagemaker_session):
249
+ """Initialize a `_LineageArtifactTracker` instance.
250
+
251
+ Args:
252
+ trial_component_arn (str): The ARN of the trial component to be
253
+ associated with the input/output artifacts.
254
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
255
+ manages interactions with Amazon SageMaker APIs and any other
256
+ AWS services needed.
257
+ """
258
+ self.trial_component_arn = trial_component_arn
259
+ self.sagemaker_session = sagemaker_session
260
+ self.artifacts = []
261
+
262
+ def add_input_artifact(self, name, source_uri, etag, artifact_type):
263
+ """Add a Lineage input artifact locally
264
+
265
+ Args:
266
+ name (str): The name of the Lineage input artifact to be added.
267
+ source_uri (str): The source URI used to create the Lineage input artifact.
268
+ etag (str): The S3 Etag used to create the Lineage input artifact.
269
+ artifact_type (str): The type of the Lineage input artifact.
270
+ """
271
+ artifact = _LineageArtifactManager(
272
+ name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type
273
+ )
274
+ self.artifacts.append(artifact)
275
+
276
+ def add_output_artifact(self, name, source_uri, etag, artifact_type):
277
+ """Add a Lineage output artifact locally
278
+
279
+ Args:
280
+ name (str): The name of the Lineage output artifact to be added.
281
+ source_uri (str): The source URI used to create the Lineage output artifact.
282
+ etag (str): The S3 Etag used to create the Lineage output artifact.
283
+ artifact_type (str): The type of the Lineage output artifact.
284
+ """
285
+ artifact = _LineageArtifactManager(
286
+ name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type
287
+ )
288
+ self.artifacts.append(artifact)
289
+
290
+ def save(self):
291
+ """Persist any artifact data saved locally"""
292
+ for artifact in self.artifacts:
293
+ artifact.create_artifact(self.sagemaker_session)
294
+ artifact.add_association(self.sagemaker_session)
@@ -0,0 +1,333 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Contains classes to manage metrics for Sagemaker Experiment"""
14
+ from __future__ import absolute_import
15
+
16
+ import datetime
17
+ import logging
18
+ import os
19
+ import time
20
+ import threading
21
+ import queue
22
+
23
+ import dateutil.tz
24
+
25
+ from sagemaker.core.helper.session_helper import Session
26
+
27
+ METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", "")
28
+ METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds
29
+ METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds
30
+
31
+ BATCH_SIZE = 10
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class _RawMetricData(object):
38
+ """A Raw Metric Data Object"""
39
+
40
+ MetricName = None
41
+ Value = None
42
+ Timestamp = None
43
+ Step = None
44
+
45
+ def __init__(self, metric_name, value, timestamp=None, step=None):
46
+ """Construct a `_RawMetricData` instance.
47
+
48
+ Args:
49
+ metric_name (str): The name of the metric.
50
+ value (float): The value of the metric.
51
+ timestamp (datetime.datetime or float or str): Timestamp of the metric.
52
+ If not specified, the current UTC time will be used.
53
+ step (int): Iteration number of the metric (default: None).
54
+ """
55
+ if timestamp is None:
56
+ timestamp = time.time()
57
+ elif isinstance(timestamp, datetime.datetime):
58
+ # If the input is a datetime then convert it to UTC time.
59
+ # Assume a naive datetime is in local timezone
60
+ if not timestamp.tzinfo:
61
+ timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal())
62
+ timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc)
63
+ timestamp = timestamp.timestamp()
64
+ else:
65
+ timestamp = float(timestamp)
66
+
67
+ if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > (
68
+ time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW
69
+ ):
70
+ raise ValueError(
71
+ "Supplied timestamp %f is invalid."
72
+ " Timestamps must be between two weeks before and two hours from now." % timestamp
73
+ )
74
+ value = float(value)
75
+
76
+ self.MetricName = metric_name
77
+ self.Value = float(value)
78
+ self.Timestamp = timestamp
79
+ if step is not None:
80
+ if not isinstance(step, int):
81
+ raise ValueError("step must be int.")
82
+ self.Step = step
83
+
84
+ def to_record(self):
85
+ """Convert the `_RawMetricData` object to dict"""
86
+ return self.__dict__
87
+
88
+ def to_raw_metric_data(self):
89
+ """Converts the metric data to a BatchPutMetrics RawMetricData item"""
90
+ # Convert timestamp from float to timestamp str.
91
+ # Otherwise will get ParamValidationError
92
+ raw_metric_data = {
93
+ "MetricName": self.MetricName,
94
+ "Value": self.Value,
95
+ "Timestamp": str(int(self.Timestamp)),
96
+ }
97
+ if self.Step is not None:
98
+ raw_metric_data["Step"] = int(self.Step)
99
+ return raw_metric_data
100
+
101
+ def __str__(self):
102
+ """String representation of the `_RawMetricData` object."""
103
+ return repr(self)
104
+
105
+ def __repr__(self):
106
+ """Return a string representation of this _RawMetricData` object."""
107
+ return "{}({})".format(
108
+ type(self).__name__,
109
+ ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
110
+ )
111
+
112
+
113
+ class _MetricsManager(object):
114
+ """Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
115
+
116
+ def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None:
117
+ """Initialize a `_MetricsManager` instance
118
+
119
+ Args:
120
+ trial_component_name (str): The Name of the Trial Component to log metrics to
121
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
122
+ manages interactions with Amazon SageMaker APIs and any other
123
+ AWS services needed. If not specified, one is created using the
124
+ default AWS configuration chain.
125
+ sink (object): The metrics sink to use.
126
+ """
127
+ if sink is None:
128
+ self.sink = _SyncMetricsSink(
129
+ trial_component_name, sagemaker_session.sagemaker_metrics_client
130
+ )
131
+ else:
132
+ self.sink = sink
133
+
134
+ def log_metric(self, metric_name, value, timestamp=None, step=None):
135
+ """Sends a metric to metrics service."""
136
+
137
+ metric_data = _RawMetricData(metric_name, value, timestamp, step)
138
+ self.sink.log_metric(metric_data)
139
+
140
+ def __enter__(self):
141
+ """Return self"""
142
+ return self
143
+
144
+ def __exit__(self, exc_type, exc_value, exc_traceback):
145
+ """Execute self.close()"""
146
+ self.sink.close()
147
+
148
+ def close(self):
149
+ """Close the metrics object."""
150
+ self.sink.close()
151
+
152
+
153
+ class _SyncMetricsSink(object):
154
+ """Collects metrics and sends them directly to metrics service."""
155
+
156
+ def __init__(self, trial_component_name, metrics_client) -> None:
157
+ """Initialize a `_SyncMetricsSink` instance
158
+
159
+ Args:
160
+ trial_component_name (str): The Name of the Trial Component to log metrics.
161
+ metrics_client (boto3.client): boto client for metrics service
162
+ """
163
+ self._trial_component_name = trial_component_name
164
+ self._metrics_client = metrics_client
165
+ self._buffer = []
166
+
167
+ def log_metric(self, metric_data):
168
+ """Sends a metric to metrics service."""
169
+
170
+ # this is a simplistic solution which calls BatchPutMetrics
171
+ # on the same thread as the client code
172
+ self._buffer.append(metric_data)
173
+ self._drain()
174
+
175
+ def _drain(self, close=False):
176
+ """Pops off all metrics in the buffer and starts sending them to metrics service."""
177
+
178
+ if not self._buffer:
179
+ return
180
+
181
+ if len(self._buffer) < BATCH_SIZE and not close:
182
+ return
183
+
184
+ # pop all the available metrics
185
+ available_metrics, self._buffer = self._buffer, []
186
+
187
+ self._send_metrics(available_metrics)
188
+
189
+ def _send_metrics(self, metrics):
190
+ """Calls BatchPutMetrics directly on the metrics service."""
191
+ while metrics:
192
+ batch, metrics = (
193
+ metrics[:BATCH_SIZE],
194
+ metrics[BATCH_SIZE:],
195
+ )
196
+ request = self._construct_batch_put_metrics_request(batch)
197
+ response = self._metrics_client.batch_put_metrics(**request)
198
+ errors = response["Errors"] if "Errors" in response else None
199
+ if errors:
200
+ error_code = errors[0]["Code"]
201
+ raise Exception(f'{len(errors)} errors with error code "{error_code}"')
202
+
203
+ def _construct_batch_put_metrics_request(self, batch):
204
+ """Creates dictionary object used as request to metrics service."""
205
+ return {
206
+ "TrialComponentName": self._trial_component_name.lower(),
207
+ "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
208
+ }
209
+
210
+ def close(self):
211
+ """Drains any remaining metrics."""
212
+ self._drain(close=True)
213
+
214
+
215
+ class _MetricQueue(object):
216
+ """A thread safe queue for sending metrics to SageMaker.
217
+
218
+ Args:
219
+ trial_component_name (str): the ARN of the resource
220
+ metric_name (str): the name of the metric
221
+ metrics_client (boto_client): the boto client for SageMaker Metrics service
222
+ """
223
+
224
+ _CONSUMER_SLEEP_SECONDS = 5
225
+
226
+ def __init__(self, trial_component_name, metric_name, metrics_client):
227
+ # infinite queue size
228
+ self._queue = queue.Queue()
229
+ self._buffer = []
230
+ self._thread = threading.Thread(target=self._run)
231
+ self._started = False
232
+ self._finished = False
233
+ self._trial_component_name = trial_component_name
234
+ self._metrics_client = metrics_client
235
+ self._metric_name = metric_name
236
+ self._logged_metrics = 0
237
+
238
+ def log_metric(self, metric_data):
239
+ """Adds a metric data point to the queue"""
240
+ self._buffer.append(metric_data)
241
+
242
+ if len(self._buffer) < BATCH_SIZE:
243
+ return
244
+
245
+ self._enqueue_all()
246
+
247
+ if not self._started:
248
+ self._thread.start()
249
+ self._started = True
250
+
251
+ def _run(self):
252
+ """Starts the metric thread which sends metrics to SageMaker in batches"""
253
+
254
+ while not self._queue.empty() or not self._finished:
255
+ if self._queue.empty():
256
+ time.sleep(self._CONSUMER_SLEEP_SECONDS)
257
+ else:
258
+ batch = self._queue.get()
259
+ self._send_metrics(batch)
260
+
261
+ def _send_metrics(self, metrics_batch):
262
+ """Calls BatchPutMetrics directly on the metrics service."""
263
+ request = self._construct_batch_put_metrics_request(metrics_batch)
264
+ self._logged_metrics += len(metrics_batch)
265
+ self._metrics_client.batch_put_metrics(**request)
266
+
267
+ def _construct_batch_put_metrics_request(self, batch):
268
+ """Creates dictionary object used as request to metrics service."""
269
+
270
+ return {
271
+ "TrialComponentName": self._trial_component_name,
272
+ "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
273
+ }
274
+
275
+ def _enqueue_all(self):
276
+ """Enqueue all buffered metrics to be sent to SageMaker"""
277
+
278
+ available_metrics, self._buffer = self._buffer, []
279
+ if available_metrics:
280
+ self._queue.put(available_metrics)
281
+
282
+ def close(self):
283
+ """Flushes any buffered metrics"""
284
+
285
+ self._enqueue_all()
286
+ self._finished = True
287
+
288
+ def is_active(self):
289
+ """Is the thread active (still draining metrics to SageMaker)"""
290
+
291
+ return self._thread.is_alive()
292
+
293
+
294
+ class _AsyncMetricsSink(object):
295
+ """Collects metrics and sends them directly to metrics service."""
296
+
297
+ _COMPLETE_SLEEP_SECONDS = 1.0
298
+
299
+ def __init__(self, trial_component_name, metrics_client) -> None:
300
+ """Initialize a `_AsyncMetricsSink` instance
301
+
302
+ Args:
303
+ trial_component_name (str): The Name of the Trial Component to log metrics to.
304
+ metrics_client (boto3.client): boto client for metrics service
305
+ """
306
+ self._trial_component_name = trial_component_name
307
+ self._metrics_client = metrics_client
308
+ self._buffer = []
309
+ self._is_draining = False
310
+ self._metric_queues = {}
311
+
312
+ def log_metric(self, metric_data):
313
+ """Sends a metric to metrics service."""
314
+
315
+ if metric_data.MetricName in self._metric_queues:
316
+ self._metric_queues[metric_data.MetricName].log_metric(metric_data)
317
+ else:
318
+ cur_metric_queue = _MetricQueue(
319
+ self._trial_component_name, metric_data.MetricName, self._metrics_client
320
+ )
321
+ self._metric_queues[metric_data.MetricName] = cur_metric_queue
322
+ cur_metric_queue.log_metric(metric_data)
323
+
324
+ def close(self):
325
+ """Closes the metric file."""
326
+ logging.debug("Closing")
327
+ for q in self._metric_queues.values():
328
+ q.close()
329
+
330
+ # TODO should probably use join
331
+ while any(map(lambda x: x.is_active(), self._metric_queues.values())):
332
+ time.sleep(self._COMPLETE_SLEEP_SECONDS)
333
+ logging.debug("Closed")
@@ -0,0 +1,58 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Contains the SageMaker Experiment _RunContext class."""
14
+ from __future__ import absolute_import
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ from sagemaker.core.experiments import Run
20
+
21
+
22
+ class _RunContext:
23
+ """A static context variable to keep track of the current Run object"""
24
+
25
+ _context_run = None
26
+
27
+ @classmethod
28
+ def add_run_object(cls, run: "Run"):
29
+ """Keep track of the current executing Run object
30
+
31
+ by adding it to a class static variable.
32
+
33
+ Args:
34
+ run (Run): The current Run object to be tracked.
35
+ """
36
+ cls._context_run = run
37
+
38
+ @classmethod
39
+ def drop_current_run(cls) -> "Run":
40
+ """Drop the Run object tracked in the global static variable
41
+
42
+ as its execution finishes (its "with" block ends).
43
+
44
+ Return:
45
+ Run: the dropped Run object.
46
+ """
47
+ current_run = cls._context_run
48
+ cls._context_run = None
49
+ return current_run
50
+
51
+ @classmethod
52
+ def get_current_run(cls) -> "Run":
53
+ """Return the current Run object without dropping it.
54
+
55
+ Return:
56
+ Run: the current Run object to be returned.
57
+ """
58
+ return cls._context_run