mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,870 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import time
16
+ import json
17
+ import os
18
+ import uuid
19
+ from collections import defaultdict
20
+ from datetime import datetime, timezone
21
+ from functools import partial
22
+
23
+ import pytz
24
+ import torch
25
+ import torch.distributed as dist
26
+ from msprobe.core.common.const import MonitorConst
27
+ from msprobe.core.common.file_utils import load_json
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
30
+ from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
31
+ CSVWriterWithAD, BaseWriterWithAD, WriterInput
32
+ from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
33
+ get_process_group
34
+ from msprobe.pytorch.monitor.features import get_sign_matches
35
+ from msprobe.pytorch.monitor.module_metric import get_metrics, write_metrics_base, get_summary_writer_tag_name, \
36
+ TensorMetrics, write_metrics_csv, squash_param_name
37
+ from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
38
+ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
39
+ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation
40
+ from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
41
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
42
+ from torch.utils.hooks import BackwardHook
43
+
44
+ try:
45
+ import torch_npu
46
+ except ImportError:
47
+ pass
48
+
49
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
50
+ if not torch_version_above_or_equal_2:
51
+ raise ValueError("monitor require torch>=2.0")
52
+
53
+ output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
54
+
55
+ FORMAT_MAPPING = {
56
+ MonitorConst.TENSORBOARD: (SummaryWriterWithAD, write_metrics_base),
57
+ MonitorConst.CSV: (CSVWriterWithAD, write_metrics_csv),
58
+ MonitorConst.API: (BaseWriterWithAD, write_metrics_base)
59
+ }
60
+
61
+
62
+ def param_is_not_tensor_parallel_duplicate(param, tp_group):
63
+ return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
64
+ torch.distributed.get_rank(group=tp_group) == 0
65
+ )
66
+
67
+
68
+ def param_is_data_parallel_duplicate(dp_group):
69
+ return torch.distributed.get_rank(group=dp_group) != 0
70
+
71
+
72
+ class ModuleHookContext:
73
+ def __init__(self, module_name) -> None:
74
+ self.step = 0
75
+ self.micro_step = 0
76
+ self.actv = defaultdict(dict)
77
+ self.actvgrad = []
78
+ self.module_name = module_name
79
+ self.struct = {}
80
+ self.format_by_arg = {}
81
+ self.verified = False
82
+ self.focused_in_col = 0
83
+ self.focused_out_col = 0
84
+ self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
85
+
86
+ def set_format_by_arg(self, key_name: str, target_config: dict):
87
+ cared = target_config.get(self.module_name, self.struct)
88
+ if key_name in cared:
89
+ if isinstance(cared[key_name], dict):
90
+ # current cared is self.struct
91
+ config = cared[key_name].get('config')
92
+ self.format_by_arg[key_name] = config
93
+ else:
94
+ # current cared is target_config[self.module_name]
95
+ self.format_by_arg[key_name] = cared[key_name]
96
+ elif key_name in ['input', 'input_grad']:
97
+ self.ignore_in = True
98
+
99
+
100
+ class OptimizerContext:
101
+ def __init__(self) -> None:
102
+ self.step = 0
103
+ self.param_effective_rank = defaultdict(float)
104
+ self.param_mg_direction = defaultdict(float)
105
+ self.param_adam_update = defaultdict()
106
+ self.param_adam_ratio = defaultdict()
107
+ self.param_weight_grad = defaultdict()
108
+ self.param_exp_avg = defaultdict()
109
+ self.exp_avg_metric = {}
110
+ self.param_exp_avg_sq = defaultdict()
111
+ self.exp_avg_sq_metric = {}
112
+ self.metric_dict = {}
113
+ self.param_metric = {}
114
+
115
+
116
+ class CommunicationContext:
117
+ def __init__(self) -> None:
118
+ self.data = {}
119
+
120
+ @staticmethod
121
+ def _agg(data):
122
+ aggregated_data = {}
123
+ for tag, op2tensorlist in data.items():
124
+ aggregated_data[tag] = {}
125
+ for op, tensorlist in op2tensorlist.items():
126
+ aggregated_data[tag][op] = op_aggregate(op, tensorlist)
127
+ return aggregated_data
128
+
129
+ def reset(self):
130
+ self.data = {}
131
+
132
+ def aggregate(self):
133
+ self.data = self._agg(self.data)
134
+
135
+
136
+ class GradContext:
137
+ def __init__(self) -> None:
138
+ self.pre = {}
139
+ self.post = {}
140
+ self.acc_metric = {}
141
+ self.acc = {}
142
+ self.actv = {}
143
+
144
+ def reset(self):
145
+ self.pre.clear()
146
+ self.post.clear()
147
+ self.acc_metric.clear()
148
+ self.acc.clear()
149
+ self.actv.clear()
150
+
151
+
152
+ class TrainerMon:
153
+ tensor_metrics = TensorMetrics()
154
+
155
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
156
+ """
157
+ opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
158
+ """
159
+ self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
160
+ self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
161
+ self.optimizer_context = defaultdict(OptimizerContext)
162
+ self.cc_context = defaultdict(CommunicationContext)
163
+ self.grad_context = GradContext()
164
+ self.process_group = get_process_group(process_group)
165
+ self.params_have_main_grad = params_have_main_grad
166
+ self.opt_ty = opt_ty
167
+ self.config = load_json(config_file_path)
168
+ validate_config(self.config)
169
+
170
+ self.module_rank_list = self.config.get("module_ranks", [])
171
+ self.format = self.config.get('format', 'tensorboard')
172
+ self.eps = self.config.get('eps', 1e-8)
173
+ self.ops = self.config.get('ops', [])
174
+ self.ndigits = self.config.get('ndigits', 6)
175
+ self.all_xy = self.config.get('all_xy', False)
176
+ self.xy_distribution = self.config.get('xy_distribution', False)
177
+ self.forward_only = self.config.get('forward_only', False)
178
+ self.backward_only = self.config.get('backward_only', False)
179
+ self.ur_distribution = self.config.get('ur_distribution', False)
180
+ self.mv_distribution = self.config.get("mv_distribution", False)
181
+ self.wg_distribution = self.config.get("wg_distribution", False)
182
+ self.param_distribution = self.config.get("param_distribution", False)
183
+ self.mg_direction = self.config.get('mg_direction', False)
184
+ self.cc_distribution = self.config.get("cc_distribution", {})
185
+ if not self.cc_distribution.get('enable', False):
186
+ self.cc_log_only = False
187
+ else:
188
+ self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
189
+ self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
190
+ self.cc_logged_stack = defaultdict(set)
191
+ self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
192
+ api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
193
+ api_register.redirect_api()
194
+
195
+ self.common_info()
196
+
197
+ alert_setting = self.config.get('alert', {"rules": []})
198
+ self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
199
+
200
+ # 设置时区,使用 'UTC' 作为示例
201
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
202
+
203
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
204
+ unique_id = str(uuid.uuid4())[:8]
205
+
206
+ if dist.is_initialized():
207
+ rank = dist.get_rank()
208
+ tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
209
+ pp_stage = dist.get_group_rank(self.process_group, rank)
210
+ group_mates = dist.get_process_group_ranks(self.process_group)
211
+ else:
212
+ rank = 0
213
+ tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
214
+ pp_stage = 0
215
+ group_mates = [0]
216
+ self.rank = rank
217
+
218
+ # 初始化AnomalyData工厂
219
+ self.anomaly_data_factory = None
220
+ if alert_setting.get('dump', False):
221
+ self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
222
+
223
+ if self.format not in FORMAT_MAPPING:
224
+ raise ValueError(f"Unsupported format: {self.format}")
225
+ writer, self.write_metrics = FORMAT_MAPPING[self.format]
226
+ self.step_count_per_record = self.config.get('step_count_per_record', 1)
227
+
228
+ if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
229
+ self.summary_writer = writer(
230
+ WriterInput(
231
+ tensorboard_dir,
232
+ self.alert_rules,
233
+ unique_id,
234
+ None,
235
+ self.anomaly_data_factory,
236
+ self.ndigits,
237
+ self.step_count_per_record
238
+ )
239
+ )
240
+ # 初始化anomaly detected文件目录
241
+ if self.anomaly_data_factory:
242
+ self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank)
243
+ self.anomaly_data_writer.init_detected_json()
244
+
245
+ # A HeatmapVisualizer instance is associated with an image
246
+ self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
247
+ self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
248
+ self.micro_batch_number = 1
249
+
250
+ self.model = None
251
+ self.weight_hooked = False
252
+ self.optimizer_hooked = False
253
+ self.param_registered = False
254
+ self.vpp = False
255
+ self.dp_group = None
256
+ self.tp_group = None
257
+ self.enable_megatron = False
258
+
259
+ self.param2name = defaultdict(str)
260
+ self.name2index = defaultdict()
261
+ self.name2indices = defaultdict()
262
+ self.name2param = {}
263
+ self.param_name_call_id = {}
264
+ self.duplicate_param = {}
265
+ self.name2tag = {}
266
+ self.call_id = 0
267
+ self.grad_accs = []
268
+ self.handles = defaultdict(list)
269
+
270
+ self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
271
+ self.print_struct = self.config.get("print_struct", False)
272
+ self.struct_printed = False
273
+ self.module_struct = {}
274
+
275
+ def __del__(self):
276
+ if hasattr(self, "summary_writer"):
277
+ self.summary_writer.close()
278
+
279
+ @property
280
+ def ops(self):
281
+ return self._ops
282
+
283
+ @ops.setter
284
+ def ops(self, value):
285
+ self._ops = validate_ops(value)
286
+
287
+ @staticmethod
288
+ def set_wrapped_optimizer(_wrapped_optimizer):
289
+ OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
290
+
291
+ @staticmethod
292
+ def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
293
+ rank = None
294
+ if dist.is_initialized():
295
+ rank = dist.get_rank()
296
+ if (rank not in rank_list) and len(rank_list) != 0:
297
+ return
298
+ TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
299
+
300
+ @staticmethod
301
+ def build_tbtag_tensor_map(module_name, tag, tensor):
302
+ metrics = {}
303
+ rank = dist.get_rank() if dist.is_initialized() else None
304
+ key = get_summary_writer_tag_name(module_name, tag, rank)
305
+ if torch.is_tensor(tensor):
306
+ metrics[key] = tensor
307
+ return metrics
308
+
309
+ @staticmethod
310
+ def generate_cc_metrics(cc_name, cc_tensor):
311
+ metrics = defaultdict(dict)
312
+ rank = dist.get_rank() if dist.is_initialized() else None
313
+ for op, tag2tensor in cc_tensor.data.items():
314
+ for tag, tensor in tag2tensor.items():
315
+ key = get_summary_writer_tag_name(cc_name, tag, rank)
316
+ metrics[op].update({key: tensor})
317
+ cc_tensor.reset()
318
+ return metrics
319
+
320
+ def common_info(self):
321
+ if not self.xy_distribution:
322
+ logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
323
+ if self.forward_only:
324
+ logger.info_on_rank_0("> only module forward is monitored. ")
325
+ if not self.ur_distribution:
326
+ logger.info_on_rank_0("> update vector and ratio vector of adam is not monitored. ")
327
+ if not self.mv_distribution:
328
+ logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
329
+ if not self.wg_distribution:
330
+ logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
331
+ if not self.mg_direction:
332
+ logger.info_on_rank_0('> grad and momentum direction will not be compared.')
333
+ if not self.cc_distribution.get('enable', False):
334
+ logger.info_on_rank_0("> cc operator is not monitored.")
335
+ if not self.opt_ty:
336
+ if self.ur_distribution:
337
+ raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
338
+ if self.mv_distribution:
339
+ raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
340
+
341
+ def hook_modules(self, model: torch.nn.Module, grad_acc_steps):
342
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
343
+ return
344
+
345
+ if not isinstance(model, list):
346
+ model = [model]
347
+ self.model = model
348
+ self._register_param_name(model)
349
+
350
+ self.micro_batch_number = grad_acc_steps
351
+
352
+ targets = self.config['targets']
353
+ module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
354
+ for key in module_in_all_stage:
355
+ struct = targets.pop(key)
356
+ targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(model))})
357
+
358
+ hooked_count = 0
359
+ for vpp_stage, model_chunk in enumerate(model):
360
+ vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
361
+ targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
362
+ 'targets'].keys()
363
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
364
+
365
+ logger.info_on_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.")
366
+
367
+ def clone_if_tensor(args):
368
+ if isinstance(args, tuple):
369
+ return tuple([clone_if_tensor(arg) for arg in args])
370
+ elif isinstance(args, torch.Tensor):
371
+ return args.clone()
372
+ else:
373
+ return args
374
+
375
+ @torch.no_grad
376
+ def wrap_hook_setup(setup):
377
+ def wrapped_setup(*args, **kwargs):
378
+ args = setup(*args, **kwargs)
379
+ args = clone_if_tensor(args)
380
+ return args
381
+
382
+ return wrapped_setup
383
+
384
+ BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
385
+
386
+ if not self.optimizer_hooked:
387
+ self.hook_optimizer()
388
+ return
389
+
390
+ def generate_param_metrics(self, opt_context):
391
+ get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
392
+
393
+ def generate_mv_metrics(self, opt_context):
394
+ if not self.mv_distribution:
395
+ return
396
+ opt_context.exp_avg_metric = {}
397
+ opt_context.exp_avg_sq_metric = {}
398
+ m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
399
+ v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
400
+ get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
401
+ get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
402
+
403
+ def generate_wgrad_metrics(self):
404
+ if not self.wg_distribution:
405
+ return {}, {}
406
+
407
+ if self.weight_hooked:
408
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
409
+
410
+ grad_dict = {}
411
+ for param, name in self.param2name.items():
412
+ if self.duplicate_param.get(name, False):
413
+ continue
414
+ grad = param.main_grad if self.params_have_main_grad else param.grad
415
+ if grad is None:
416
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
417
+ continue
418
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
419
+ grad_dict[tag] = grad
420
+
421
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
422
+ return self.grad_context.post, self.grad_context.pre
423
+
424
+ def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None):
425
+ """External interface"""
426
+ logger.info(f'grad acc steps {grad_acc_steps}')
427
+ self.hook_optimizer(optimizer)
428
+ self.micro_batch_number = grad_acc_steps
429
+
430
+ self.dp_group = dp_group
431
+ self.tp_group = tp_group
432
+
433
+ self._register_param_name(model)
434
+ self._patch_grad_sync()
435
+ self.hook_modules(model, grad_acc_steps)
436
+
437
+ def generate_param_map(self, tag, param_tensor):
438
+ metrics = {}
439
+ rank = dist.get_rank() if dist.is_initialized() else None
440
+ for name in self.param2name.values():
441
+ key = get_summary_writer_tag_name(name, tag, rank)
442
+ if name not in param_tensor or param_tensor[name] is None:
443
+ continue
444
+ metrics[key] = param_tensor[name]
445
+ return metrics
446
+
447
+ def generate_xy_metrics(self):
448
+ actv = {}
449
+ for fwd_context in self.module_fwd_hook_context_by_module.values():
450
+ actv.update(fwd_context.actv)
451
+
452
+ actv_grad = self.grad_context.actv
453
+
454
+ return actv, actv_grad
455
+
456
+ def reload_xy(self, xy_distribution=False):
457
+ self.xy_distribution = xy_distribution
458
+
459
+ for handle in self.handles['xy']:
460
+ handle.remove()
461
+ self.handles['xy'].clear()
462
+ self.hook_modules(self.model, self.micro_batch_number)
463
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
464
+ fwd_context.actv.clear()
465
+
466
+ def write_adhoc_check(self, step):
467
+ TrainerMon.tensor_metrics.flush(self.summary_writer)
468
+
469
+ def write_xy_tb(self, step):
470
+ if not self.xy_distribution:
471
+ return
472
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
473
+ if len(fwd_context.actv) == 0:
474
+ continue
475
+ self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv')
476
+ fwd_context.actv.clear()
477
+ if self.grad_context.actv:
478
+ self.write_metrics(self.ops, self.summary_writer, self.grad_context.actv, step, 'actv_grad')
479
+
480
+ def write_param_tb(self, opt_context):
481
+ if not self.param_distribution:
482
+ return
483
+ self.write_metrics(self.ops, self.summary_writer, opt_context.param_metric, opt_context.step, 'param')
484
+
485
+ def write_mv_tb(self, opt_context):
486
+ if not self.mv_distribution:
487
+ return
488
+ self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_metric,
489
+ opt_context.step, 'exp_avg')
490
+ self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_sq_metric,
491
+ opt_context.step, 'exp_avg_sq')
492
+
493
+ def write_grad_tb(self, step):
494
+ if not self.wg_distribution:
495
+ return
496
+
497
+ if self.enable_megatron:
498
+ self.write_metrics(self.ops, self.summary_writer, self.grad_context.pre, step, 'grad_unreduced')
499
+ else:
500
+ self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced')
501
+ self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced')
502
+
503
+ def hook_optimizer(self, optimizer=None):
504
+ # in DDP by default use params_have_main_grad
505
+ def optimizer_pre_step_hook(optimizer, args, kwargs):
506
+ context = self.optimizer_context[optimizer]
507
+ if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
508
+ if context.step == 0:
509
+ self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
510
+ self.name2index)
511
+ mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
512
+ self.name2indices)
513
+ self.param2name = mv_result.grad
514
+ else:
515
+ mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
516
+ context.param_exp_avg = mv_result.exp_avg
517
+ context.param_exp_avg_sq = mv_result.exp_avg_sq
518
+ context.param_adam_update = mv_result.update
519
+ context.param_adam_ratio = mv_result.ratio
520
+
521
+ if (self.print_struct and not all(value == {} for value in self.module_struct.values())
522
+ and not self.struct_printed):
523
+ self._smallest_rank_print("> module struct:")
524
+ self._smallest_rank_print(json.dumps(self.module_struct))
525
+ self.struct_printed = True
526
+ if not self.cc_log_only:
527
+ raise Exception("exit after first step when print model struct")
528
+ if self.cc_log_only and context.step > 0:
529
+ self._smallest_rank_print("> Used communication ops and corresponding stack")
530
+ self._smallest_rank_print(
531
+ json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
532
+ raise Exception("exit after first step when print cc stack")
533
+
534
+ self.generate_wgrad_metrics()
535
+ self.generate_mv_metrics(context)
536
+ self.generate_param_metrics(context)
537
+
538
+ tbtag_tensor_map = {}
539
+ if self.mg_direction:
540
+ for param, name in self.param2name.items():
541
+ grad = param.main_grad if self.params_have_main_grad else param.grad
542
+ if grad is None:
543
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
544
+ continue
545
+ if context.step == 0:
546
+ same_direction_ratio = torch.tensor(1.)
547
+ else:
548
+ same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name])
549
+ context.param_mg_direction[name] = same_direction_ratio
550
+ tbtag_tensor_map.update(self.generate_param_map('mg_direction', context.param_mg_direction))
551
+
552
+ metric_dict = {}
553
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, metric_dict)
554
+ for cc in self.cc_context.values():
555
+ cc.aggregate()
556
+ metric_dict.update(cc.data)
557
+ cc.reset()
558
+
559
+ if not metric_dict:
560
+ return
561
+ context.metric_dict = metric_dict
562
+ return
563
+
564
+ def optimizer_post_step_hook(optimizer, args, kwargs):
565
+ context = self.optimizer_context[optimizer]
566
+ rank = dist.get_rank() if dist.is_initialized() else None
567
+
568
+ if self.anomaly_data_factory:
569
+ self.anomaly_data_factory.set_call_id(self.param_name_call_id)
570
+ self.write_xy_tb(context.step)
571
+ self.write_grad_tb(context.step)
572
+ self.write_mv_tb(context)
573
+ self.write_param_tb(context)
574
+ self.write_adhoc_check(context.step)
575
+
576
+ if self.ur_distribution:
577
+ for param_name, _ in context.param_adam_update.items():
578
+ self.update_heatmap_visualizer[param_name].visualize(
579
+ get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
580
+ for param_name, _ in context.param_adam_ratio.items():
581
+ self.ratio_heatmap_visualizer[param_name].visualize(
582
+ get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
583
+
584
+ if context.metric_dict:
585
+ self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other')
586
+ context.metric_dict.clear()
587
+ context.step += 1
588
+ if self.anomaly_data_factory:
589
+ self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
590
+ self.summary_writer.clear_anomalies()
591
+ self.call_id = 0
592
+ return
593
+
594
+ def patch_step(func, optimizer):
595
+ def wrapper(*args, **kwargs):
596
+ optimizer_pre_step_hook(optimizer, args, kwargs)
597
+ out = func(*args, **kwargs)
598
+ optimizer_post_step_hook(optimizer, args, kwargs)
599
+ return out
600
+
601
+ return wrapper
602
+
603
+ if self.optimizer_hooked:
604
+ return
605
+
606
+ if optimizer:
607
+ optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
608
+
609
+ else:
610
+ if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
611
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
612
+ register_optimizer_step_post_hook(optimizer_post_step_hook)
613
+ self.optimizer_hooked = True
614
+ return
615
+
616
+ def _smallest_rank_print(self, msg):
617
+ if dist.is_initialized():
618
+ if self.module_rank_list:
619
+ if dist.get_rank() == min(self.module_rank_list):
620
+ logger.info(msg)
621
+ else:
622
+ if dist.get_rank() == 0:
623
+ logger.info(msg)
624
+ else:
625
+ logger.info(msg)
626
+
627
+ def _is_target_param(self, param_name, param, prefix):
628
+ squash_name = prefix + squash_param_name(param_name)
629
+ name = prefix + param_name
630
+ for target in self.config['targets'].keys():
631
+ if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
632
+ setattr(param, "zero_out_wgrad", True)
633
+ return True
634
+
635
+ return False
636
+
637
+ def _register_chunk(self, model_chunk, prefix):
638
+ for index, (param_name, param) in enumerate(model_chunk.named_parameters()):
639
+ if not param.requires_grad:
640
+ continue
641
+ if self._is_target_param(param_name, param, prefix):
642
+ name = prefix + squash_param_name(param_name)
643
+ if name in self.param2name.values():
644
+ logger.error(f'same name {name} for different param. Current param is {param_name}. \
645
+ May be error of squash_param_name')
646
+ raise Exception("param with same name will be overwritten.")
647
+ self.param2name[param] = name
648
+ self.name2param[name] = param
649
+ self.name2index[name] = index
650
+
651
+ if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
652
+ self.duplicate_param[name] = True
653
+ if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
654
+ self.duplicate_param[name] = True
655
+ self.name2tag[name] = {}
656
+ self.name2tag[name][MonitorConst.PRE_GRAD] = get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD,
657
+ self.rank)
658
+ self.name2tag[name][MonitorConst.POST_GRAD] = get_summary_writer_tag_name(name, MonitorConst.POST_GRAD,
659
+ self.rank)
660
+
661
+ def _register_param_name(self, model):
662
+ if self.param_registered:
663
+ return
664
+
665
+ if not isinstance(model, list):
666
+ model = [model]
667
+
668
+ if len(model) > 1:
669
+ self.vpp = True
670
+ self._smallest_rank_print('vpp enabled')
671
+
672
+ for vpp_stage, model_chunk in enumerate(model):
673
+ prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
674
+ self._register_chunk(model_chunk, prefix)
675
+
676
+ self.param_registered = True
677
+
678
+ def _is_target_module(self, module_name, targets, vpp_stage):
679
+ if self.all_xy or self.print_struct:
680
+ return vpp_stage + squash_param_name(module_name)
681
+ for pattern in [
682
+ vpp_stage + squash_param_name(module_name),
683
+ vpp_stage + module_name,
684
+ ]:
685
+ if pattern in targets:
686
+ return pattern
687
+ return ""
688
+
689
+ def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
690
+ if '_modules' not in module.__dict__:
691
+ # nothing to hook
692
+ return 0
693
+
694
+ def fwd_hook_fun(module, module_input, module_output, name):
695
+ if is_recomputation():
696
+ return
697
+ if module not in self.module_fwd_hook_context_by_module:
698
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
699
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
700
+ if not context.struct:
701
+ context.struct = {MonitorConst.ACTV_IN: get_param_struct(module_input),
702
+ MonitorConst.ACTV_OUT: get_param_struct(module_output)}
703
+ if self.print_struct:
704
+ if context.module_name not in self.module_struct:
705
+ self.module_struct[context.module_name] = {}
706
+ self.module_struct[context.module_name].update(context.struct)
707
+ return
708
+ if not module.training:
709
+ return
710
+ if not context.format_by_arg:
711
+ context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
712
+ context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
713
+ if not context.format_by_arg:
714
+ return
715
+ if not context.verified:
716
+ if not context.ignore_in:
717
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
718
+ module_input, context.module_name,
719
+ MonitorConst.ACTV_IN)
720
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
721
+ module_output, context.module_name,
722
+ MonitorConst.ACTV_OUT)
723
+ context.verified = True
724
+ # expect output be tensor type
725
+ tbtag_tensor_map = {}
726
+ if not context.ignore_in:
727
+ cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
728
+ tbtag_tensor_map.update(
729
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
730
+ cared_input))
731
+ cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
732
+ tbtag_tensor_map.update(
733
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
734
+ cared_output))
735
+
736
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
737
+
738
+ context.micro_step += 1
739
+ if context.micro_step == self.micro_batch_number:
740
+ context.micro_step = 0
741
+ context.step += 1
742
+ return
743
+
744
+ def bwd_hook_fun(module, input_grad, output_grad):
745
+ context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
746
+ if not context.struct:
747
+ context.struct = {MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
748
+ MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)}
749
+ if self.print_struct:
750
+ if context.module_name not in self.module_struct:
751
+ self.module_struct[context.module_name] = {}
752
+ self.module_struct[context.module_name].update(context.struct)
753
+ return
754
+ if not context.format_by_arg:
755
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
756
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
757
+ if not context.format_by_arg:
758
+ return
759
+ if not context.verified:
760
+ if not context.ignore_in:
761
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
762
+ input_grad, context.module_name,
763
+ MonitorConst.ACTVGRAD_IN)
764
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
765
+ output_grad, context.module_name,
766
+ MonitorConst.ACTVGRAD_OUT)
767
+ context.verified = True
768
+
769
+ tbtag_tensor_map = {}
770
+ if not context.ignore_in:
771
+ cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
772
+ tbtag_tensor_map.update(
773
+ self.build_tbtag_tensor_map(
774
+ f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
775
+ cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
776
+ tbtag_tensor_map.update(
777
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
778
+ cared_output_grad))
779
+
780
+ if context.micro_step == 0 and context.actvgrad:
781
+ logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
782
+ f"maybe something wrong happened. Now clear it.")
783
+ context.actvgrad.clear()
784
+
785
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
786
+
787
+ context.micro_step += 1
788
+ if context.micro_step == self.micro_batch_number:
789
+ context.micro_step = 0
790
+ context.step += 1
791
+ return
792
+
793
+ if self.backward_only and self.forward_only:
794
+ logger.warning('not enable backward_only and forward_only simultaneously')
795
+
796
+ hooked_count = 0
797
+ if self.xy_distribution or self.print_struct:
798
+ for module_name, submodule in module.named_modules():
799
+ name = self._is_target_module(module_name, target_names, vpp_stage)
800
+ if not name:
801
+ continue
802
+ if not self.backward_only:
803
+ handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
804
+ self.handles['xy'].append(handle)
805
+ if not self.forward_only:
806
+ handle = submodule.register_full_backward_hook(bwd_hook_fun)
807
+ self.handles['xy'].append(handle)
808
+ self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
809
+ logger.info_on_rank_0(f"> {name} is monitored successfully")
810
+ hooked_count += 1
811
+ return hooked_count
812
+
813
+ def _patch_grad_sync(self):
814
+ def patch_sync(sync_grad_func):
815
+ def wrapper(bucket):
816
+ grad_dict = {}
817
+ for param, name in self.param2name.items():
818
+ if param not in bucket.params_list:
819
+ continue
820
+ grad = param.main_grad if self.params_have_main_grad else param.grad
821
+ if grad is None:
822
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
823
+ continue
824
+ tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
825
+ if tag is None:
826
+ continue
827
+ grad_dict[tag] = grad
828
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
829
+ out = sync_grad_func(bucket)
830
+ return out
831
+
832
+ return wrapper
833
+
834
+ try:
835
+ from megatron.core.distributed.param_and_grad_buffer import Bucket
836
+ self.enable_megatron = True
837
+ except ImportError:
838
+ self.enable_megatron = False
839
+
840
+ if self.enable_megatron:
841
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
842
+ else:
843
+ self._hook_weights()
844
+
845
+ def _hook_weights(self):
846
+ context = self.grad_context
847
+
848
+ @torch.no_grad
849
+ def param_hook(*args, context_dict, param, key, name):
850
+ param.micro_step += 1
851
+ self.param_name_call_id[name] = self.call_id
852
+ self.call_id += 1
853
+ if param.micro_step == self.micro_batch_number:
854
+ param.micro_step = 0
855
+ if self.params_have_main_grad:
856
+ context_dict[key] = param.main_grad.clone()
857
+ else:
858
+ context_dict[key] = param.grad.clone()
859
+
860
+ for param, name in self.param2name.items():
861
+ key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
862
+ setattr(param, 'micro_step', 0)
863
+ param_tmp = param.expand_as(param)
864
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
865
+ handle = grad_acc.register_hook(
866
+ partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
867
+ self.grad_accs.append(grad_acc)
868
+ self.handles['wgrads'].append(handle)
869
+
870
+ self.weight_hooked = True