mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -15,17 +15,17 @@
15
15
 
16
16
  import functools
17
17
  import threading
18
+ from collections import defaultdict
18
19
 
19
20
  import torch
20
21
  import torch.nn as nn
21
22
  import torch.utils.hooks as full_hooks
22
23
 
23
- from msprobe.core.common.const import Const
24
24
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
25
25
 
26
26
 
27
27
  class HOOKModule(nn.Module):
28
- module_count = {}
28
+ module_count = defaultdict(int)
29
29
  inner_stop_hook = {}
30
30
 
31
31
  def __init__(self, build_hook) -> None:
@@ -41,12 +41,7 @@ class HOOKModule(nn.Module):
41
41
  if hasattr(self, "prefix_op_name_"):
42
42
  self.prefix = self.prefix_op_name_
43
43
 
44
- if self.prefix not in HOOKModule.module_count:
45
- HOOKModule.module_count[self.prefix] = 1
46
- self.prefix += '0' + Const.SEP
47
- else:
48
- HOOKModule.module_count[self.prefix] += 1
49
- self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
44
+ self.forward_data_collected = False
50
45
  forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
51
46
  if torch_version_above_or_equal_2:
52
47
  self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
@@ -66,9 +61,17 @@ class HOOKModule(nn.Module):
66
61
  HOOKModule.inner_stop_hook[self.current_thread] = False
67
62
  return result
68
63
 
69
- @classmethod
70
- def reset_module_stats(cls):
71
- cls.module_count = {}
64
+ @staticmethod
65
+ def reset_module_stats():
66
+ HOOKModule.module_count = defaultdict(int)
67
+
68
+ @staticmethod
69
+ def add_module_count(name):
70
+ HOOKModule.module_count[name] += 1
71
+
72
+ @staticmethod
73
+ def get_module_count(name):
74
+ return HOOKModule.module_count[name]
72
75
 
73
76
  def _call_func(self, *args, **kwargs):
74
77
  full_backward_hooks, non_full_backward_hooks = [], []
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.pytorch.common.log import logger
19
+
20
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
21
+ if torch_version_above_or_equal_2:
22
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
23
+
24
+
25
+ def register_optimizer_hook(data_collector):
26
+ def optimizer_pre_step_hook(optimizer, args, kwargs):
27
+ data_collector.optimizer_status = Const.OPTIMIZER
28
+
29
+ def optimizer_post_step_hook(optimizer, args, kwargs):
30
+ data_collector.optimizer_status = Const.END_PREFIX + Const.OPTIMIZER
31
+
32
+ def patch_clip_grad(func):
33
+ def wrapper(*args, **kwargs):
34
+ data_collector.optimizer_status = Const.CLIP_GRAD
35
+ func(*args, **kwargs)
36
+ data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
37
+
38
+ return wrapper
39
+
40
+ if torch_version_above_or_equal_2:
41
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
42
+ register_optimizer_step_post_hook(optimizer_post_step_hook)
43
+ else:
44
+ logger.info_on_rank_0("Pytorch version is below 2.0, cannot register optimizer hook.")
45
+
46
+ try:
47
+ torch.nn.utils.clip_grad_norm_ = patch_clip_grad(torch.nn.utils.clip_grad_norm_)
48
+ torch.nn.utils.clip_grad_norm = patch_clip_grad(torch.nn.utils.clip_grad_norm)
49
+ torch.nn.utils.clip_grad_value_ = patch_clip_grad(torch.nn.utils.clip_grad_value_)
50
+ except Exception as e:
51
+ logger.info_on_rank_0("Cannot patch clip grad function. detail:%s" % str(e))
52
+
53
+ try:
54
+ from megatron.core.optimizer import MegatronOptimizer
55
+ MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm)
56
+ except ImportError:
57
+ pass
58
+ except Exception as e:
59
+ logger.info_on_rank_0("Cannot patch megatron clip grad function. detail:%s" % str(e))
@@ -138,6 +138,10 @@ functional:
138
138
  - fold
139
139
  - multi_head_attention_forward
