mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,1076 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import os
17
+ import uuid
18
+ from collections import defaultdict
19
+ from datetime import datetime
20
+ from functools import partial
21
+
22
+ import pytz
23
+ import torch
24
+ import torch.distributed as dist
25
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
26
+ from torch.utils.hooks import BackwardHook
27
+
28
+ from msprobe.core.common.const import MonitorConst
29
+ from msprobe.core.common.file_utils import load_json, save_json
30
+ from msprobe.pytorch.common.log import logger
31
+ from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
32
+ from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
33
+ CSVWriterWithAD, BaseWriterWithAD, WriterInput
34
+ from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
35
+ get_process_group
36
+ from msprobe.pytorch.monitor.features import get_sign_matches
37
+ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
38
+ TensorMetrics, squash_param_name
39
+ from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
40
+ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
41
+ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \
42
+ get_output_base_dir, get_target_output_dir
43
+ from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
44
+
45
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
46
+ if not torch_version_above_or_equal_2:
47
+ raise ValueError("monitor require torch>=2.0")
48
+
49
+ FORMAT_MAPPING = {
50
+ MonitorConst.TENSORBOARD: SummaryWriterWithAD,
51
+ MonitorConst.CSV: CSVWriterWithAD,
52
+ MonitorConst.API: BaseWriterWithAD
53
+ }
54
+
55
+
56
+ def param_is_not_tensor_parallel_duplicate(param, tp_group):
57
+ return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
58
+ torch.distributed.get_rank(group=tp_group) == 0
59
+ )
60
+
61
+
62
+ def param_is_data_parallel_duplicate(dp_group):
63
+ return torch.distributed.get_rank(group=dp_group) != 0
64
+
65
+
66
+ class ModuleHookContext:
67
+ def __init__(self, module_name) -> None:
68
+ self.micro_step = 0
69
+ self.actv = defaultdict(dict)
70
+ self.actvgrad = []
71
+ self.module_name = module_name
72
+ self.struct = {}
73
+ self.format_by_arg = {}
74
+ self.verified = False
75
+ self.focused_in_col = 0
76
+ self.focused_out_col = 0
77
+
78
+ def set_format_by_arg(self, key_name: str, target_config: dict):
79
+ """ 按照监控对象配置format_by_arg
80
+ 1) module_name 在 target 中配置监控对象
81
+ 2) module_name 未在 targets 中配置,且 all_xy 全量监控
82
+ 3) module_name 未在 targets 中配置,且 all_xy 未全量监控
83
+
84
+ :param key_name: str, one of [input, output, input_grad, output_grad]
85
+ :param target_config: target obj in config json.
86
+ :return:
87
+ """
88
+ valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
89
+ if key_name not in valid_key:
90
+ raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
91
+ cared = target_config.get(self.module_name, self.struct)
92
+ if key_name in cared:
93
+ target_module_config = cared[key_name]
94
+ if isinstance(target_module_config, dict):
95
+ # current cared is self.struct, monitor all data for module_name
96
+ self.format_by_arg[key_name] = target_module_config.get('config')
97
+ elif isinstance(target_module_config, str):
98
+ # current cared is target_config[self.module_name]
99
+ self.format_by_arg[key_name] = target_module_config
100
+ else:
101
+ logger.warning_on_rank_0(f"target module config error, result maybe empty."
102
+ f"module_name: {self.module_name}, key_name: {key_name}")
103
+ self.format_by_arg[key_name] = None
104
+ else:
105
+ self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
106
+
107
+ def reset(self):
108
+ self.actv.clear()
109
+ self.actvgrad.clear()
110
+
111
+
112
+ start_step = 0
113
+
114
+
115
+ class OptimizerContext:
116
+ def __init__(self) -> None:
117
+ self.step = start_step
118
+ self.param_mg_direction = defaultdict(float)
119
+ self.param_adam_update = defaultdict()
120
+ self.param_adam_ratio = defaultdict()
121
+ self.param_weight_grad = defaultdict()
122
+ self.param_exp_avg = defaultdict()
123
+ self.exp_avg_metric = {}
124
+ self.param_exp_avg_sq = defaultdict()
125
+ self.exp_avg_sq_metric = {}
126
+ self.metric_dict = {}
127
+ self.param_metric = {}
128
+
129
+ def reset(self):
130
+ self.param_mg_direction.clear()
131
+ self.param_adam_update.clear()
132
+ self.param_adam_ratio.clear()
133
+ self.param_weight_grad.clear()
134
+ self.param_exp_avg.clear()
135
+ self.exp_avg_metric.clear()
136
+ self.param_exp_avg_sq.clear()
137
+ self.exp_avg_sq_metric.clear()
138
+ self.metric_dict.clear()
139
+ self.param_metric.clear()
140
+
141
+
142
+ class CommunicationContext:
143
+ def __init__(self) -> None:
144
+ self.data = {}
145
+
146
+ @staticmethod
147
+ def _agg(data):
148
+ aggregated_data = {}
149
+ for tag, op2tensorlist in data.items():
150
+ aggregated_data[tag] = {}
151
+ for op, tensorlist in op2tensorlist.items():
152
+ aggregated_data[tag][op] = op_aggregate(op, tensorlist)
153
+ return aggregated_data
154
+
155
+ def reset(self):
156
+ self.data = {}
157
+
158
+ def aggregate(self):
159
+ self.data = self._agg(self.data)
160
+
161
+
162
+ class GradContext:
163
+ def __init__(self) -> None:
164
+ self.pre = {}
165
+ self.post = {}
166
+ self.acc_metric = {}
167
+ self.acc = {}
168
+ self.actv = {}
169
+
170
+ def reset(self):
171
+ self.pre.clear()
172
+ self.post.clear()
173
+ self.acc_metric.clear()
174
+ self.acc.clear()
175
+ self.actv.clear()
176
+
177
+
178
+ class TrainerMon:
179
+ tensor_metrics = TensorMetrics()
180
+
181
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
182
+ """
183
+ opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
184
+ """
185
+ # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
186
+ self.config_file_path = config_file_path
187
+ self.process_group = get_process_group(process_group)
188
+ self.params_have_main_grad = params_have_main_grad
189
+ self.opt_ty = opt_ty
190
+ self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
191
+ self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
192
+ self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
193
+ self.origin_step_func = None
194
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过switch开关直接打开
195
+ self.config = load_json(config_file_path)
196
+ validate_config(self.config)
197
+
198
+ self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上
199
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
200
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
201
+ self.unique_id = str(uuid.uuid4())[:8]
202
+ self.output_base_dir = get_output_base_dir()
203
+ time_tags = self.config.get("append_output", [])
204
+ if dist.is_initialized():
205
+ self.rank = dist.get_rank()
206
+ if time_tags:
207
+ output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
208
+ if str(self.rank) in output_append_dirs:
209
+ self.tensorboard_dir = output_append_dirs[str(self.rank)]
210
+ logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}")
211
+ else:
212
+ self.tensorboard_dir = os.path.join(self.output_base_dir,
213
+ f"{cur_time}-rank{self.rank}-{self.unique_id}")
214
+ self.pp_stage = dist.get_group_rank(self.process_group, self.rank)
215
+ self.group_mates = dist.get_process_group_ranks(self.process_group)
216
+ else:
217
+ self.rank = 0
218
+ self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
219
+ self.pp_stage = 0
220
+ self.group_mates = [0]
221
+
222
+ # TYPE2: 只会在monitor_gnorm_with_ad()主调中赋值的变量
223
+ self.model = None
224
+ self.vpp = False
225
+ self.dp_group = None
226
+ self.tp_group = None
227
+ self.enable_megatron = False
228
+ self.micro_batch_number = 1
229
+
230
+ # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
231
+ self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
232
+ self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
233
+ self.optimizer_context = defaultdict(OptimizerContext)
234
+ self.cc_context = defaultdict(CommunicationContext)
235
+ self.grad_context = GradContext()
236
+ self.handles = defaultdict(list)
237
+ self.param2name = defaultdict(str)
238
+ self.name2index = defaultdict()
239
+ self.name2indices = defaultdict()
240
+ self.name2param = {}
241
+ self.duplicate_param = {}
242
+ self.name2tag = {}
243
+ self.param_name_call_id = {}
244
+ self.call_id = 0
245
+ self.module_struct = defaultdict(dict)
246
+ self.grad_accs = []
247
+ self.weight_hooked = False
248
+ self.optimizer_hooked = False
249
+ self.param_registered = False
250
+ self.struct_printed = False
251
+
252
+ # 动静态区分
253
+ self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
254
+ if self.dynamic_enable:
255
+ logger.warning(f"DYNAMIC_MONITOR is set, "
256
+ f"please make sure you have 'switch' and 'collect_times' item in {self.config_file_path}")
257
+ self.monitoring = False
258
+ else:
259
+ self.set_config()
260
+ # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
261
+ if self.collect_times > 0:
262
+ self.monitoring = True
263
+
264
+ def __del__(self):
265
+ if hasattr(self, "summary_writer"):
266
+ self.summary_writer.close()
267
+
268
+ @property
269
+ def ops(self):
270
+ return self._ops
271
+
272
+ @ops.setter
273
+ def ops(self, value):
274
+ self._ops = validate_ops(value)
275
+
276
+ @staticmethod
277
+ def set_wrapped_optimizer(_wrapped_optimizer):
278
+ OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
279
+
280
+ @staticmethod
281
+ def has_register_backward_hook(module_name, module):
282
+ if hasattr(module, '_backward_hooks') and \
283
+ len(module._backward_hooks) > 0 and \
284
+ module._is_full_backward_hook is False:
285
+ logger.warning(
286
+ f"The {module_name} has registered deprecated register_backward_hook,"
287
+ f"which may cause abnormal data dump. The backward input/output for this module will be skipped."
288
+ )
289
+ return True
290
+ return False
291
+
292
+ @staticmethod
293
+ def generate_cc_metrics(cc_name, cc_tensor):
294
+ metrics = defaultdict(dict)
295
+ rank = dist.get_rank() if dist.is_initialized() else None
296
+ for op, tag2tensor in cc_tensor.data.items():
297
+ for tag, tensor in tag2tensor.items():
298
+ key = get_summary_writer_tag_name(cc_name, tag, rank)
299
+ metrics[op].update({key: tensor})
300
+ cc_tensor.reset()
301
+ return metrics
302
+
303
+ def set_config(self):
304
+ logger.info(f"current config: {self.config}")
305
+ self.start_step = self.config.get("start_step", 0)
306
+ self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
307
+ self.step_interval = self.config.get("step_interval", 1)
308
+ self.has_collect_times = 0 # 重设采集计数器
309
+ self.print_struct = self.config.get("print_struct", False)
310
+ self.module_rank_list = self.config.get("module_ranks", [])
311
+ self.format = self.config.get('format', 'tensorboard')
312
+ self.eps = self.config.get('eps', 1e-8)
313
+ self.ops = self.config.get('ops', [])
314
+ self.ndigits = self.config.get('ndigits', 6)
315
+ self.all_xy = self.config.get('all_xy', False)
316
+ self.xy_distribution = self.config.get('xy_distribution', False)
317
+ self.forward_only = self.config.get('forward_only', False)
318
+ self.backward_only = self.config.get('backward_only', False)
319
+ self.ur_distribution = self.config.get('ur_distribution', False)
320
+ self.mv_distribution = self.config.get("mv_distribution", False)
321
+ self.wg_distribution = self.config.get("wg_distribution", False)
322
+ self.param_distribution = self.config.get("param_distribution", False)
323
+ self.mg_direction = self.config.get('mg_direction', False)
324
+ self.cc_distribution = self.config.get("cc_distribution", {})
325
+
326
+ if not self.cc_distribution.get('enable', False):
327
+ self.cc_log_only = False
328
+ else:
329
+ self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
330
+ self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
331
+ self.cc_logged_stack = defaultdict(set)
332
+ self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
333
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
334
+ api_register.redirect_api()
335
+
336
+ self.common_info()
337
+
338
+ # 初始化AnomalyData工厂
339
+ alert_setting = self.config.get('alert', {"rules": []})
340
+ self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
341
+ self.anomaly_data_factory = None
342
+ if alert_setting.get('dump', False):
343
+ self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
344
+
345
+ # 初始化writer, 创建输出目录
346
+ if self.format not in FORMAT_MAPPING:
347
+ raise ValueError(f"Unsupported format: {self.format}")
348
+ writer = FORMAT_MAPPING[self.format]
349
+ self.step_count_per_record = self.config.get('step_count_per_record', 1)
350
+
351
+ if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0:
352
+ self.summary_writer = writer(
353
+ WriterInput(
354
+ self.tensorboard_dir,
355
+ self.alert_rules,
356
+ self.unique_id,
357
+ self.anomaly_data_factory,
358
+ self.ndigits,
359
+ self.step_count_per_record
360
+ )
361
+ )
362
+ # 初始化anomaly detected文件目录
363
+ if self.anomaly_data_factory:
364
+ self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
365
+ self.rank)
366
+ self.anomaly_data_writer.init_detected_json()
367
+
368
+ def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
369
+ rank = None
370
+ if dist.is_initialized():
371
+ rank = dist.get_rank()
372
+ if (rank not in rank_list) and len(rank_list) != 0:
373
+ return
374
+ self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
375
+
376
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
377
+ key = get_summary_writer_tag_name(module_name, tag, self.rank)
378
+ self._register_param_call_id("_hook_module", key)
379
+ return {key: tensor}
380
+
381
+ def common_info(self):
382
+ if not self.xy_distribution:
383
+ logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
384
+ if self.forward_only:
385
+ logger.info_on_rank_0("> only module forward is monitored. ")
386
+ if not self.ur_distribution:
387
+ logger.info_on_rank_0("> update vector and ratio vector of adam is not monitored. ")
388
+ if not self.mv_distribution:
389
+ logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
390
+ if not self.wg_distribution:
391
+ logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
392
+ if not self.mg_direction:
393
+ logger.info_on_rank_0('> grad and momentum direction will not be compared.')
394
+ if not self.cc_distribution.get('enable', False):
395
+ logger.info_on_rank_0("> cc operator is not monitored.")
396
+ if not self.opt_ty:
397
+ if self.ur_distribution:
398
+ raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
399
+ if self.mv_distribution:
400
+ raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
401
+
402
+ def hook_modules(self):
403
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
404
+ return
405
+
406
+ targets = self.config['targets']
407
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
408
+ for key in module_in_all_stage:
409
+ struct = targets.pop(key)
410
+ targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
411
+
412
+ hooked_count = 0
413
+ for vpp_stage, model_chunk in enumerate(self.model):
414
+ vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
415
+ targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
416
+ 'targets'].keys()
417
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
418
+
419
+ logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
420
+
421
+ def clone_if_tensor(args):
422
+ if isinstance(args, tuple):
423
+ return tuple([clone_if_tensor(arg) for arg in args])
424
+ elif isinstance(args, torch.Tensor):
425
+ return args.clone()
426
+ else:
427
+ return args
428
+
429
+ @torch.no_grad
430
+ def wrap_hook_setup(setup):
431
+ def wrapped_setup(*args, **kwargs):
432
+ args = setup(*args, **kwargs)
433
+ args = clone_if_tensor(args)
434
+ return args
435
+
436
+ return wrapped_setup
437
+
438
+ BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
439
+
440
+ return
441
+
442
+ def generate_param_metrics(self, opt_context):
443
+ if not self.param_distribution:
444
+ return
445
+ get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
446
+
447
+ def generate_mv_metrics(self, opt_context):
448
+ if not self.mv_distribution:
449
+ return
450
+ opt_context.exp_avg_metric = {}
451
+ opt_context.exp_avg_sq_metric = {}
452
+ m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
453
+ v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
454
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
455
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
456
+
457
+ def generate_wgrad_metrics(self):
458
+ if not self.wg_distribution:
459
+ return {}, {}
460
+
461
+ if self.weight_hooked:
462
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
463
+
464
+ grad_dict = {}
465
+ for param, name in self.param2name.items():
466
+ if self.duplicate_param.get(name, False):
467
+ continue
468
+ grad = param.main_grad if self.params_have_main_grad else param.grad
469
+ if grad is None:
470
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
471
+ continue
472
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
473
+ self._register_param_call_id("hook_optimizer", tag)
474
+ grad_dict[tag] = grad
475
+
476
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
477
+ return self.grad_context.post, self.grad_context.pre
478
+
479
+ def monitor_gnorm_with_ad(
480
+ self,
481
+ model,
482
+ grad_acc_steps=1,
483
+ optimizer=None,
484
+ tp_group=None,
485
+ dp_group=None,
486
+ start_iteration=0
487
+ ):
488
+ """External interface"""
489
+ global start_step
490
+ start_step = start_iteration
491
+ logger.info(f'grad acc steps {grad_acc_steps}')
492
+ self.micro_batch_number = grad_acc_steps
493
+ self.dp_group = dp_group
494
+ self.tp_group = tp_group
495
+ self.hook_step_final(optimizer)
496
+ if not isinstance(model, list):
497
+ model = [model]
498
+ self.model = model
499
+ if len(model) > 1:
500
+ self.vpp = True
501
+ self._smallest_rank_print('vpp enabled')
502
+ if not self.dynamic_enable:
503
+ self.register_hooks(optimizer)
504
+
505
+ def register_hooks(self, optimizer):
506
+ self._register_param_name()
507
+ self.hook_optimizer(optimizer)
508
+ self._patch_grad_sync()
509
+ self.hook_modules()
510
+ self.monitoring = True
511
+
512
+ def generate_param_map(self, tag, param_tensor):
513
+ metrics = {}
514
+ for name in self.param2name.values():
515
+ key = get_summary_writer_tag_name(name, tag, self.rank)
516
+ self._register_param_call_id("optimizer_pre_step_hook", key)
517
+ if name not in param_tensor or param_tensor[name] is None:
518
+ continue
519
+ metrics[key] = param_tensor[name]
520
+ return metrics
521
+
522
+ def generate_xy_metrics(self):
523
+ actv = {}
524
+ for fwd_context in self.module_fwd_hook_context_by_module.values():
525
+ actv.update(fwd_context.actv)
526
+
527
+ actv_grad = self.grad_context.actv
528
+
529
+ return actv, actv_grad
530
+
531
+ def reload_xy(self, xy_distribution=False):
532
+ self.xy_distribution = xy_distribution
533
+
534
+ for handle in self.handles['xy']:
535
+ handle.remove()
536
+ self.handles['xy'].clear()
537
+ self.hook_modules()
538
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
539
+ fwd_context.actv.clear()
540
+
541
+ def write_adhoc_check(self, step):
542
+ self.tensor_metrics.flush(self.summary_writer)
543
+
544
+ def write_xy_tb(self, step):
545
+ if not self.xy_distribution:
546
+ return
547
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
548
+ if len(fwd_context.actv) == 0:
549
+ continue
550
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
551
+ fwd_context.actv.clear()
552
+ if self.grad_context.actv:
553
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
554
+
555
+ def write_param_tb(self, opt_context):
556
+ if not self.param_distribution:
557
+ return
558
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
559
+
560
+ def write_mv_tb(self, opt_context):
561
+ if not self.mv_distribution:
562
+ return
563
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
564
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
565
+
566
+ def write_grad_tb(self, step):
567
+ if not self.wg_distribution:
568
+ return
569
+
570
+ if self.enable_megatron:
571
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
572
+ else:
573
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
574
+ self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
575
+
576
+ def hook_optimizer(self, optimizer=None):
577
+ # in DDP by default use params_have_main_grad
578
+ def optimizer_pre_step_hook(optimizer, args, kwargs):
579
+ context = self.optimizer_context[optimizer]
580
+
581
+ if (self.print_struct and not all(value == {} for value in self.module_struct.values())
582
+ and not self.struct_printed):
583
+ self._save_module_struct()
584
+ if not self.cc_log_only:
585
+ raise Exception("exit after first monitor step when print model struct")
586
+ if self.cc_log_only and context.step > 0:
587
+ self._smallest_rank_print("> Used communication ops and corresponding stack")
588
+ self._smallest_rank_print(
589
+ json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
590
+ raise Exception("exit after first step when print cc stack")
591
+
592
+ # skip generate metrics
593
+ if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
594
+ return
595
+ if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
596
+ if not self.name2indices:
597
+ self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
598
+ self.name2index)
599
+ mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
600
+ self.name2indices)
601
+ self.param2name = mv_result.grad
602
+ else:
603
+ mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
604
+ context.param_exp_avg = mv_result.exp_avg
605
+ context.param_exp_avg_sq = mv_result.exp_avg_sq
606
+ context.param_adam_update = mv_result.update
607
+ context.param_adam_ratio = mv_result.ratio
608
+
609
+ self.generate_wgrad_metrics()
610
+ self.generate_mv_metrics(context)
611
+ self.generate_param_metrics(context)
612
+
613
+ tbtag_tensor_map = {}
614
+ if self.mg_direction:
615
+ for param, name in self.param2name.items():
616
+ grad = param.main_grad if self.params_have_main_grad else param.grad
617
+ if grad is None:
618
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
619
+ continue
620
+ if context.step == 0:
621
+ same_direction_ratio = torch.tensor(1.)
622
+ else:
623
+ same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name])
624
+ context.param_mg_direction[name] = same_direction_ratio
625
+ tbtag_tensor_map.update(self.generate_param_map('mg_direction', context.param_mg_direction))
626
+
627
+ metric_dict = {}
628
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, metric_dict)
629
+ for cc in self.cc_context.values():
630
+ cc.aggregate()
631
+ metric_dict.update(cc.data)
632
+ cc.reset()
633
+
634
+ if not metric_dict:
635
+ return
636
+ context.metric_dict = metric_dict
637
+ return
638
+
639
+ def patch_step(func, optimizer):
640
+ def wrapper(*args, **kwargs):
641
+ optimizer_pre_step_hook(optimizer, args, kwargs)
642
+ out = func(*args, **kwargs)
643
+ return out
644
+
645
+ return wrapper
646
+
647
+ if self.optimizer_hooked:
648
+ return
649
+
650
+ if optimizer:
651
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
652
+ self.handles['optimizer'] = []
653
+ else:
654
+ if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
655
+ step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
656
+ self.handles['optimizer'] = [step_pre_hook]
657
+ self.optimizer_hooked = True
658
+ return
659
+
660
+ def dynamic_monitor(self, optimizer):
661
+ """
662
+ If dynamic monitor enabled and config.json updated,
663
+ remove hooks and register new hooks according to new configuration.
664
+ """
665
+ context = self.optimizer_context[optimizer]
666
+ if not self.dynamic_enable:
667
+ return
668
+ try:
669
+ # 如果文件时间戳没变, 可以不读取节省时间
670
+ config_timestamp = os.path.getmtime(self.config_file_path)
671
+ if config_timestamp == self.config_timestamp:
672
+ return
673
+ # 更新config文件最新修改时间戳
674
+ self.config_timestamp = config_timestamp
675
+ config = load_json(self.config_file_path)
676
+ except Exception as e:
677
+ logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
678
+ return
679
+
680
+ if config.get("switch", False):
681
+ try:
682
+ validate_config(config)
683
+ self.config = config
684
+ self.set_config()
685
+ logger.warning(f"config is updated at step{context.step - 1}, "
686
+ f"will start new hook at step{context.step}.")
687
+ except Exception as e:
688
+ logger.error(f"set config wrong because {e}, not updated, please check!!!")
689
+ return
690
+
691
+ self._remove_all_hooks(optimizer)
692
+ self.register_hooks(optimizer)
693
+
694
+ def hook_step_final(self, optimizer):
695
+ def step_final_hook(optimizer, args, kwargs):
696
+ context = self.optimizer_context[optimizer]
697
+ rank = dist.get_rank() if dist.is_initialized() else None
698
+ # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
699
+ if self.monitoring:
700
+ module_rank_valid = not self.module_rank_list or (
701
+ dist.is_initialized() and dist.get_rank() in self.module_rank_list)
702
+ step_condition = (context.step >= self.start_step and (
703
+ context.step - self.start_step) % self.step_interval == 0)
704
+ if module_rank_valid and step_condition:
705
+ self.has_collect_times += 1
706
+
707
+ if self.anomaly_data_factory:
708
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
709
+ self.write_xy_tb(context.step)
710
+ self.write_grad_tb(context.step)
711
+ self.write_mv_tb(context)
712
+ self.write_param_tb(context)
713
+ self.write_adhoc_check(context.step)
714
+
715
+ if self.ur_distribution:
716
+ for param_name, _ in context.param_adam_update.items():
717
+ self.update_heatmap_visualizer[param_name].visualize(
718
+ get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step,
719
+ self.summary_writer)
720
+ for param_name, _ in context.param_adam_ratio.items():
721
+ self.ratio_heatmap_visualizer[param_name].visualize(
722
+ get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step,
723
+ self.summary_writer)
724
+
725
+ if context.metric_dict:
726
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
727
+ context.metric_dict.clear()
728
+
729
+ if self.anomaly_data_factory:
730
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
731
+ self.summary_writer.clear_anomalies()
732
+ self.call_id = 0
733
+ self.param_name_call_id.clear()
734
+
735
+ if self.has_collect_times >= self.collect_times:
736
+ self._remove_all_hooks_final(optimizer)
737
+
738
+ context.step += 1
739
+ self.dynamic_monitor(optimizer)
740
+
741
+ def patch_step(func, optimizer):
742
+ def wrapper(*args, **kwargs):
743
+ out = func(*args, **kwargs)
744
+ step_final_hook(optimizer, args, kwargs)
745
+ return out
746
+ return wrapper
747
+
748
+ if optimizer:
749
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
750
+ self.origin_step_func = optimizer.__class__.step
751
+ else:
752
+ register_optimizer_step_post_hook(step_final_hook)
753
+ return
754
+
755
+ def _remove_all_hooks(self, optimizer):
756
+ # 清空hook handle
757
+ for handle in self.handles['xy']:
758
+ handle.remove()
759
+ self.handles['xy'].clear()
760
+ # 清空对应context缓存
761
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
762
+ fwd_context.reset()
763
+ for _, bwd_context in self.module_bwd_hook_context_by_module.items():
764
+ bwd_context.reset()
765
+ self.grad_context.reset() # 权重梯度和激活值梯度都在这
766
+
767
+ for handle in self.handles['wgrads']:
768
+ handle.remove()
769
+ self.handles['wgrads'].clear()
770
+ self.weight_hooked = False
771
+
772
+ if len(self.handles['optimizer']) == 0 and self.optimizer_hooked:
773
+ optimizer.__class__.step = self.origin_step_func
774
+ else:
775
+ for handle in self.handles['optimizer']:
776
+ handle.remove()
777
+ self.handles['optimizer'].clear()
778
+ for _, context in self.optimizer_context.items():
779
+ context.reset()
780
+ self.optimizer_hooked = False
781
+
782
+ for handle in self.handles['cc']:
783
+ handle.remove()
784
+ self.handles['cc'].clear()
785
+ for _, context in self.cc_context.items():
786
+ context.reset()
787
+
788
+ # 清空节点缓存
789
+ self.param2name.clear()
790
+ self.name2index.clear()
791
+ self.name2indices.clear()
792
+ self.name2param.clear()
793
+ self.duplicate_param.clear()
794
+ self.name2tag.clear()
795
+ self.module_struct.clear()
796
+ self.grad_accs.clear()
797
+
798
+ # 关闭采集状态
799
+ self.monitoring = False
800
+
801
+ def _remove_all_hooks_final(self, optimizer):
802
+ if self.dynamic_enable:
803
+ # 结束后自动重置switch为False等待用户手动开启
804
+ try:
805
+ config = load_json(self.config_file_path)
806
+ config['switch'] = False
807
+ save_json(self.config_file_path, config, indent=2)
808
+ config_timestamp = os.path.getmtime(self.config_file_path)
809
+ self.config_timestamp = config_timestamp
810
+ logger.info(
811
+ "Finish monitor, set config'switch=False, will restart by set switch=True and update content")
812
+ except Exception as e:
813
+ logger.warning(f"Finish monitor, set config'switch=False fail because {e}, please check!!!")
814
+ logger.info("Finish monitor")
815
+ self._remove_all_hooks(optimizer)
816
+
817
+ def _smallest_rank_print(self, msg):
818
+ if dist.is_initialized():
819
+ if self.module_rank_list:
820
+ if dist.get_rank() == min(self.module_rank_list):
821
+ logger.info(msg)
822
+ else:
823
+ if dist.get_rank() == 0:
824
+ logger.info(msg)
825
+ else:
826
+ logger.info(msg)
827
+
828
+ def _save_module_struct(self):
829
+ save_module_struct = (not dist.is_initialized()
830
+ or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
831
+ or (not self.module_rank_list and dist.get_rank() == 0))
832
+
833
+ if save_module_struct:
834
+ module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
835
+ save_json(module_struct_file, self.module_struct, indent=2)
836
+ logger.info(f"> save module struct to {module_struct_file}")
837
+ self.struct_printed = True
838
+
839
+ def _is_target_param(self, param_name, param, prefix):
840
+ name = prefix + param_name
841
+ squash_name = prefix + squash_param_name(param_name, self.squash_name)
842
+ for target in self.config['targets'].keys():
843
+ if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
844
+ setattr(param, "zero_out_wgrad", True)
845
+ return True
846
+
847
+ return False
848
+
849
+ def _register_chunk(self, model_chunk, prefix):
850
+ index = 0
851
+ for (param_name, param) in model_chunk.named_parameters():
852
+ if not param.requires_grad:
853
+ continue
854
+ if self._is_target_param(param_name, param, prefix):
855
+ name = prefix + squash_param_name(param_name, self.squash_name)
856
+ if name in self.param2name.values():
857
+ name = prefix + param_name
858
+ self.param2name[param] = name
859
+ self.name2param[name] = param
860
+ self.name2index[name] = index
861
+
862
+ if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
863
+ self.duplicate_param[name] = True
864
+ if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
865
+ self.duplicate_param[name] = True
866
+ self.name2tag[name] = {
867
+ MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
868
+ MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
869
+ }
870
+ index += 1
871
+
872
+ def _register_param_name(self):
873
+ for vpp_stage, model_chunk in enumerate(self.model):
874
+ prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
875
+ self._register_chunk(model_chunk, prefix)
876
+
877
+ def _is_target_module(self, module_name, targets, vpp_stage):
878
+ if self.all_xy or self.print_struct:
879
+ return vpp_stage + squash_param_name(module_name, self.squash_name)
880
+ for pattern in [
881
+ vpp_stage + squash_param_name(module_name, self.squash_name),
882
+ vpp_stage + module_name,
883
+ ]:
884
+ if pattern in targets:
885
+ return pattern
886
+ return ""
887
+
888
+ def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
889
+ if '_modules' not in module.__dict__:
890
+ # nothing to hook
891
+ return 0
892
+
893
+ def fwd_hook_fun(module, module_input, module_output, name):
894
+ if not module.training or is_recomputation():
895
+ # 1 only monitor training stage.
896
+ # 2 when open recompute, skip recomputed forward stage.
897
+ return
898
+ if module not in self.module_fwd_hook_context_by_module:
899
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
900
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
901
+ if not context.struct:
902
+ context.struct = {
903
+ MonitorConst.ACTV_IN: get_param_struct(module_input),
904
+ MonitorConst.ACTV_OUT: get_param_struct(module_output)
905
+ }
906
+ if self.print_struct:
907
+ self.module_struct[context.module_name].update(context.struct)
908
+ return
909
+ if not context.format_by_arg:
910
+ context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
911
+ context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
912
+ if not context.format_by_arg:
913
+ return
914
+ if not context.verified:
915
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
916
+ module_input, context.module_name,
917
+ MonitorConst.ACTV_IN)
918
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
919
+ module_output, context.module_name,
920
+ MonitorConst.ACTV_OUT)
921
+ context.verified = True
922
+ # expect output be tensor type
923
+ tbtag_tensor_map = {}
924
+ cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
925
+ tbtag_tensor_map.update(
926
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
927
+ cared_input))
928
+ cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
929
+ tbtag_tensor_map.update(
930
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
931
+ cared_output))
932
+
933
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
934
+ context.micro_step += 1
935
+ if context.micro_step == self.micro_batch_number:
936
+ context.micro_step = 0
937
+ return
938
+
939
+ def bwd_hook_fun(module, input_grad, output_grad):
940
+ context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
941
+ if not context.struct:
942
+ context.struct = {
943
+ MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
944
+ MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
945
+ }
946
+ if self.print_struct:
947
+ self.module_struct[context.module_name].update(context.struct)
948
+ return
949
+ if not context.format_by_arg:
950
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
951
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
952
+ if not context.format_by_arg:
953
+ return
954
+ if not context.verified:
955
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
956
+ input_grad, context.module_name,
957
+ MonitorConst.ACTVGRAD_IN)
958
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
959
+ output_grad, context.module_name,
960
+ MonitorConst.ACTVGRAD_OUT)
961
+ context.verified = True
962
+
963
+ tbtag_tensor_map = {}
964
+ cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
965
+ tbtag_tensor_map.update(
966
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN,
967
+ cared_input_grad))
968
+ cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
969
+ tbtag_tensor_map.update(
970
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
971
+ cared_output_grad))
972
+
973
+ if context.micro_step == 0 and context.actvgrad:
974
+ logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
975
+ f"maybe something wrong happened. Now clear it.")
976
+ context.actvgrad.clear()
977
+
978
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
979
+
980
+ context.micro_step += 1
981
+ if context.micro_step == self.micro_batch_number:
982
+ context.micro_step = 0
983
+ return
984
+
985
+ if self.backward_only and self.forward_only:
986
+ logger.warning('not enable backward_only and forward_only simultaneously')
987
+
988
+ hooked_count = 0
989
+ if self.xy_distribution or self.print_struct:
990
+ for module_name, submodule in module.named_modules():
991
+ name = self._is_target_module(module_name, target_names, vpp_stage)
992
+ if not name:
993
+ continue
994
+ if not self.backward_only:
995
+ handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
996
+ self.handles['xy'].append(handle)
997
+ if not self.forward_only and not self.has_register_backward_hook(name, submodule):
998
+ handle = submodule.register_full_backward_hook(bwd_hook_fun)
999
+ self.handles['xy'].append(handle)
1000
+ self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
1001
+ logger.info_on_rank_0(f"> {name} is monitored successfully")
1002
+ hooked_count += 1
1003
+ return hooked_count
1004
+
1005
+ def _patch_grad_sync(self):
1006
+ def patch_sync(sync_grad_func):
1007
+ def wrapper(bucket):
1008
+ grad_dict = {}
1009
+ bucket_params_id_list = [id(params) for params in bucket.params_list]
1010
+ for param, name in self.param2name.items():
1011
+ if id(param) not in bucket_params_id_list:
1012
+ continue
1013
+ grad = param.main_grad if self.params_have_main_grad else param.grad
1014
+ if grad is None:
1015
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
1016
+ continue
1017
+ tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
1018
+ if tag is None:
1019
+ continue
1020
+ grad_dict[tag] = grad
1021
+ self._register_param_call_id("sync_grad_func", tag)
1022
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1023
+ out = sync_grad_func(bucket)
1024
+ return out
1025
+
1026
+ return wrapper
1027
+
1028
+ try:
1029
+ from megatron.core.distributed.param_and_grad_buffer import Bucket
1030
+ self.enable_megatron = True
1031
+ except ImportError:
1032
+ self.enable_megatron = False
1033
+
1034
+ if not self.wg_distribution:
1035
+ return
1036
+
1037
+ if self.enable_megatron:
1038
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
1039
+ else:
1040
+ self._hook_weights()
1041
+
1042
+ def _hook_weights(self):
1043
+ context = self.grad_context
1044
+
1045
+ @torch.no_grad
1046
+ def param_hook(*args, context_dict, param, key, name):
1047
+ param.micro_step += 1
1048
+ self._register_param_call_id("param_hook", key)
1049
+ if param.micro_step == self.micro_batch_number:
1050
+ param.micro_step = 0
1051
+ if self.params_have_main_grad:
1052
+ context_dict[key] = param.main_grad.clone()
1053
+ else:
1054
+ context_dict[key] = param.grad.clone()
1055
+
1056
+ for param, name in self.param2name.items():
1057
+ key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
1058
+ setattr(param, 'micro_step', 0)
1059
+ param_tmp = param.expand_as(param)
1060
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
1061
+ handle = grad_acc.register_hook(
1062
+ partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
1063
+ self.grad_accs.append(grad_acc)
1064
+ self.handles['wgrads'].append(handle)
1065
+
1066
+ self.weight_hooked = True
1067
+
1068
+ def _register_param_call_id(self, hook_name: str, key: str):
1069
+ """
1070
+ :param hook_name:
1071
+ :param key: str, '0:relu_0/output_grad'
1072
+ :return:
1073
+ """
1074
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
1075
+ self.param_name_call_id[key] = self.call_id
1076
+ self.call_id += 1