sagemaker-core 1.0.47__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (363) hide show
  1. sagemaker/core/__init__.py +16 -0
  2. sagemaker/core/_studio.py +116 -0
  3. sagemaker/core/_version.py +11 -0
  4. sagemaker/core/accept_types.py +131 -0
  5. sagemaker/core/analytics.py +744 -0
  6. sagemaker/core/apiutils/__init__.py +13 -0
  7. sagemaker/core/apiutils/_base_types.py +228 -0
  8. sagemaker/core/apiutils/_boto_functions.py +130 -0
  9. sagemaker/core/apiutils/_utils.py +34 -0
  10. sagemaker/core/base_deserializers.py +35 -0
  11. sagemaker/core/base_serializers.py +35 -0
  12. sagemaker/core/clarify/__init__.py +2898 -0
  13. sagemaker/core/collection.py +467 -0
  14. sagemaker/core/common_utils.py +2281 -0
  15. sagemaker/core/compute_resource_requirements/__init__.py +18 -0
  16. sagemaker/core/compute_resource_requirements/resource_requirements.py +94 -0
  17. sagemaker/core/config/__init__.py +181 -0
  18. sagemaker/core/config/config.py +238 -0
  19. sagemaker/core/config/config_manager.py +595 -0
  20. sagemaker/core/config/config_schema.py +1220 -0
  21. sagemaker/core/config/config_utils.py +297 -0
  22. {sagemaker_core/main → sagemaker/core}/config_schema.py +410 -4
  23. sagemaker/core/constants.py +73 -0
  24. sagemaker/core/content_types.py +137 -0
  25. sagemaker/core/debugger/__init__.py +39 -0
  26. sagemaker/core/debugger/debugger.py +945 -0
  27. sagemaker/core/debugger/framework_profile.py +292 -0
  28. sagemaker/core/debugger/metrics_config.py +468 -0
  29. sagemaker/core/debugger/profiler.py +42 -0
  30. sagemaker/core/debugger/profiler_config.py +190 -0
  31. sagemaker/core/debugger/profiler_constants.py +40 -0
  32. sagemaker/core/debugger/utils.py +148 -0
  33. sagemaker/core/deprecations.py +254 -0
  34. sagemaker/core/deserializers/__init__.py +10 -0
  35. sagemaker/core/deserializers/base.py +424 -0
  36. sagemaker/core/deserializers/implementations.py +157 -0
  37. sagemaker/core/drift_check_baselines.py +106 -0
  38. sagemaker/core/enums.py +51 -0
  39. sagemaker/core/environment_variables.py +101 -0
  40. sagemaker/core/exceptions.py +108 -0
  41. sagemaker/core/experiments/__init__.py +53 -0
  42. sagemaker/core/experiments/_api_types.py +251 -0
  43. sagemaker/core/experiments/_environment.py +124 -0
  44. sagemaker/core/experiments/_helper.py +294 -0
  45. sagemaker/core/experiments/_metrics.py +333 -0
  46. sagemaker/core/experiments/_run_context.py +58 -0
  47. sagemaker/core/experiments/_utils.py +216 -0
  48. sagemaker/core/experiments/experiment.py +244 -0
  49. sagemaker/core/experiments/run.py +970 -0
  50. sagemaker/core/experiments/trial.py +296 -0
  51. sagemaker/core/experiments/trial_component.py +387 -0
  52. sagemaker/core/explainer/__init__.py +24 -0
  53. sagemaker/core/explainer/clarify_explainer_config.py +298 -0
  54. sagemaker/core/explainer/explainer_config.py +44 -0
  55. sagemaker/core/fw_utils.py +1176 -0
  56. sagemaker/core/git_utils.py +349 -0
  57. sagemaker/core/helper/pipeline_variable.py +82 -0
  58. sagemaker/core/helper/session_helper.py +2965 -0
  59. sagemaker/core/huggingface/__init__.py +29 -0
  60. sagemaker/core/huggingface/llm_utils.py +150 -0
  61. sagemaker/core/huggingface/processing.py +139 -0
  62. sagemaker/core/huggingface/training_compiler/config.py +167 -0
  63. sagemaker/core/hyperparameters.py +172 -0
  64. sagemaker/core/image_retriever/__init__.py +3 -0
  65. sagemaker/core/image_retriever/image_retriever.py +640 -0
  66. sagemaker/core/image_retriever/image_retriever_utils.py +511 -0
  67. sagemaker/core/image_retriever/test.py +7 -0
  68. sagemaker/core/image_uri_config/__init__.py +13 -0
  69. sagemaker/core/image_uri_config/autogluon.json +1335 -0
  70. sagemaker/core/image_uri_config/blazingtext.json +50 -0
  71. sagemaker/core/image_uri_config/chainer.json +104 -0
  72. sagemaker/core/image_uri_config/clarify.json +39 -0
  73. sagemaker/core/image_uri_config/coach-mxnet.json +70 -0
  74. sagemaker/core/image_uri_config/coach-tensorflow.json +186 -0
  75. sagemaker/core/image_uri_config/data-wrangler.json +91 -0
  76. sagemaker/core/image_uri_config/debugger.json +34 -0
  77. sagemaker/core/image_uri_config/detailed-profiler.json +18 -0
  78. sagemaker/core/image_uri_config/djl-deepspeed.json +385 -0
  79. sagemaker/core/image_uri_config/djl-fastertransformer.json +167 -0
  80. sagemaker/core/image_uri_config/djl-lmi.json +136 -0
  81. sagemaker/core/image_uri_config/djl-neuronx.json +258 -0
  82. sagemaker/core/image_uri_config/djl-tensorrtllm.json +262 -0
  83. sagemaker/core/image_uri_config/factorization-machines.json +50 -0
  84. sagemaker/core/image_uri_config/forecasting-deepar.json +50 -0
  85. sagemaker/core/image_uri_config/huggingface-llm-neuronx.json +660 -0
  86. sagemaker/core/image_uri_config/huggingface-llm.json +1158 -0
  87. sagemaker/core/image_uri_config/huggingface-neuron.json +52 -0
  88. sagemaker/core/image_uri_config/huggingface-neuronx.json +510 -0
  89. sagemaker/core/image_uri_config/huggingface-tei-cpu.json +298 -0
  90. sagemaker/core/image_uri_config/huggingface-tei.json +298 -0
  91. sagemaker/core/image_uri_config/huggingface-training-compiler.json +195 -0
  92. sagemaker/core/image_uri_config/huggingface.json +2138 -0
  93. sagemaker/core/image_uri_config/hyperpod-recipes-neuron.json +52 -0
  94. sagemaker/core/image_uri_config/image-classification-neo.json +43 -0
  95. sagemaker/core/image_uri_config/image-classification.json +50 -0
  96. sagemaker/core/image_uri_config/inferentia-mxnet.json +88 -0
  97. sagemaker/core/image_uri_config/inferentia-pytorch.json +127 -0
  98. sagemaker/core/image_uri_config/inferentia-tensorflow.json +88 -0
  99. sagemaker/core/image_uri_config/instance_gpu_info.json +782 -0
  100. sagemaker/core/image_uri_config/ipinsights.json +50 -0
  101. sagemaker/core/image_uri_config/kmeans.json +50 -0
  102. sagemaker/core/image_uri_config/knn.json +50 -0
  103. sagemaker/core/image_uri_config/lda.json +26 -0
  104. sagemaker/core/image_uri_config/linear-learner.json +50 -0
  105. sagemaker/core/image_uri_config/model-monitor.json +42 -0
  106. sagemaker/core/image_uri_config/mxnet.json +1154 -0
  107. sagemaker/core/image_uri_config/neo-mxnet.json +64 -0
  108. sagemaker/core/image_uri_config/neo-pytorch.json +341 -0
  109. sagemaker/core/image_uri_config/neo-tensorflow.json +109 -0
  110. sagemaker/core/image_uri_config/ntm.json +50 -0
  111. sagemaker/core/image_uri_config/object-detection.json +50 -0
  112. sagemaker/core/image_uri_config/object2vec.json +50 -0
  113. sagemaker/core/image_uri_config/pca.json +50 -0
  114. sagemaker/core/image_uri_config/pytorch-neuron.json +43 -0
  115. sagemaker/core/image_uri_config/pytorch-smp.json +218 -0
  116. sagemaker/core/image_uri_config/pytorch-training-compiler.json +80 -0
  117. sagemaker/core/image_uri_config/pytorch.json +3101 -0
  118. sagemaker/core/image_uri_config/randomcutforest.json +50 -0
  119. sagemaker/core/image_uri_config/ray-pytorch.json +46 -0
  120. sagemaker/core/image_uri_config/ray-tensorflow.json +194 -0
  121. sagemaker/core/image_uri_config/sagemaker-base-python.json +46 -0
  122. sagemaker/core/image_uri_config/sagemaker-distribution.json +37 -0
  123. sagemaker/core/image_uri_config/sagemaker-geospatial.json +13 -0
  124. sagemaker/core/image_uri_config/sagemaker-tritonserver.json +212 -0
  125. sagemaker/core/image_uri_config/semantic-segmentation.json +50 -0
  126. sagemaker/core/image_uri_config/seq2seq.json +50 -0
  127. sagemaker/core/image_uri_config/sklearn.json +446 -0
  128. sagemaker/core/image_uri_config/spark.json +280 -0
  129. sagemaker/core/image_uri_config/sparkml-serving.json +97 -0
  130. sagemaker/core/image_uri_config/stabilityai.json +53 -0
  131. sagemaker/core/image_uri_config/tensorflow.json +5086 -0
  132. sagemaker/core/image_uri_config/vw.json +25 -0
  133. sagemaker/core/image_uri_config/xgboost-neo.json +43 -0
  134. sagemaker/core/image_uri_config/xgboost.json +888 -0
  135. sagemaker/core/image_uris.py +810 -0
  136. sagemaker/core/inference_config.py +144 -0
  137. sagemaker/core/inference_recommender/__init__.py +18 -0
  138. sagemaker/core/inference_recommender/inference_recommender_mixin.py +622 -0
  139. sagemaker/core/inputs.py +366 -0
  140. sagemaker/core/instance_group.py +61 -0
  141. sagemaker/core/instance_types.py +164 -0
  142. sagemaker/core/instance_types_gpu_info.py +43 -0
  143. sagemaker/core/interactive_apps/__init__.py +41 -0
  144. sagemaker/core/interactive_apps/base_interactive_app.py +204 -0
  145. sagemaker/core/interactive_apps/detail_profiler_app.py +139 -0
  146. sagemaker/core/interactive_apps/tensorboard.py +149 -0
  147. sagemaker/core/iterators.py +186 -0
  148. sagemaker/core/job.py +380 -0
  149. sagemaker/core/jumpstart/__init__.py +156 -0
  150. sagemaker/core/jumpstart/accessors.py +390 -0
  151. sagemaker/core/jumpstart/artifacts/__init__.py +69 -0
  152. sagemaker/core/jumpstart/artifacts/environment_variables.py +252 -0
  153. sagemaker/core/jumpstart/artifacts/hyperparameters.py +120 -0
  154. sagemaker/core/jumpstart/artifacts/image_uris.py +139 -0
  155. sagemaker/core/jumpstart/artifacts/incremental_training.py +87 -0
  156. sagemaker/core/jumpstart/artifacts/instance_types.py +223 -0
  157. sagemaker/core/jumpstart/artifacts/kwargs.py +289 -0
  158. sagemaker/core/jumpstart/artifacts/metric_definitions.py +117 -0
  159. sagemaker/core/jumpstart/artifacts/model_packages.py +202 -0
  160. sagemaker/core/jumpstart/artifacts/model_uris.py +252 -0
  161. sagemaker/core/jumpstart/artifacts/payloads.py +96 -0
  162. sagemaker/core/jumpstart/artifacts/predictors.py +540 -0
  163. sagemaker/core/jumpstart/artifacts/resource_names.py +86 -0
  164. sagemaker/core/jumpstart/artifacts/resource_requirements.py +162 -0
  165. sagemaker/core/jumpstart/artifacts/script_uris.py +172 -0
  166. sagemaker/core/jumpstart/cache.py +663 -0
  167. sagemaker/core/jumpstart/configs.py +50 -0
  168. sagemaker/core/jumpstart/constants.py +198 -0
  169. sagemaker/core/jumpstart/deserializers.py +81 -0
  170. sagemaker/core/jumpstart/document.py +76 -0
  171. sagemaker/core/jumpstart/enums.py +168 -0
  172. sagemaker/core/jumpstart/exceptions.py +236 -0
  173. sagemaker/core/jumpstart/factory/utils.py +833 -0
  174. sagemaker/core/jumpstart/filters.py +597 -0
  175. sagemaker/core/jumpstart/hub/__init__.py +0 -0
  176. sagemaker/core/jumpstart/hub/constants.py +16 -0
  177. sagemaker/core/jumpstart/hub/hub.py +291 -0
  178. sagemaker/core/jumpstart/hub/interfaces.py +936 -0
  179. sagemaker/core/jumpstart/hub/parser_utils.py +70 -0
  180. sagemaker/core/jumpstart/hub/parsers.py +288 -0
  181. sagemaker/core/jumpstart/hub/types.py +35 -0
  182. sagemaker/core/jumpstart/hub/utils.py +260 -0
  183. sagemaker/core/jumpstart/models.py +499 -0
  184. sagemaker/core/jumpstart/notebook_utils.py +575 -0
  185. sagemaker/core/jumpstart/parameters.py +20 -0
  186. sagemaker/core/jumpstart/payload_utils.py +239 -0
  187. sagemaker/core/jumpstart/region_config.json +163 -0
  188. sagemaker/core/jumpstart/search.py +171 -0
  189. sagemaker/core/jumpstart/serializers.py +81 -0
  190. sagemaker/core/jumpstart/session_utils.py +234 -0
  191. sagemaker/core/jumpstart/types.py +3044 -0
  192. sagemaker/core/jumpstart/utils.py +1731 -0
  193. sagemaker/core/jumpstart/validators.py +257 -0
  194. sagemaker/core/lambda_helper.py +312 -0
  195. sagemaker/core/lineage/__init__.py +42 -0
  196. sagemaker/core/lineage/_api_types.py +239 -0
  197. sagemaker/core/lineage/_utils.py +49 -0
  198. sagemaker/core/lineage/action.py +345 -0
  199. sagemaker/core/lineage/artifact.py +646 -0
  200. sagemaker/core/lineage/association.py +190 -0
  201. sagemaker/core/lineage/context.py +505 -0
  202. sagemaker/core/lineage/lineage_trial_component.py +191 -0
  203. sagemaker/core/lineage/query.py +732 -0
  204. sagemaker/core/lineage/visualizer.py +346 -0
  205. sagemaker/core/local/__init__.py +18 -0
  206. sagemaker/core/local/data.py +413 -0
  207. sagemaker/core/local/entities.py +678 -0
  208. sagemaker/core/local/exceptions.py +17 -0
  209. sagemaker/core/local/image.py +1243 -0
  210. sagemaker/core/local/local_session.py +739 -0
  211. sagemaker/core/local/utils.py +245 -0
  212. sagemaker/core/logs.py +181 -0
  213. sagemaker/core/metadata_properties.py +56 -0
  214. sagemaker/core/metric_definitions.py +91 -0
  215. sagemaker/core/mlflow/__init__.py +38 -0
  216. sagemaker/core/mlflow/forward_sagemaker_metrics.py +44 -0
  217. sagemaker/core/model_card/__init__.py +26 -0
  218. sagemaker/core/model_life_cycle.py +51 -0
  219. sagemaker/core/model_metrics.py +160 -0
  220. sagemaker/core/model_monitor/__init__.py +66 -0
  221. sagemaker/core/model_monitor/clarify_model_monitoring.py +1495 -0
  222. sagemaker/core/model_monitor/cron_expression_generator.py +82 -0
  223. sagemaker/core/model_monitor/data_capture_config.py +115 -0
  224. sagemaker/core/model_monitor/data_quality_monitoring_config.py +66 -0
  225. sagemaker/core/model_monitor/dataset_format.py +102 -0
  226. sagemaker/core/model_monitor/model_monitoring.py +4266 -0
  227. sagemaker/core/model_monitor/monitoring_alert.py +76 -0
  228. sagemaker/core/model_monitor/monitoring_files.py +506 -0
  229. sagemaker/core/model_monitor/utils.py +793 -0
  230. sagemaker/core/model_registry.py +480 -0
  231. sagemaker/core/model_uris.py +97 -0
  232. sagemaker/core/modules/__init__.py +19 -0
  233. sagemaker/core/modules/configs.py +226 -0
  234. sagemaker/core/modules/constants.py +37 -0
  235. sagemaker/core/modules/distributed.py +182 -0
  236. sagemaker/core/modules/local_core/__init__.py +0 -0
  237. sagemaker/core/modules/local_core/local_container.py +605 -0
  238. sagemaker/core/modules/templates.py +83 -0
  239. sagemaker/core/modules/train/__init__.py +14 -0
  240. sagemaker/core/modules/train/container_drivers/__init__.py +14 -0
  241. sagemaker/core/modules/train/container_drivers/common/__init__.py +14 -0
  242. sagemaker/core/modules/train/container_drivers/common/utils.py +213 -0
  243. sagemaker/core/modules/train/container_drivers/distributed_drivers/__init__.py +14 -0
  244. sagemaker/core/modules/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  245. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_driver.py +123 -0
  246. sagemaker/core/modules/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  247. sagemaker/core/modules/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  248. sagemaker/core/modules/train/container_drivers/scripts/__init__.py +14 -0
  249. sagemaker/core/modules/train/container_drivers/scripts/environment.py +305 -0
  250. sagemaker/core/modules/train/sm_recipes/__init__.py +0 -0
  251. sagemaker/core/modules/train/sm_recipes/utils.py +330 -0
  252. sagemaker/core/modules/types.py +19 -0
  253. sagemaker/core/modules/utils.py +194 -0
  254. sagemaker/core/network.py +185 -0
  255. sagemaker/core/parameter.py +173 -0
  256. sagemaker/core/payloads.py +185 -0
  257. sagemaker/core/processing.py +1597 -0
  258. sagemaker/core/remote_function/__init__.py +19 -0
  259. sagemaker/core/remote_function/checkpoint_location.py +47 -0
  260. sagemaker/core/remote_function/client.py +1285 -0
  261. sagemaker/core/remote_function/core/__init__.py +0 -0
  262. sagemaker/core/remote_function/core/_custom_dispatch_table.py +72 -0
  263. sagemaker/core/remote_function/core/pipeline_variables.py +353 -0
  264. sagemaker/core/remote_function/core/serialization.py +422 -0
  265. sagemaker/core/remote_function/core/stored_function.py +226 -0
  266. sagemaker/core/remote_function/custom_file_filter.py +128 -0
  267. sagemaker/core/remote_function/errors.py +104 -0
  268. sagemaker/core/remote_function/invoke_function.py +172 -0
  269. sagemaker/core/remote_function/job.py +2140 -0
  270. sagemaker/core/remote_function/logging_config.py +38 -0
  271. sagemaker/core/remote_function/runtime_environment/__init__.py +14 -0
  272. sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +605 -0
  273. sagemaker/core/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  274. sagemaker/core/remote_function/runtime_environment/runtime_environment_manager.py +554 -0
  275. sagemaker/core/remote_function/runtime_environment/spark_app.py +18 -0
  276. sagemaker/core/remote_function/spark_config.py +149 -0
  277. sagemaker/core/resource_requirements.py +168 -0
  278. {sagemaker_core/main → sagemaker/core}/resources.py +20121 -11728
  279. sagemaker/core/s3/__init__.py +41 -0
  280. sagemaker/core/s3/client.py +367 -0
  281. sagemaker/core/s3/utils.py +175 -0
  282. sagemaker/core/script_uris.py +93 -0
  283. sagemaker/core/serializers/__init__.py +11 -0
  284. sagemaker/core/serializers/base.py +510 -0
  285. sagemaker/core/serializers/implementations.py +159 -0
  286. sagemaker/core/serializers/utils.py +223 -0
  287. sagemaker/core/serverless_inference_config.py +63 -0
  288. sagemaker/core/session_settings.py +55 -0
  289. sagemaker/core/shapes/__init__.py +3 -0
  290. sagemaker/core/shapes/model_card_shapes.py +159 -0
  291. {sagemaker_core/main → sagemaker/core/shapes}/shapes.py +6384 -1865
  292. sagemaker/core/spark/__init__.py +16 -0
  293. sagemaker/core/spark/defaults.py +16 -0
  294. sagemaker/core/spark/processing.py +1380 -0
  295. sagemaker/core/telemetry/__init__.py +23 -0
  296. sagemaker/core/telemetry/constants.py +84 -0
  297. sagemaker/core/telemetry/telemetry_logging.py +284 -0
  298. sagemaker/core/tools/__init__.py +1 -0
  299. {sagemaker_core → sagemaker/core}/tools/codegen.py +4 -4
  300. {sagemaker_core → sagemaker/core}/tools/constants.py +23 -15
  301. {sagemaker_core → sagemaker/core}/tools/data_extractor.py +1 -1
  302. {sagemaker_core → sagemaker/core}/tools/method.py +1 -1
  303. sagemaker/core/tools/model_card/generate_model_card_from_schema.py +562 -0
  304. {sagemaker_core → sagemaker/core}/tools/resources_codegen.py +165 -98
  305. {sagemaker_core → sagemaker/core}/tools/resources_extractor.py +5 -13
  306. {sagemaker_core → sagemaker/core}/tools/shapes_codegen.py +16 -17
  307. {sagemaker_core → sagemaker/core}/tools/shapes_extractor.py +29 -67
  308. {sagemaker_core → sagemaker/core}/tools/templates.py +39 -17
  309. sagemaker/core/training/__init__.py +14 -0
  310. sagemaker/core/training/configs.py +333 -0
  311. sagemaker/core/training/constants.py +37 -0
  312. sagemaker/core/training/utils.py +77 -0
  313. sagemaker/core/training_compiler/__init__.py +16 -0
  314. sagemaker/core/training_compiler/config.py +197 -0
  315. sagemaker/core/training_compiler_config.py +197 -0
  316. sagemaker/core/transformer.py +793 -0
  317. sagemaker/core/user_agent.py +76 -0
  318. sagemaker/core/utilities/__init__.py +24 -0
  319. sagemaker/core/utilities/cache.py +169 -0
  320. sagemaker/core/utilities/search_expression.py +133 -0
  321. sagemaker/core/utils/__init__.py +48 -0
  322. sagemaker/core/utils/code_injection/__init__.py +0 -0
  323. {sagemaker_core/main → sagemaker/core/utils}/code_injection/codec.py +2 -2
  324. {sagemaker_core/main → sagemaker/core/utils}/code_injection/shape_dag.py +6479 -136
  325. {sagemaker_core/main → sagemaker/core/utils}/exceptions.py +8 -8
  326. sagemaker_core/main/default_configs_helper.py → sagemaker/core/utils/intelligent_defaults_helper.py +5 -6
  327. {sagemaker_core/main → sagemaker/core/utils}/logs.py +1 -2
  328. {sagemaker_core/main → sagemaker/core/utils}/utils.py +25 -20
  329. sagemaker/core/workflow/__init__.py +152 -0
  330. sagemaker/core/workflow/conditions.py +313 -0
  331. sagemaker/core/workflow/entities.py +58 -0
  332. sagemaker/core/workflow/execution_variables.py +89 -0
  333. sagemaker/core/workflow/functions.py +193 -0
  334. sagemaker/core/workflow/parameters.py +222 -0
  335. sagemaker/core/workflow/pipeline_context.py +394 -0
  336. sagemaker/core/workflow/pipeline_definition_config.py +31 -0
  337. sagemaker/core/workflow/properties.py +285 -0
  338. sagemaker/core/workflow/step_outputs.py +65 -0
  339. sagemaker/core/workflow/utilities.py +507 -0
  340. sagemaker/lineage/__init__.py +33 -0
  341. sagemaker/lineage/action.py +28 -0
  342. sagemaker/lineage/artifact.py +28 -0
  343. sagemaker/lineage/context.py +28 -0
  344. sagemaker/lineage/lineage_trial_component.py +28 -0
  345. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/METADATA +28 -9
  346. sagemaker_core-2.1.1.dist-info/RECORD +355 -0
  347. sagemaker_core-2.1.1.dist-info/top_level.txt +1 -0
  348. sagemaker_core/__init__.py +0 -4
  349. sagemaker_core/_version.py +0 -3
  350. sagemaker_core/helper/session_helper.py +0 -769
  351. sagemaker_core/resources/__init__.py +0 -1
  352. sagemaker_core/shapes/__init__.py +0 -1
  353. sagemaker_core/tools/__init__.py +0 -1
  354. sagemaker_core-1.0.47.dist-info/RECORD +0 -35
  355. sagemaker_core-1.0.47.dist-info/top_level.txt +0 -1
  356. {sagemaker_core → sagemaker/core}/helper/__init__.py +0 -0
  357. {sagemaker_core/main → sagemaker/core/huggingface/training_compiler}/__init__.py +0 -0
  358. {sagemaker_core/main/code_injection → sagemaker/core/jumpstart/factory}/__init__.py +0 -0
  359. {sagemaker_core/main → sagemaker/core/utils}/code_injection/base.py +0 -0
  360. {sagemaker_core/main → sagemaker/core/utils}/code_injection/constants.py +0 -0
  361. {sagemaker_core/main → sagemaker/core/utils}/user_agent.py +0 -0
  362. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/WHEEL +0 -0
  363. {sagemaker_core-1.0.47.dist-info → sagemaker_core-2.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,970 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Contains the SageMaker Experiment Run class."""
14
+ from __future__ import absolute_import
15
+
16
+ import datetime
17
+ import logging
18
+ from enum import Enum
19
+ from math import isnan, isinf
20
+ from numbers import Number
21
+ from typing import Optional, List, Dict, TYPE_CHECKING, Union
22
+
23
+ import dateutil
24
+ from numpy import array
25
+
26
+ from sagemaker.core.apiutils import _utils
27
+ import sagemaker.core.experiments._api_types as _api_types
28
+ from sagemaker.core.experiments._api_types import (
29
+ TrialComponentArtifact,
30
+ _TrialComponentStatusType,
31
+ )
32
+ from sagemaker.core.experiments._helper import (
33
+ _ArtifactUploader,
34
+ _LineageArtifactTracker,
35
+ _DEFAULT_ARTIFACT_PREFIX,
36
+ )
37
+ from sagemaker.core.experiments._environment import _RunEnvironment
38
+ from sagemaker.core.experiments._run_context import _RunContext
39
+ from sagemaker.core.experiments.experiment import Experiment
40
+ from sagemaker.core.experiments._metrics import _MetricsManager
41
+ from sagemaker.core.experiments.trial import _Trial
42
+ from sagemaker.core.experiments.trial_component import _TrialComponent
43
+
44
+ from sagemaker.core.common_utils import (
45
+ get_module,
46
+ unique_name_from_base,
47
+ format_tags,
48
+ Tags,
49
+ TagsDict,
50
+ )
51
+
52
+ from sagemaker.core.experiments._utils import (
53
+ guess_media_type,
54
+ resolve_artifact_name,
55
+ verify_length_of_true_and_predicted,
56
+ validate_invoked_inside_run_context,
57
+ get_tc_and_exp_config_from_job_env,
58
+ verify_load_input_names,
59
+ is_run_trial_component,
60
+ )
61
+
62
+ if TYPE_CHECKING:
63
+ from sagemaker.core.helper.session_helper import Session
64
+
65
+ logger = logging.getLogger(__name__)
66
+
67
+ RUN_NAME_BASE = "Sagemaker-Run".lower()
68
+ TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}"
69
+ MAX_RUN_TC_ARTIFACTS_LEN = 30
70
+ MAX_NAME_LEN_IN_BACKEND = 120
71
+ EXPERIMENT_NAME = "ExperimentName"
72
+ TRIAL_NAME = "TrialName"
73
+ RUN_NAME = "RunName"
74
+ DELIMITER = "-"
75
+ RUN_TC_TAG_KEY = "sagemaker:trial-component-source"
76
+ RUN_TC_TAG_VALUE = "run"
77
+ RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE}
78
+
79
+
80
+ class SortByType(Enum):
81
+ """The type of property by which to sort the `list_runs` results."""
82
+
83
+ CREATION_TIME = "CreationTime"
84
+ NAME = "Name"
85
+
86
+
87
+ class SortOrderType(Enum):
88
+ """The type of order to sort the list or search results."""
89
+
90
+ ASCENDING = "Ascending"
91
+ DESCENDING = "Descending"
92
+
93
+
94
+ class Run(object):
95
+ """A collection of parameters, metrics, and artifacts to create a ML model."""
96
+
97
+ def __init__(
98
+ self,
99
+ experiment_name: str,
100
+ run_name: Optional[str] = None,
101
+ experiment_display_name: Optional[str] = None,
102
+ run_display_name: Optional[str] = None,
103
+ tags: Optional[Tags] = None,
104
+ sagemaker_session: Optional["Session"] = None,
105
+ artifact_bucket: Optional[str] = None,
106
+ artifact_prefix: Optional[str] = None,
107
+ ):
108
+ """Construct a `Run` instance.
109
+
110
+ SageMaker Experiments automatically tracks the inputs, parameters, configurations,
111
+ and results of your iterations as runs.
112
+ You can assign, group, and organize these runs into experiments.
113
+ You can also create, compare, and evaluate runs.
114
+
115
+ The code sample below shows how to initialize a run, log parameters to the Run object
116
+ and invoke a training job under the context of this Run object, which automatically
117
+ passes the run's ``experiment_config`` (including the experiment name, run name etc.)
118
+ to the training job.
119
+
120
+ Note:
121
+ All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within
122
+ the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown.
123
+
124
+ .. code:: python
125
+
126
+ with Run(experiment_name="my-exp", run_name="my-run", ...) as run:
127
+ run.log_parameter(...)
128
+ ...
129
+ estimator.fit(job_name="my-job") # Create a training job
130
+
131
+ In order to reuse an existing run to log extra data, ``load_run`` is recommended.
132
+ For example, instead of the ``Run`` constructor, the ``load_run`` is recommended to use
133
+ in a job script to load the existing run created before the job launch.
134
+ Otherwise, a new run may be created each time you launch a job.
135
+
136
+ The code snippet below displays how to load the run initialized above
137
+ in a custom training job script, where no ``run_name`` or ``experiment_name``
138
+ is presented as they are automatically retrieved from the experiment config
139
+ in the job environment.
140
+
141
+ .. code:: python
142
+
143
+ with load_run(sagemaker_session=sagemaker_session) as run:
144
+ run.log_metric(...)
145
+ ...
146
+
147
+ Args:
148
+ experiment_name (str): The name of the experiment. The name must be unique
149
+ within an account.
150
+ run_name (str): The name of the run. If it is not specified, one is auto generated.
151
+ experiment_display_name (str): Name of the experiment that will appear in UI,
152
+ such as SageMaker Studio. (default: None). This display name is used in
153
+ a create experiment call. If an experiment with the specified name already exists,
154
+ this display name won't take effect.
155
+ run_display_name (str): The display name of the run used in UI (default: None).
156
+ This display name is used in a create run call. If a run with the
157
+ specified name already exists, this display name won't take effect.
158
+ tags (Optional[Tags]): Tags to be used for all create calls,
159
+ e.g. to create an experiment, a run group, etc. (default: None).
160
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
161
+ manages interactions with Amazon SageMaker APIs and any other
162
+ AWS services needed. If not specified, one is created using the
163
+ default AWS configuration chain.
164
+ artifact_bucket (str): The S3 bucket to upload the artifact to.
165
+ If not specified, the default bucket defined in `sagemaker_session`
166
+ will be used.
167
+ artifact_prefix (str): The S3 key prefix used to generate the S3 path
168
+ to upload the artifact to (default: "trial-component-artifacts").
169
+ """
170
+ # TODO: we should revert the lower casting once backend fix reaches prod
171
+ self.experiment_name = experiment_name.lower()
172
+ sagemaker_session = sagemaker_session or _utils.default_session()
173
+ self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE)
174
+
175
+ # avoid confusion due to mis-match in casing between run name and TC name
176
+ self.run_name = self.run_name.lower()
177
+
178
+ tags = format_tags(tags)
179
+
180
+ trial_component_name = Run._generate_trial_component_name(
181
+ run_name=self.run_name, experiment_name=self.experiment_name
182
+ )
183
+ self.run_group_name = Run._generate_trial_name(self.experiment_name)
184
+
185
+ self._experiment = Experiment._load_or_create(
186
+ experiment_name=self.experiment_name,
187
+ display_name=experiment_display_name,
188
+ tags=tags,
189
+ sagemaker_session=sagemaker_session,
190
+ )
191
+
192
+ self._trial = _Trial._load_or_create(
193
+ experiment_name=self.experiment_name,
194
+ trial_name=self.run_group_name,
195
+ tags=tags,
196
+ sagemaker_session=sagemaker_session,
197
+ )
198
+
199
+ self._trial_component, is_existed = _TrialComponent._load_or_create(
200
+ trial_component_name=trial_component_name,
201
+ display_name=run_display_name,
202
+ tags=Run._append_run_tc_label_to_tags(tags),
203
+ sagemaker_session=sagemaker_session,
204
+ )
205
+ if is_existed:
206
+ logger.info(
207
+ "The run (%s) under experiment (%s) already exists. Loading it.",
208
+ self.run_name,
209
+ self.experiment_name,
210
+ )
211
+
212
+ if not _TrialComponent._trial_component_is_associated_to_trial(
213
+ self._trial_component.trial_component_name,
214
+ self._trial.trial_name,
215
+ sagemaker_session,
216
+ ):
217
+ self._trial.add_trial_component(self._trial_component)
218
+
219
+ self._artifact_uploader = _ArtifactUploader(
220
+ trial_component_name=self._trial_component.trial_component_name,
221
+ sagemaker_session=sagemaker_session,
222
+ artifact_bucket=artifact_bucket,
223
+ artifact_prefix=(
224
+ _DEFAULT_ARTIFACT_PREFIX if artifact_prefix is None else artifact_prefix
225
+ ),
226
+ )
227
+ self._lineage_artifact_tracker = _LineageArtifactTracker(
228
+ trial_component_arn=self._trial_component.trial_component_arn,
229
+ sagemaker_session=sagemaker_session,
230
+ )
231
+ self._metrics_manager = _MetricsManager(
232
+ trial_component_name=self._trial_component.trial_component_name,
233
+ sagemaker_session=sagemaker_session,
234
+ )
235
+ self._inside_init_context = False
236
+ self._inside_load_context = False
237
+ self._in_load = False
238
+
239
+ @property
240
+ def experiment_config(self) -> dict:
241
+ """Get experiment config from run attributes."""
242
+ return {
243
+ EXPERIMENT_NAME: self.experiment_name,
244
+ TRIAL_NAME: self.run_group_name,
245
+ RUN_NAME: self._trial_component.trial_component_name,
246
+ }
247
+
248
+ @validate_invoked_inside_run_context
249
+ def log_parameter(self, name: str, value: Union[str, int, float]):
250
+ """Record a single parameter value for this run.
251
+
252
+ Overwrites any previous value recorded for the specified parameter name.
253
+
254
+ Args:
255
+ name (str): The name of the parameter.
256
+ value (str or int or float): The value of the parameter.
257
+ """
258
+ if self._is_input_valid("parameter", name, value):
259
+ self._trial_component.parameters[name] = value
260
+
261
+ @validate_invoked_inside_run_context
262
+ def log_parameters(self, parameters: Dict[str, Union[str, int, float]]):
263
+ """Record a collection of parameter values for this run.
264
+
265
+ Args:
266
+ parameters (dict[str, str or int or float]): The parameters to record.
267
+ """
268
+ filtered_parameters = {
269
+ key: value
270
+ for (key, value) in parameters.items()
271
+ if self._is_input_valid("parameter", key, value)
272
+ }
273
+ self._trial_component.parameters.update(filtered_parameters)
274
+
275
+ @validate_invoked_inside_run_context
276
+ def log_metric(
277
+ self,
278
+ name: str,
279
+ value: float,
280
+ timestamp: Optional[datetime.datetime] = None,
281
+ step: Optional[int] = None,
282
+ ):
283
+ """Record a custom scalar metric value for this run.
284
+
285
+ Note:
286
+ This method is for manual custom metrics, for automatic metrics see the
287
+ ``enable_sagemaker_metrics`` parameter on the ``estimator`` class.
288
+
289
+ Args:
290
+ name (str): The name of the metric.
291
+ value (float): The value of the metric.
292
+ timestamp (datetime.datetime): The timestamp of the metric.
293
+ If not specified, the current UTC time will be used.
294
+ step (int): The integer iteration number of the metric value (default: None).
295
+ """
296
+ if self._is_input_valid("metric", name, value):
297
+ self._metrics_manager.log_metric(
298
+ metric_name=name, value=value, timestamp=timestamp, step=step
299
+ )
300
+
301
+ @validate_invoked_inside_run_context
302
+ def log_precision_recall(
303
+ self,
304
+ y_true: Union[list, array],
305
+ predicted_probabilities: Union[list, array],
306
+ positive_label: Optional[Union[str, int]] = None,
307
+ title: Optional[str] = None,
308
+ is_output: bool = True,
309
+ no_skill: Optional[int] = None,
310
+ ):
311
+ """Create and log a precision recall graph artifact for Studio UI to render.
312
+
313
+ The artifact is stored in S3 and represented as a lineage artifact
314
+ with an association with the run.
315
+
316
+ You can view the artifact in the UI.
317
+ If your job is created by a pipeline execution you can view the artifact
318
+ by selecting the corresponding step in the pipelines UI.
319
+ See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
320
+
321
+ This method requires sklearn library.
322
+
323
+ Args:
324
+ y_true (list or array): True labels. If labels are not binary
325
+ then positive_label should be given.
326
+ predicted_probabilities (list or array): Estimated/predicted probabilities.
327
+ positive_label (str or int): Label of the positive class (default: None).
328
+ title (str): Title of the graph (default: None).
329
+ is_output (bool): Determines direction of association to the
330
+ run. Defaults to True (output artifact).
331
+ If set to False then represented as input association.
332
+ no_skill (int): The precision threshold under which the classifier cannot discriminate
333
+ between the classes and would predict a random class or a constant class in
334
+ all cases (default: None).
335
+ """
336
+
337
+ verify_length_of_true_and_predicted(
338
+ true_labels=y_true,
339
+ predicted_attrs=predicted_probabilities,
340
+ predicted_attrs_name="predicted probabilities",
341
+ )
342
+
343
+ get_module("sklearn")
344
+ from sklearn.metrics import precision_recall_curve, average_precision_score
345
+
346
+ kwargs = {}
347
+ if positive_label is not None:
348
+ kwargs["pos_label"] = positive_label
349
+
350
+ precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs)
351
+
352
+ kwargs["average"] = "micro"
353
+ ap = average_precision_score(y_true, predicted_probabilities, **kwargs)
354
+
355
+ data = {
356
+ "type": "PrecisionRecallCurve",
357
+ "version": 0,
358
+ "title": title,
359
+ "precision": precision.tolist(),
360
+ "recall": recall.tolist(),
361
+ "averagePrecisionScore": ap,
362
+ "noSkill": no_skill,
363
+ }
364
+ self._log_graph_artifact(
365
+ artifact_name=title,
366
+ data=data,
367
+ graph_type="PrecisionRecallCurve",
368
+ is_output=is_output,
369
+ )
370
+
371
+ @validate_invoked_inside_run_context
372
+ def log_roc_curve(
373
+ self,
374
+ y_true: Union[list, array],
375
+ y_score: Union[list, array],
376
+ title: Optional[str] = None,
377
+ is_output: bool = True,
378
+ ):
379
+ """Create and log a receiver operating characteristic (ROC curve) artifact.
380
+
381
+ The artifact is stored in S3 and represented as a lineage artifact
382
+ with an association with the run.
383
+
384
+ You can view the artifact in the UI.
385
+ If your job is created by a pipeline execution you can view the artifact
386
+ by selecting the corresponding step in the pipelines UI.
387
+ See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
388
+
389
+ This method requires sklearn library.
390
+
391
+ Args:
392
+ y_true (list or array): True labels. If labels are not binary
393
+ then positive_label should be given.
394
+ y_score (list or array): Estimated/predicted probabilities.
395
+ title (str): Title of the graph (default: None).
396
+ is_output (bool): Determines direction of association to the
397
+ run. Defaults to True (output artifact).
398
+ If set to False then represented as input association.
399
+ """
400
+ verify_length_of_true_and_predicted(
401
+ true_labels=y_true,
402
+ predicted_attrs=y_score,
403
+ predicted_attrs_name="predicted scores",
404
+ )
405
+
406
+ get_module("sklearn")
407
+ from sklearn.metrics import roc_curve, auc
408
+
409
+ fpr, tpr, _ = roc_curve(y_true, y_score)
410
+
411
+ auc = auc(fpr, tpr)
412
+
413
+ data = {
414
+ "type": "ROCCurve",
415
+ "version": 0,
416
+ "title": title,
417
+ "falsePositiveRate": fpr.tolist(),
418
+ "truePositiveRate": tpr.tolist(),
419
+ "areaUnderCurve": auc,
420
+ }
421
+ self._log_graph_artifact(
422
+ artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output
423
+ )
424
+
425
+ @validate_invoked_inside_run_context
426
+ def log_confusion_matrix(
427
+ self,
428
+ y_true: Union[list, array],
429
+ y_pred: Union[list, array],
430
+ title: Optional[str] = None,
431
+ is_output: bool = True,
432
+ ):
433
+ """Create and log a confusion matrix artifact.
434
+
435
+ The artifact is stored in S3 and represented as a lineage artifact
436
+ with an association with the run.
437
+
438
+ You can view the artifact in the UI.
439
+ If your job is created by a pipeline execution you can view the
440
+ artifact by selecting the corresponding step in the pipelines UI.
441
+ See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
442
+ This method requires sklearn library.
443
+
444
+ Args:
445
+ y_true (list or array): True labels. If labels are not binary
446
+ then positive_label should be given.
447
+ y_pred (list or array): Predicted labels.
448
+ title (str): Title of the graph (default: None).
449
+ is_output (bool): Determines direction of association to the
450
+ run. Defaults to True (output artifact).
451
+ If set to False then represented as input association.
452
+ """
453
+ verify_length_of_true_and_predicted(
454
+ true_labels=y_true,
455
+ predicted_attrs=y_pred,
456
+ predicted_attrs_name="predicted labels",
457
+ )
458
+
459
+ get_module("sklearn")
460
+ from sklearn.metrics import confusion_matrix
461
+
462
+ matrix = confusion_matrix(y_true, y_pred)
463
+
464
+ data = {
465
+ "type": "ConfusionMatrix",
466
+ "version": 0,
467
+ "title": title,
468
+ "confusionMatrix": matrix.tolist(),
469
+ }
470
+ self._log_graph_artifact(
471
+ artifact_name=title,
472
+ data=data,
473
+ graph_type="ConfusionMatrix",
474
+ is_output=is_output,
475
+ )
476
+
477
+ @validate_invoked_inside_run_context
478
+ def log_artifact(
479
+ self,
480
+ name: str,
481
+ value: str,
482
+ media_type: Optional[str] = None,
483
+ is_output: bool = True,
484
+ ):
485
+ """Record a single artifact for this run.
486
+
487
+ Overwrites any previous value recorded for the specified name.
488
+
489
+ Args:
490
+ name (str): The name of the artifact.
491
+ value (str): The value.
492
+ media_type (str): The MediaType (MIME type) of the value (default: None).
493
+ is_output (bool): Determines direction of association to the
494
+ run. Defaults to True (output artifact).
495
+ If set to False then represented as input association.
496
+ """
497
+ self._verify_trial_component_artifacts_length(is_output=is_output)
498
+ if is_output:
499
+ self._trial_component.output_artifacts[name] = TrialComponentArtifact(
500
+ value, media_type=media_type
501
+ )
502
+ else:
503
+ self._trial_component.input_artifacts[name] = TrialComponentArtifact(
504
+ value, media_type=media_type
505
+ )
506
+
507
+ @validate_invoked_inside_run_context
508
+ def log_file(
509
+ self,
510
+ file_path: str,
511
+ name: Optional[str] = None,
512
+ media_type: Optional[str] = None,
513
+ is_output: Optional[bool] = True,
514
+ extra_args: Optional[dict] = None,
515
+ ):
516
+ """Upload a file to s3 and store it as an input/output artifact in this run.
517
+
518
+ Args:
519
+ file_path (str): The path of the local file to upload.
520
+ name (str): The name of the artifact (default: None).
521
+ media_type (str): The MediaType (MIME type) of the file.
522
+ If not specified, this library will attempt to infer the media type
523
+ from the file extension of ``file_path``.
524
+ is_output (bool): Determines direction of association to the
525
+ run. Defaults to True (output artifact).
526
+ If set to False then represented as input association.
527
+ extra_args (dict): Optional extra arguments that may be passed to the upload operation.
528
+ Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
529
+ ExtraArgs parameter documentation here:
530
+ https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
531
+ """
532
+ self._verify_trial_component_artifacts_length(is_output)
533
+ media_type = media_type or guess_media_type(file_path)
534
+ name = name or resolve_artifact_name(file_path)
535
+ s3_uri, _ = self._artifact_uploader.upload_artifact(file_path, extra_args=extra_args)
536
+ if is_output:
537
+ self._trial_component.output_artifacts[name] = TrialComponentArtifact(
538
+ value=s3_uri, media_type=media_type
539
+ )
540
+ else:
541
+ self._trial_component.input_artifacts[name] = TrialComponentArtifact(
542
+ value=s3_uri, media_type=media_type
543
+ )
544
+
545
+ def close(self):
546
+ """Persist any data saved locally."""
547
+ try:
548
+ # Update the trial component with additions from the Run object
549
+ self._trial_component.save()
550
+ # Create Lineage entities for the artifacts
551
+ self._lineage_artifact_tracker.save()
552
+ finally:
553
+ if self._metrics_manager:
554
+ self._metrics_manager.close()
555
+
556
+ @staticmethod
557
+ def _generate_trial_name(base_name) -> str:
558
+ """Generate the reserved trial name based on experiment name
559
+
560
+ Args:
561
+ base_name (str): The ``experiment_name`` of this ``Run`` object.
562
+ """
563
+ available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE)
564
+ return TRIAL_NAME_TEMPLATE.format(base_name[:available_length])
565
+
566
+ @staticmethod
567
+ def _is_input_valid(input_type, field_name, field_value) -> bool:
568
+ """Check if the input is valid or not
569
+
570
+ Args:
571
+ input_type (str): The type of the input, one of ``parameter``, ``metric``.
572
+ field_name (str): The name of the field to be checked.
573
+ field_value (str or int or float): The value of the field to be checked.
574
+ """
575
+ if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
576
+ logger.warning(
577
+ "Failed to log %s %s. Received invalid value: %s.",
578
+ input_type,
579
+ field_name,
580
+ field_value,
581
+ )
582
+ return False
583
+ return True
584
+
585
+ def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
586
+ """Log an artifact.
587
+
588
+ Logs an artifact by uploading data to S3, creating an artifact, and associating that
589
+ artifact with the run trial component.
590
+
591
+ Args:
592
+ data (dict): Artifacts data that will be saved to S3.
593
+ graph_type (str): The type of the artifact.
594
+ is_output (bool): Determines direction of association to the
595
+ trial component. Defaults to True (output artifact).
596
+ If set to False then represented as input association.
597
+ artifact_name (str): Name of the artifact (default: None).
598
+ """
599
+ # generate an artifact name
600
+ if not artifact_name:
601
+ unique_name_from_base(graph_type)
602
+
603
+ # create a json file in S3
604
+ s3_uri, etag = self._artifact_uploader.upload_object_artifact(
605
+ artifact_name, data, file_extension="json"
606
+ )
607
+
608
+ # create an artifact and association for the table
609
+ if is_output:
610
+ self._lineage_artifact_tracker.add_output_artifact(
611
+ name=artifact_name,
612
+ source_uri=s3_uri,
613
+ etag=etag,
614
+ artifact_type=graph_type,
615
+ )
616
+ else:
617
+ self._lineage_artifact_tracker.add_input_artifact(
618
+ name=artifact_name,
619
+ source_uri=s3_uri,
620
+ etag=etag,
621
+ artifact_type=graph_type,
622
+ )
623
+
624
+ def _verify_trial_component_artifacts_length(self, is_output):
625
+ """Verify the length of trial component artifacts
626
+
627
+ Args:
628
+ is_output (bool): Determines direction of association to the
629
+ trial component.
630
+
631
+ Raises:
632
+ ValueError: If the length of trial component artifacts exceeds the limit.
633
+ """
634
+ err_msg_template = "Cannot add more than {} {}_artifacts under run"
635
+ if is_output:
636
+ if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
637
+ raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
638
+ else:
639
+ if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
640
+ raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input"))
641
+
642
+ @staticmethod
643
+ def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
644
+ """Generate the TrialComponentName based on run_name and experiment_name
645
+
646
+ Args:
647
+ run_name (str): The run_name supplied by the user.
648
+ experiment_name (str): The experiment_name supplied by the user,
649
+ which is prepended to the run_name to generate the TrialComponentName.
650
+
651
+ Returns:
652
+ str: The TrialComponentName used to create a trial component
653
+ which is unique in an account.
654
+
655
+ Raises:
656
+ ValueError: If either the run_name or the experiment_name exceeds
657
+ the length limit.
658
+ """
659
+ buffer = 1 # leave length buffers for delimiters
660
+ max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer
661
+ err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
662
+ if len(run_name) > max_len:
663
+ raise ValueError(err_msg_template.format("run_name", len(run_name), max_len))
664
+ if len(experiment_name) > max_len:
665
+ raise ValueError(
666
+ err_msg_template.format("experiment_name", len(experiment_name), max_len)
667
+ )
668
+ trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name)
669
+ # due to mixed-case concerns on the backend
670
+ trial_component_name = trial_component_name.lower()
671
+ return trial_component_name
672
+
673
+ @staticmethod
674
+ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str:
675
+ """Extract the user supplied run name from a trial component name.
676
+
677
+ Args:
678
+ trial_component_name (str): The name of a run trial component.
679
+ experiment_name (str): The experiment_name supplied by the user,
680
+ which was prepended to the run_name to generate the trial_component_name.
681
+
682
+ Returns:
683
+ str: The name of the Run object supplied by a user.
684
+ """
685
+ # TODO: we should revert the lower casting once backend fix reaches prod
686
+ return trial_component_name.replace(
687
+ "{}{}".format(experiment_name.lower(), DELIMITER), "", 1
688
+ )
689
+
690
+ @staticmethod
691
+ def _append_run_tc_label_to_tags(tags: Optional[List[TagsDict]] = None) -> list:
692
+ """Append the run trial component label to tags used to create a trial component.
693
+
694
+ Args:
695
+ tags (List[TagsDict]): The tags supplied by users to initialize a Run object.
696
+
697
+ Returns:
698
+ list: The updated tags with the appended run trial component label.
699
+ """
700
+ if not tags:
701
+ tags = []
702
+ if RUN_TC_TAG not in tags:
703
+ tags.append(RUN_TC_TAG)
704
+ return tags
705
+
706
+ def __enter__(self):
707
+ """Updates the start time of the run.
708
+
709
+ Returns:
710
+ object: self.
711
+ """
712
+ nested_with_err_msg_template = (
713
+ "It is not allowed to use nested 'with' statements on the {}."
714
+ )
715
+ if self._in_load:
716
+ if self._inside_load_context:
717
+ raise RuntimeError(nested_with_err_msg_template.format("load_run"))
718
+ self._inside_load_context = True
719
+ if not self._inside_init_context:
720
+ # Add to run context only if the load_run is called separately
721
+ # without under a Run init context
722
+ _RunContext.add_run_object(self)
723
+ else:
724
+ if _RunContext.get_current_run():
725
+ raise RuntimeError(nested_with_err_msg_template.format("Run"))
726
+ self._inside_init_context = True
727
+ _RunContext.add_run_object(self)
728
+
729
+ if not self._trial_component.start_time:
730
+ start_time = datetime.datetime.now(dateutil.tz.tzlocal())
731
+ self._trial_component.start_time = start_time
732
+ self._trial_component.status = _api_types.TrialComponentStatus(
733
+ primary_status=_TrialComponentStatusType.InProgress.value,
734
+ message="Within a run context",
735
+ )
736
+ # Save the start_time and status changes to backend
737
+ self._trial_component.save()
738
+ return self
739
+
740
+ def __exit__(self, exc_type, exc_value, exc_traceback):
741
+ """Updates the end time of the run.
742
+
743
+ Args:
744
+ exc_type (str): The exception type.
745
+ exc_value (str): The exception value.
746
+ exc_traceback (str): The stack trace of the exception.
747
+ """
748
+ if self._in_load:
749
+ self._inside_load_context = False
750
+ self._in_load = False
751
+ if not self._inside_init_context:
752
+ _RunContext.drop_current_run()
753
+ else:
754
+ self._inside_init_context = False
755
+ _RunContext.drop_current_run()
756
+
757
+ end_time = datetime.datetime.now(dateutil.tz.tzlocal())
758
+ self._trial_component.end_time = end_time
759
+ if exc_value:
760
+ self._trial_component.status = _api_types.TrialComponentStatus(
761
+ primary_status=_TrialComponentStatusType.Failed.value,
762
+ message=str(exc_value),
763
+ )
764
+ else:
765
+ self._trial_component.status = _api_types.TrialComponentStatus(
766
+ primary_status=_TrialComponentStatusType.Completed.value
767
+ )
768
+
769
+ self.close()
770
+
771
+ def __getstate__(self):
772
+ """Overriding this method to prevent instance of Run from being pickled.
773
+
774
+ Raise:
775
+ NotImplementedError: If attempting to pickle this instance.
776
+ """
777
+ raise NotImplementedError("Instance of Run type is not allowed to be pickled.")
778
+
779
+
780
+ def load_run(
781
+ run_name: Optional[str] = None,
782
+ experiment_name: Optional[str] = None,
783
+ sagemaker_session: Optional["Session"] = None,
784
+ artifact_bucket: Optional[str] = None,
785
+ artifact_prefix: Optional[str] = None,
786
+ tags: Optional[List[Dict[str, str]]] = None,
787
+ ) -> Run:
788
+ """Load an existing run.
789
+
790
+ In order to reuse an existing run to log extra data, ``load_run`` is recommended.
791
+ It can be used in several ways:
792
+
793
+ 1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``.
794
+
795
+ If ``run_name`` and ``experiment_name`` are passed in, they are honored over
796
+ the default experiment config in the job environment or the run context
797
+ (i.e. within the ``with`` block).
798
+
799
+ Note:
800
+ Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work.
801
+ Otherwise, you may get a ``ValueError``.
802
+
803
+ .. code:: python
804
+
805
+ with load_run(experiment_name="my-exp", run_name="my-run") as run:
806
+ run.log_metric(...)
807
+ ...
808
+
809
+ 2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``.
810
+
811
+ In this case, the default experiment config (specified when creating the job) is fetched
812
+ from the job environment to load the run.
813
+
814
+ .. code:: python
815
+
816
+ # In a job script
817
+ with load_run() as run:
818
+ run.log_metric(...)
819
+ ...
820
+
821
+ 3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block)
822
+ but without supplying ``run_name`` and ``experiment_name``.
823
+
824
+ Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked
825
+ in the run context. Then when we call ``load_run()`` under this with statement, the ``run1``
826
+ in the context is loaded by default.
827
+
828
+ .. code:: python
829
+
830
+ # In a notebook
831
+ with Run(experiment_name="my-exp", run_name="my-run", ...) as run1:
832
+ run1.log_parameter(...)
833
+
834
+ with load_run() as run2: # run2 is the same object as run1
835
+ run2.log_metric(...)
836
+ ...
837
+
838
+ Args:
839
+ run_name (str): The name of the run to be loaded (default: None).
840
+ If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be
841
+ fetched to load the run.
842
+ experiment_name (str): The name of the Experiment that the to be loaded run
843
+ is associated with (default: None).
844
+ Note: the experiment_name must be supplied along with a valid run_name.
845
+ Otherwise, it will be ignored.
846
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
847
+ manages interactions with Amazon SageMaker APIs and any other
848
+ AWS services needed. If not specified, one is created using the
849
+ default AWS configuration chain.
850
+ artifact_bucket (str): The S3 bucket to upload the artifact to.
851
+ If not specified, the default bucket defined in `sagemaker_session`
852
+ will be used.
853
+ artifact_prefix (str): The S3 key prefix used to generate the S3 path
854
+ to upload the artifact to (default: "trial-component-artifacts").
855
+ tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
856
+ e.g. to create an experiment, a run group, etc. (default: None).
857
+
858
+ Returns:
859
+ Run: The loaded Run object.
860
+ """
861
+ environment = _RunEnvironment.load()
862
+
863
+ verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
864
+
865
+ if run_name:
866
+ logger.warning(
867
+ "run_name is explicitly supplied in load_run, "
868
+ "which will be prioritized to load the Run object. "
869
+ "In other words, the run name in the experiment config, fetched from the "
870
+ "job environment or the current run context, will be ignored."
871
+ )
872
+ run_instance = Run(
873
+ experiment_name=experiment_name,
874
+ run_name=run_name,
875
+ sagemaker_session=sagemaker_session or _utils.default_session(),
876
+ artifact_bucket=artifact_bucket,
877
+ artifact_prefix=artifact_prefix,
878
+ tags=tags,
879
+ )
880
+ elif _RunContext.get_current_run():
881
+ run_instance = _RunContext.get_current_run()
882
+ elif environment:
883
+ exp_config = get_tc_and_exp_config_from_job_env(
884
+ environment=environment,
885
+ sagemaker_session=sagemaker_session or _utils.default_session(),
886
+ )
887
+ run_name = Run._extract_run_name_from_tc_name(
888
+ trial_component_name=exp_config[RUN_NAME],
889
+ experiment_name=exp_config[EXPERIMENT_NAME],
890
+ )
891
+ experiment_name = exp_config[EXPERIMENT_NAME]
892
+ run_instance = Run(
893
+ experiment_name=experiment_name,
894
+ run_name=run_name,
895
+ sagemaker_session=sagemaker_session or _utils.default_session(),
896
+ artifact_bucket=artifact_bucket,
897
+ artifact_prefix=artifact_prefix,
898
+ tags=tags,
899
+ )
900
+ else:
901
+ raise RuntimeError(
902
+ "Failed to load a Run object. "
903
+ "Please make sure a Run object has been initialized already."
904
+ )
905
+
906
+ run_instance._in_load = True
907
+ return run_instance
908
+
909
+
910
+ def list_runs(
911
+ experiment_name: str,
912
+ created_before: Optional[datetime.datetime] = None,
913
+ created_after: Optional[datetime.datetime] = None,
914
+ sagemaker_session: Optional["Session"] = None,
915
+ max_results: Optional[int] = None,
916
+ next_token: Optional[str] = None,
917
+ sort_by: SortByType = SortByType.CREATION_TIME,
918
+ sort_order: SortOrderType = SortOrderType.DESCENDING,
919
+ ) -> list:
920
+ """Return a list of ``Run`` objects matching the given criteria.
921
+
922
+ Args:
923
+ experiment_name (str): Only Run objects related to the specified experiment
924
+ are returned.
925
+ created_before (datetime.datetime): Return Run objects created before this instant
926
+ (default: None).
927
+ created_after (datetime.datetime): Return Run objects created after this instant
928
+ (default: None).
929
+ sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
930
+ manages interactions with Amazon SageMaker APIs and any other
931
+ AWS services needed. If not specified, one is created using the
932
+ default AWS configuration chain.
933
+ max_results (int): Maximum number of Run objects to retrieve (default: None).
934
+ next_token (str): Token for next page of results (default: None).
935
+ sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME
936
+ (default: CREATION_TIME).
937
+ sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING).
938
+
939
+ Returns:
940
+ list: A list of ``Run`` objects.
941
+ """
942
+
943
+ # all trial components retrieved by default
944
+ tc_summaries = _TrialComponent.list(
945
+ experiment_name=experiment_name,
946
+ created_before=created_before,
947
+ created_after=created_after,
948
+ sort_by=sort_by.value,
949
+ sort_order=sort_order.value,
950
+ sagemaker_session=sagemaker_session,
951
+ max_results=max_results,
952
+ next_token=next_token,
953
+ )
954
+ run_list = []
955
+ for tc_summary in tc_summaries:
956
+ if not is_run_trial_component(
957
+ trial_component_name=tc_summary.trial_component_name,
958
+ sagemaker_session=sagemaker_session,
959
+ ):
960
+ continue
961
+ run_instance = Run(
962
+ experiment_name=experiment_name,
963
+ run_name=Run._extract_run_name_from_tc_name(
964
+ trial_component_name=tc_summary.trial_component_name,
965
+ experiment_name=experiment_name,
966
+ ),
967
+ sagemaker_session=sagemaker_session,
968
+ )
969
+ run_list.append(run_instance)
970
+ return run_list