140
140
  - scaled_dot_product_attention
141
+ - lp_pool3d
142
+ - dropout1d
143
+ - mish
144
+ - huber_loss
141
145
 
142
146
  tensor:
143
147
  - __add__
@@ -172,6 +176,7 @@ tensor:
172
176
  - __sub__
173
177
  - __truediv__
174
178
  - __xor__
179
+ - __pow__
175
180
  - abs
176
181
  - abs_
177
182
  - absolute
@@ -557,6 +562,27 @@ tensor:
557
562
  - view_as
558
563
  - xlogy
559
564
  - xlogy_
565
+ - split
566
+ - stft
567
+ - nan_to_num
568
+ - dsplit
569
+ - orgqr
570
+ - bitwise_left_shift_
571
+ - arctan2
572
+ - histogram
573
+ - q_zero_point
574
+ - adjoint
575
+ - ormqr
576
+ - bitwise_right_shift_
577
+ - nanquantile
578
+ - lu
579
+ - quantile
580
+ - arctan2_
581
+ - qr
582
+ - diagonal_scatter
583
+ - corrcoef
584
+ - vsplit
585
+ - aminmax
560
586
 
561
587
  torch:
562
588
  - linalg.norm
@@ -1131,6 +1157,14 @@ torch_npu:
1131
1157
  - npu_lstm
1132
1158
  - npu_apply_adam
1133
1159
  - npu_apply_adam_w
1160
+ - npu_anti_quant
1161
+ - npu_grouped_matmu
1162
+ - npu_quant_scatter
1163
+ - npu_group_norm_silu
1164
+ - npu_format_cast
1165
+ - npu_moe_finalize_routing
1166
+ - npu_moe_gating_top_k_softmax
1167
+ - npu_trans_quant_param
1134
1168
 
1135
1169
  aten:
1136
1170
  - signbit
@@ -1877,4 +1911,5 @@ distributed:
1877
1911
  - all_to_all_single
1878
1912
  - all_to_all
1879
1913
  - all_gather_into_tensor
1880
- - reduce_scatter_tensor
1914
+ - reduce_scatter_tensor
1915
+ - batch_isend_irecv
@@ -21,7 +21,6 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
21
21
  from msprobe.pytorch.common.utils import torch_device_guard
22
22
  from msprobe.core.common.const import Const
23
23
  from msprobe.core.common.file_utils import load_yaml
24
- from msprobe.core.common.inplace_op_checker import InplaceOpChecker
25
24
 
26
25
 
27
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
@@ -49,17 +48,20 @@ class DistributedOPTemplate(HOOKModule):
49
48
  self.op_name_ = op_name
50
49
  self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
51
50
  super().__init__(build_hook)
52
- if not self.stop_hook and InplaceOpChecker.check(self.op_name_, InplaceOpChecker.OP_DISTRIBUTED):
53
- self.op_is_inplace = True
51
+ if not self.stop_hook:
52
+ self.op_is_distributed = True
54
53
 
55
54
  @torch_device_guard
56
55
  def forward(self, *args, **kwargs):
56
+ handle = distributed_func.get(self.op_name_)(*args, **kwargs)
57
57
  if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
58
- handle = distributed_func.get(self.op_name_)(*args, **kwargs)
59
- handle.wait()
60
- return handle
61
- else:
62
- return distributed_func.get(self.op_name_)(*args, **kwargs)
58
+ if handle and hasattr(handle, 'wait'):
59
+ handle.wait()
60
+ if self.op_name_ == "batch_isend_irecv":
61
+ if isinstance(handle, list):
62
+ for req in handle:
63
+ req.wait()
64
+ return handle
63
65
 
64
66
 
65
67
  def wrap_distributed_op(op_name, hook):
@@ -23,46 +23,6 @@ from msprobe.pytorch.common.log import logger
23
23
  from msprobe.core.common.file_utils import load_yaml
24
24
 
25
25
 
26
- def remove_dropout():
27
- if torch.__version__ > "1.8":
28
- logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
29
- import torch.nn.functional as F
30
- from torch import _VF
31
- from torch.overrides import has_torch_function_unary, handle_torch_function
32
-
33
- def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
34
- inplace: bool = False) -> torch.Tensor:
35
- if has_torch_function_unary(input_tensor):
36
- return handle_torch_function(
37
- function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
38
- if p < 0.0 or p > 1.0:
39
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
40
- return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
41
-
42
- def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
43
- inplace: bool = False) -> torch.Tensor:
44
- if has_torch_function_unary(input_tensor):
45
- return handle_torch_function(
46
- function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
47
- if p < 0.0 or p > 1.0:
48
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
49
- return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
50
- 0., training)
51
-
52
- def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
53
- inplace: bool = False) -> torch.Tensor:
54
- if has_torch_function_unary(input_tensor):
55
- return handle_torch_function(
56
- function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
57
- if p < 0.0 or p > 1.0:
58
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
59
- return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
60
- 0., training)
61
-
62
- F.dropout = function_dropout
63
- F.dropout2d = function_dropout2d
64
- F.dropout3d = function_dropout3d
65
-
66
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
67
27
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
68
28
 
@@ -19,7 +19,7 @@ import argparse
19
19
  import ast
20
20
  import heapq
21
21
 
22
- from msprobe.core.common.log import logger
22
+ from msprobe.pytorch.common.log import logger
23
23
  from msprobe.core.common.const import MonitorConst
24
24
  from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \
25
25
  check_file_or_directory_path, load_json
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,21 +12,22 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
15
+ import itertools
16
16
  import os
17
- import sys
18
17
  import statistics as st
18
+ import sys
19
19
  from abc import ABC
20
+ from collections import defaultdict
20
21
  from dataclasses import dataclass, field
21
22
  from typing import List
22
- from collections import defaultdict
23
23
 
24
24
  import pandas as pd
25
+ import torch
25
26
  from torch.utils.tensorboard import SummaryWriter
26
27
 
27
- from msprobe.core.common.log import logger
28
- from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
29
28
  from msprobe.core.common.const import FileCheckConst, MonitorConst
29
+ from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
30
+ from msprobe.pytorch.common.log import logger
30
31
 
31
32
 
32
33
  class ScanRule(ABC):
@@ -134,9 +135,9 @@ class AnomalyDataFactory(ABC):
134
135
  raise ValueError("tag must be a tuple with length 2")
135
136
  tag_name = tag[0]
136
137
  param_name = tag_name.split('/')[0]
137
- call_id = self.name2callid.get(param_name, -1)
138
- if MonitorConst.VPP_SEP in param_name:
139
- vpp_stage = int(param_name.split(MonitorConst.VPP_SEP)[0])
138
+ call_id = self.name2callid.get(tag_name, -1)
139
+ if MonitorConst.NAME_SEP in param_name:
140
+ vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0])
140
141
  else:
141
142
  vpp_stage = 0
142
143
 
@@ -153,6 +154,24 @@ class AnomalyDataFactory(ABC):
153
154
  )
154
155
 
155
156
 
157
+ class TrainStage:
158
+ DEFAULT_STAGE = -1
159
+ FORWARD_STAGE = 0
160
+ BACKWARD_STAGE = 1
161
+ OPTIMIZER_STAGE = 2
162
+
163
+
164
+ FORWARD_KEY = [MonitorConst.ACTV]
165
+ BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD,
166
+ MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
167
+ OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ]
168
+ TRAIN_STAGE = {
169
+ **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY},
170
+ **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY},
171
+ **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY}
172
+ }
173
+
174
+
156
175
  @dataclass(eq=True)
157
176
  class GradAnomalyData:
158
177
  rank: int = 0
@@ -166,25 +185,48 @@ class GradAnomalyData:
166
185
  group_mates: list = field(default=None, compare=False)
167
186
 
168
187
  def __lt__(self, other):
188
+ """
189
+ 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。
190
+ 比较规则为:
191
+ step 和 micro_step 值越小优先级越高;
192
+ vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高;
193
+ call_id 值越小优先级越高。
194
+ """
169
195
  if not isinstance(other, GradAnomalyData):
170
196
  return NotImplemented
171
- if self.step != other.step:
172
- return self.step < other.step
173
- if self.micro_step != other.micro_step:
174
- return self.micro_step < other.micro_step
175
- if self.vpp_stage != other.vpp_stage:
176
- return self.vpp_stage > other.vpp_stage
177
- if self.pp_stage != other.pp_stage:
178
- return self.pp_stage > other.pp_stage
179
- if self.call_id != other.call_id:
180
- return self.call_id < other.call_id
181
- return False
197
+
198
+ self_train_stage = self.get_train_stage(self.tag_name)
199
+ other_train_stage = self.get_train_stage(other.tag_name)
200
+
201
+ def vpp_pp_comparator(anomaly):
202
+ """
203
+ Determine the priority rule for vpp and pp based on train stage
204
+ Forward stage prefers smaller vpp and pp
205
+ Other stages prefer larger vpp and pp
206
+ """
207
+ if self_train_stage == TrainStage.FORWARD_STAGE:
208
+ return anomaly.vpp_stage, anomaly.pp_stage
209
+ else:
210
+ return -anomaly.vpp_stage, -anomaly.pp_stage
211
+
212
+ self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id]
213
+ other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id]
214
+ return self_cmp < other_cmp
182
215
 
183
216
  def __le__(self, other):
184
217
  if not isinstance(other, GradAnomalyData):
185
218
  return NotImplemented
186
219
  return self == other or self < other
187
220
 
221
+ @staticmethod
222
+ def get_train_stage(tag_name):
223
+ """
224
+ :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq"
225
+ :return: int, if forward return 0; if backward return 1; if optimizer return 2
226
+ """
227
+ key_ = tag_name.split("/")[-1]
228
+ return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE)
229
+
188
230
  def to_dict(self):
189
231
  return self.__dict__
190
232
 
@@ -198,7 +240,6 @@ class WriterInput:
198
240
  path: str
199
241
  ad_rules: list
200
242
  job_id: str
201
- anomaly_inform: bool = False
202
243
  anomaly_factory: AnomalyDataFactory = None
203
244
  ndigits: int = 6
204
245
  step_count_per_record: int = 1
@@ -209,7 +250,6 @@ class BaseWriterWithAD:
209
250
  self.tag2scalars = {}
210
251
  self.ad_rules = writer_input.ad_rules
211
252
  self.job_id = writer_input.job_id
212
- self.anomaly_inform = writer_input.anomaly_inform
213
253
  self.anomaly_factory = writer_input.anomaly_factory
214
254
  self.anomalies = []
215
255
  self.ndigits = writer_input.ndigits
@@ -242,6 +282,27 @@ class BaseWriterWithAD:
242
282
  if self.anomaly_factory:
243
283
  self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
244
284
 
285
+ def write_metrics(self, ops, metric_value, step, prefix=''):
286
+ if not metric_value:
287
+ return
288
+ tensors = []
289
+ tags = list(itertools.product(metric_value.keys(), ops))
290
+ for op2tensor in metric_value.values():
291
+ tensors.extend(op2tensor.values())
292
+ if not tensors:
293
+ return
294
+
295
+ n_slices = len(tensors) // MonitorConst.SLICE_SIZE
296
+ with torch.no_grad():
297
+ for i in range(n_slices + 1):
298
+ begin = i * MonitorConst.SLICE_SIZE
299
+ end = (i+1) * MonitorConst.SLICE_SIZE
300
+ if begin == len(tensors):
301
+ continue
302
+ metric_list = torch.stack(tensors[begin:end]).cpu()
303
+ for tag, metric in zip(tags[begin:end], metric_list):
304
+ self.add_scalar(tag, metric, step)
305
+
245
306
  def _ad(self, scalar_value, history):
246
307
  return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
247
308
 
@@ -291,7 +352,7 @@ class CSVWriterWithAD(BaseWriterWithAD):
291
352
  """
292
353
  if len(self.context_dict) == 0:
293
354
  return
294
-
355
+
295
356
  ster_start, step_end = self.get_step_interval(step)
296
357
  filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
297
358
  if not os.path.exists(filepath):
@@ -300,11 +361,11 @@ class CSVWriterWithAD(BaseWriterWithAD):
300
361
 
301
362
  new_data = []
302
363
  for name, metric_value in self.context_dict.items():
303
- if MonitorConst.VPP_SEP not in name:
304
- new_data.append([name] + [step] + metric_value)
305
- else:
306
- new_data.append(name.split(MonitorConst.VPP_SEP) + [step] + metric_value)
307
- new_data = pd.DataFrame(new_data).round(self.ndigits)
364
+ new_line = name.split(MonitorConst.NAME_SEP) + metric_value
365
+ new_line.insert(2, step)
366
+ new_data.append(new_line)
367
+
368
+ new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
308
369
  write_df_to_csv(new_data, filepath, mode='a+', header=False)
309
370
  self.context_dict = defaultdict(list)
310
371
 
@@ -317,6 +378,15 @@ class CSVWriterWithAD(BaseWriterWithAD):
317
378
  name = tag[0].split('/')[0]
318
379
  self.context_dict[name].append(scalar_value.item())
319
380
 
381
+ def write_metrics(self, ops, metric_value, step, prefix=''):
382
+ super().write_metrics(ops, metric_value, step, prefix='')
383
+
384
+ if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]:
385
+ self.header = MonitorConst.CSV_HEADER_XY + ops
386
+ else:
387
+ self.header = MonitorConst.CSV_HEADER + ops
388
+ self.write_csv(prefix, step)
389
+
320
390
  def close(self):
321
391
  pass
322
392
 
@@ -0,0 +1,164 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import datetime
16
+ import os
17
+ import re
18
+ from multiprocessing import Process
19
+
20
+ import pytz
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from tqdm import tqdm
23
+
24
+ from msprobe.core.common.const import MonitorConst
25
+ from msprobe.core.common.file_utils import read_csv, create_directory, remove_path
26
+ from msprobe.core.common.utils import is_int
27
+ from msprobe.pytorch.common.log import logger
28
+ from msprobe.pytorch.monitor.utils import get_target_output_dir
29
+
30
+ all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]
31
+ CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
32
+
33
+
34
+ def parse_step_line(line, ops):
35
+ vp_id = line["vpp_stage"]
36
+ module_name = line[MonitorConst.HEADER_NAME]
37
+ step = line["step"]
38
+ vpp_name = f"vp{vp_id}:{module_name}"
39
+ if 'micro_step' in line:
40
+ vpp_name = f'{vpp_name}{MonitorConst.NAME_SEP}micro{line["micro_step"]}'
41
+ ops_result = {}
42
+ for op in ops:
43
+ ops_result[op] = line[op]
44
+ return vpp_name, step, ops_result
45
+
46
+
47
+ def parse_step_fn(filepath):
48
+ data = read_csv(filepath)
49
+ ops = [k for k in data.keys() if k in MonitorConst.OP_LIST]
50
+ parse_step_result = {}
51
+
52
+ for _, line in data.iterrows():
53
+ vpp_name, step, ops_result = parse_step_line(line, ops)
54
+ if vpp_name not in parse_step_result:
55
+ parse_step_result[vpp_name] = {}
56
+ if step in parse_step_result[vpp_name]:
57
+ raise Exception(f"duplicated step({step})")
58
+ parse_step_result[vpp_name][step] = ops_result
59
+ return parse_step_result
60
+
61
+
62
+ def write_step(output_dirpath, parse_step_result, rank, data_type):
63
+ tb_output_path = os.path.join(output_dirpath, f"rank{rank}", data_type)
64
+ if os.path.exists(tb_output_path):
65
+ remove_path(tb_output_path)
66
+ logger.warning(f"existing path {tb_output_path} will be recovered")
67
+ writer = SummaryWriter(tb_output_path)
68
+ for vpp_name, step_data_dict in parse_step_result.items():
69
+ step_data_list = [(step, ops) for step, ops in step_data_dict.items()]
70
+ step_data_list.sort(key=lambda x: x[0])
71
+ for step_data in step_data_list:
72
+ step = step_data[0]
73
+ ops = step_data[1]
74
+ for op, value in ops.items():
75
+ tag = f"{vpp_name}/{op}"
76
+ writer.add_scalar(tag, value, step)
77
+
78
+
79
+ def update_dict(dict1, dict2):
80
+ for key, value in dict2.items():
81
+ if key in dict1:
82
+ if isinstance(dict1[key], dict) and isinstance(value, dict):
83
+ try:
84
+ update_dict(dict1[key], value)
85
+ except Exception as e:
86
+ raise Exception(f"Error updating nested dict failed at key '{key}': {e}") from e
87
+ else:
88
+ raise Exception(f"duplicate key: {key}")
89
+ else:
90
+ dict1[key] = value
91
+ return dict1
92
+
93
+
94
+ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
95
+ for directory in tqdm(target_output_dirs):
96
+ dirpath = directory["path"]
97
+ rank = directory["rank"]
98
+ for data_type in data_type_list:
99
+ all_step_result = {}
100
+ for filename in os.listdir(dirpath):
101
+ if not re.match(f"{data_type}{CSV_FILE_SUFFIX}", filename):
102
+ continue
103
+ filepath = os.path.join(dirpath, filename)
104
+ try:
105
+ parse_step_result = parse_step_fn(filepath)
106
+ except Exception as e:
107
+ logger.error(f"csv2tensorboard parse {filepath} failed \n {e}")
108
+ break
109
+
110
+ all_step_result = update_dict(all_step_result, parse_step_result)
111
+ if all_step_result:
112
+ write_step(output_dirpath, all_step_result, rank, data_type)
113
+
114
+
115
+ def check_process_num(process_num):
116
+ if not is_int(process_num) or process_num <= 0:
117
+ raise ValueError(f"process_num({process_num}) is not a positive integer")
118
+
119
+
120
+ def check_data_type_list(data_type_list):
121
+ if data_type_list is None:
122
+ logger.info(f"data_type_list is None, use defualt all_data_type_list: {all_data_type_list}")
123
+ return
124
+ if not isinstance(data_type_list, list):
125
+ raise ValueError(f"data_type_list({data_type_list}) is not a list")
126
+ for data_type in data_type_list:
127
+ if data_type not in all_data_type_list:
128
+ raise ValueError(f"data type({data_type}) is not supported, supported data type: {all_data_type_list}")
129
+
130
+
131
+ def csv2tensorboard_by_step(
132
+ monitor_path,
133
+ time_start=None,
134
+ time_end=None,
135
+ process_num=1,
136
+ data_type_list=None,
137
+ output_dirpath=None
138
+ ):
139
+ check_process_num(process_num)
140
+ check_data_type_list(data_type_list)
141
+ target_output_dirs = get_target_output_dir(monitor_path, time_start, time_end)
142
+ target_output_dirs = [{"rank": rank, "path": path} for rank, path in target_output_dirs.items()]
143
+ if output_dirpath is None:
144
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
145
+ cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S")
146
+ output_dirpath = os.path.join(monitor_path, f"{cur_time}-csv2tensorboard_by_step")
147
+ create_directory(output_dirpath)
148
+
149
+ task_num = len(target_output_dirs)
150
+ task_num_per_pro = task_num // process_num
151
+ target_data_type = data_type_list if data_type_list else all_data_type_list
152
+
153
+ processes = []
154
+ for pro_id in range(process_num):
155
+ task_start_id = pro_id * task_num_per_pro
156
+ task_end_id = (pro_id + 1) * task_num_per_pro if pro_id != process_num - 1 else task_num
157
+ task_dirs = target_output_dirs[task_start_id: task_end_id]
158
+
159
+ p = Process(target=csv2tb_by_step_work, args=(task_dirs, output_dirpath, target_data_type))
160
+ processes.append(p)
161
+ p.start()
162
+ for p in processes:
163
+ p.join()
164
+ logger.info(f"output has been saved to: {output_dirpath}")