mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -89,12 +89,6 @@ class FuzzHandler(ABC):
89
89
  )
90
90
  return origin_output_chunks, perturbed_output_chunks
91
91
 
92
- @staticmethod
93
- def convert_overflow_ratio_to_consistent(ratio):
94
- if math.isnan(ratio) or math.isinf(ratio):
95
- return ThresholdConfig.COMP_CONSISTENT
96
- return ratio
97
-
98
92
  @abstractmethod
99
93
  def get_threshold(self, dtype):
100
94
  pass
@@ -107,10 +101,10 @@ class FuzzHandler(ABC):
107
101
  self, origin_output, perturbed_output, norm_type, abs_tol
108
102
  ):
109
103
  if norm_type == NormType.ENDLESS_NORM:
110
- return self.calculate_error(origin_output, perturbed_output, abs_tol)
104
+ return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol)
111
105
  return ThresholdConfig.COMP_CONSISTENT
112
106
 
113
- def calculate_error(self, origin_output, perturbed_output, abs_tol):
107
+ def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol):
114
108
  origin_output_chunks, perturbed_output_chunks = (
115
109
  self.tensor_split_for_error_calculate(origin_output, perturbed_output)
116
110
  )
@@ -122,42 +116,30 @@ class FuzzHandler(ABC):
122
116
  raise FreeBenchmarkException(
123
117
  FreeBenchmarkException.OutputIndexError, err_msg
124
118
  )
125
- norm1 = -np.inf
126
- norm2 = -np.inf
127
- norm3 = np.inf
119
+
120
+ max_ratio = ThresholdConfig.COMP_CONSISTENT
128
121
  for i, chunk_origin in enumerate(origin_output_chunks):
129
122
  if chunk_origin.nelement() == 0:
130
123
  break
131
124
  chunk_perturbed = perturbed_output_chunks[i]
132
- ratio_tensor1 = TorchC.where(
133
- TorchC.abs(chunk_perturbed) > abs_tol,
134
- TorchC.div(
135
- TorchC.clamp(chunk_origin, min=abs_tol),
136
- TorchC.clamp(chunk_perturbed, min=abs_tol),
137
- ),
138
- 1,
139
- )
140
- ratio_tensor2 = TorchC.where(
141
- TorchC.abs(chunk_origin) > abs_tol,
142
- TorchC.div(
143
- TorchC.clamp(chunk_perturbed, min=abs_tol),
144
- TorchC.clamp(chunk_origin, min=abs_tol),
145
- ),
146
- 1,
125
+ # 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况
126
+ if TorchC.lt(
127
+ TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2)
128
+ ):
129
+ return ThresholdConfig.SYMBOL_FLIPPING
130
+ # 求A/B B/A的比值前,将值限制在大于极小值范围内
131
+ clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol)
132
+ clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol)
133
+ # 对于计算结果为nan的情况,认为两者没有差异
134
+ ratio_tensor = TorchC.nan_to_num(
135
+ TorchC.div(clamp_origin, clamp_perturbed),
136
+ nan=ThresholdConfig.COMP_CONSISTENT,
147
137
  )
148
- norm_values = TorchC.stack(
149
- [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
150
- )
151
- max_ratio1, max_ratio2 = norm_values.tolist()
152
- norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
153
- norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
154
- norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
155
-
156
- if norm3 < 0:
157
- ratio = ThresholdConfig.SYMBOL_FLIPPING
158
- else:
159
- ratio = max(norm1, norm2)
160
- return ratio
138
+ # 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数
139
+ min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist()
140
+ min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio
141
+ max_ratio = max(max_ratio, min_ratio_reciprocal)
142
+ return max_ratio
161
143
 
162
144
  def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
163
145
  try:
@@ -220,10 +202,12 @@ class FuzzHandler(ABC):
220
202
  )
221
203
  npu_consistent = is_consistent
222
204
  max_fuzz_ratio = (
223
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
205
+ max_fuzz_ratio
206
+ if not isinstance(ratio, (int, float))
207
+ else max(max_fuzz_ratio, ratio)
224
208
  )
225
- data_params.is_consistent = is_consistent and data_params.is_consistent
226
- if not is_consistent and data_params.grad_unequal_flag:
209
+ data_params.is_consistent = is_consistent
210
+ if not is_consistent:
227
211
  self.unequal_rows.append(
228
212
  make_unequal_row(data_params, self.params, ratio=ratio)
229
213
  )
@@ -235,12 +219,12 @@ class FuzzHandler(ABC):
235
219
  )
236
220
  npu_consistent = npu_consistent and is_consistent
237
221
  max_fuzz_ratio = (
238
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
239
- )
240
- data_params.is_consistent = (
241
- is_consistent and data_params.is_consistent
222
+ max_fuzz_ratio
223
+ if not isinstance(ratio, (int, float))
224
+ else max(max_fuzz_ratio, ratio)
242
225
  )
243
- if not is_consistent and data_params.grad_unequal_flag:
226
+ data_params.is_consistent = is_consistent
227
+ if not is_consistent:
244
228
  self.unequal_rows.append(
245
229
  make_unequal_row(
246
230
  data_params, self.params, ratio=ratio, index=index_
@@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler):
75
75
  if self.params.preheat_config.get("preheat_step") <= self.params.step:
76
76
  return data_params.original_result
77
77
 
78
- if not data_params.grad_unequal_flag:
79
- data_params.grad_unequal_flag = True
80
- data_params.is_consistent = False
81
- return data_params.original_result
82
78
  preheat_counter.add_api_called_time(self.pure_name)
83
79
 
84
80
  if not self._is_take_a_sample():
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from .wrap_functional import remove_dropout
16
+ from msprobe.pytorch.common.utils import remove_dropout
@@ -15,17 +15,17 @@
15
15
 
16
16
  import functools
17
17
  import threading
18
+ from collections import defaultdict
18
19
 
19
20
  import torch
20
21
  import torch.nn as nn
21
22
  import torch.utils.hooks as full_hooks
22
23
 
23
- from msprobe.core.common.const import Const
24
24
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
25
25
 
26
26
 
27
27
  class HOOKModule(nn.Module):
28
- module_count = {}
28
+ module_count = defaultdict(int)
29
29
  inner_stop_hook = {}
30
30
 
31
31
  def __init__(self, build_hook) -> None:
@@ -41,12 +41,7 @@ class HOOKModule(nn.Module):
41
41
  if hasattr(self, "prefix_op_name_"):
42
42
  self.prefix = self.prefix_op_name_
43
43
 
44
- if self.prefix not in HOOKModule.module_count:
45
- HOOKModule.module_count[self.prefix] = 1
46
- self.prefix += '0' + Const.SEP
47
- else:
48
- HOOKModule.module_count[self.prefix] += 1
49
- self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
44
+ self.forward_data_collected = False
50
45
  forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
51
46
  if torch_version_above_or_equal_2:
52
47
  self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
@@ -66,9 +61,17 @@ class HOOKModule(nn.Module):
66
61
  HOOKModule.inner_stop_hook[self.current_thread] = False
67
62
  return result
68
63
 
69
- @classmethod
70
- def reset_module_stats(cls):
71
- cls.module_count = {}
64
+ @staticmethod
65
+ def reset_module_stats():
66
+ HOOKModule.module_count = defaultdict(int)
67
+
68
+ @staticmethod
69
+ def add_module_count(name):
70
+ HOOKModule.module_count[name] += 1
71
+
72
+ @staticmethod
73
+ def get_module_count(name):
74
+ return HOOKModule.module_count[name]
72
75
 
73
76
  def _call_func(self, *args, **kwargs):
74
77
  full_backward_hooks, non_full_backward_hooks = [], []
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.pytorch.common.log import logger
19
+
20
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
21
+ if torch_version_above_or_equal_2:
22
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
23
+
24
+
25
+ def register_optimizer_hook(data_collector):
26
+ def optimizer_pre_step_hook(optimizer, args, kwargs):
27
+ data_collector.optimizer_status = Const.OPTIMIZER
28
+
29
+ def optimizer_post_step_hook(optimizer, args, kwargs):
30
+ data_collector.optimizer_status = Const.END_PREFIX + Const.OPTIMIZER
31
+
32
+ def patch_clip_grad(func):
33
+ def wrapper(*args, **kwargs):
34
+ data_collector.optimizer_status = Const.CLIP_GRAD
35
+ func(*args, **kwargs)
36
+ data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
37
+
38
+ return wrapper
39
+
40
+ if torch_version_above_or_equal_2:
41
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
42
+ register_optimizer_step_post_hook(optimizer_post_step_hook)
43
+ else:
44
+ logger.info_on_rank_0("Pytorch version is below 2.0, cannot register optimizer hook.")
45
+
46
+ try:
47
+ torch.nn.utils.clip_grad_norm_ = patch_clip_grad(torch.nn.utils.clip_grad_norm_)
48
+ torch.nn.utils.clip_grad_norm = patch_clip_grad(torch.nn.utils.clip_grad_norm)
49
+ torch.nn.utils.clip_grad_value_ = patch_clip_grad(torch.nn.utils.clip_grad_value_)
50
+ except Exception as e:
51
+ logger.info_on_rank_0("Cannot patch clip grad function. detail:%s" % str(e))
52
+
53
+ try:
54
+ from megatron.core.optimizer import MegatronOptimizer
55
+ MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm)
56
+ except ImportError:
57
+ pass
58
+ except Exception as e:
59
+ logger.info_on_rank_0("Cannot patch megatron clip grad function. detail:%s" % str(e))
@@ -138,6 +138,10 @@ functional:
138
138
  - fold
139
139
  - multi_head_attention_forward
140
140
  - scaled_dot_product_attention
141
+ - lp_pool3d
142
+ - dropout1d
143
+ - mish
144
+ - huber_loss
141
145
 
142
146
  tensor:
143
147
  - __add__
@@ -172,6 +176,7 @@ tensor:
172
176
  - __sub__
173
177
  - __truediv__
174
178
  - __xor__
179
+ - __pow__
175
180
  - abs
176
181
  - abs_
177
182
  - absolute
@@ -557,6 +562,27 @@ tensor:
557
562
  - view_as
558
563
  - xlogy
559
564
  - xlogy_
565
+ - split
566
+ - stft
567
+ - nan_to_num
568
+ - dsplit
569
+ - orgqr
570
+ - bitwise_left_shift_
571
+ - arctan2
572
+ - histogram
573
+ - q_zero_point
574
+ - adjoint
575
+ - ormqr
576
+ - bitwise_right_shift_
577
+ - nanquantile
578
+ - lu
579
+ - quantile
580
+ - arctan2_
581
+ - qr
582
+ - diagonal_scatter
583
+ - corrcoef
584
+ - vsplit
585
+ - aminmax
560
586
 
561
587
  torch:
562
588
  - linalg.norm
@@ -1131,6 +1157,14 @@ torch_npu:
1131
1157
  - npu_lstm
1132
1158
  - npu_apply_adam
1133
1159
  - npu_apply_adam_w
1160
+ - npu_anti_quant
1161
+ - npu_grouped_matmu
1162
+ - npu_quant_scatter
1163
+ - npu_group_norm_silu
1164
+ - npu_format_cast
1165
+ - npu_moe_finalize_routing
1166
+ - npu_moe_gating_top_k_softmax
1167
+ - npu_trans_quant_param
1134
1168
 
1135
1169
  aten:
1136
1170
  - signbit
@@ -21,7 +21,6 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
21
21
  from msprobe.pytorch.common.utils import torch_device_guard
22
22
  from msprobe.core.common.const import Const
23
23
  from msprobe.core.common.file_utils import load_yaml
24
- from msprobe.core.common.inplace_op_checker import InplaceOpChecker
25
24
 
26
25
 
27
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
@@ -49,17 +48,16 @@ class DistributedOPTemplate(HOOKModule):
49
48
  self.op_name_ = op_name
50
49
  self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
51
50
  super().__init__(build_hook)
52
- if not self.stop_hook and InplaceOpChecker.check(self.op_name_, InplaceOpChecker.OP_DISTRIBUTED):
53
- self.op_is_inplace = True
51
+ if not self.stop_hook:
52
+ self.op_is_distributed = True
54
53
 
55
54
  @torch_device_guard
56
55
  def forward(self, *args, **kwargs):
56
+ handle = distributed_func.get(self.op_name_)(*args, **kwargs)
57
57
  if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
58
- handle = distributed_func.get(self.op_name_)(*args, **kwargs)
59
- handle.wait()
60
- return handle
61
- else:
62
- return distributed_func.get(self.op_name_)(*args, **kwargs)
58
+ if handle and hasattr(handle, 'wait'):
59
+ handle.wait()
60
+ return handle
63
61
 
64
62
 
65
63
  def wrap_distributed_op(op_name, hook):
@@ -23,46 +23,6 @@ from msprobe.pytorch.common.log import logger
23
23
  from msprobe.core.common.file_utils import load_yaml
24
24
 
25
25
 
26
- def remove_dropout():
27
- if torch.__version__ > "1.8":
28
- logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
29
- import torch.nn.functional as F
30
- from torch import _VF
31
- from torch.overrides import has_torch_function_unary, handle_torch_function
32
-
33
- def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
34
- inplace: bool = False) -> torch.Tensor:
35
- if has_torch_function_unary(input_tensor):
36
- return handle_torch_function(
37
- function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
38
- if p < 0.0 or p > 1.0:
39
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
40
- return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
41
-
42
- def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
43
- inplace: bool = False) -> torch.Tensor:
44
- if has_torch_function_unary(input_tensor):
45
- return handle_torch_function(
46
- function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
47
- if p < 0.0 or p > 1.0:
48
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
49
- return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
50
- 0., training)
51
-
52
- def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
53
- inplace: bool = False) -> torch.Tensor:
54
- if has_torch_function_unary(input_tensor):
55
- return handle_torch_function(
56
- function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
57
- if p < 0.0 or p > 1.0:
58
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
59
- return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
60
- 0., training)
61
-
62
- F.dropout = function_dropout
63
- F.dropout2d = function_dropout2d
64
- F.dropout3d = function_dropout3d
65
-
66
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
67
27
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
68
28
 
@@ -19,7 +19,7 @@ import argparse
19
19
  import ast
20
20
  import heapq
21
21
 
22
- from msprobe.core.common.log import logger
22
+ from msprobe.pytorch.common.log import logger
23
23
  from msprobe.core.common.const import MonitorConst
24
24
  from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \
25
25
  check_file_or_directory_path, load_json
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,21 +12,22 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
15
+ import itertools
16
16
  import os
17
- import sys
18
17
  import statistics as st
18
+ import sys
19
19
  from abc import ABC
20
+ from collections import defaultdict
20
21
  from dataclasses import dataclass, field
21
22
  from typing import List
22
- from collections import defaultdict
23
23
 
24
24
  import pandas as pd
25
+ import torch
25
26
  from torch.utils.tensorboard import SummaryWriter
26
27
 
27
- from msprobe.core.common.log import logger
28
- from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
29
28
  from msprobe.core.common.const import FileCheckConst, MonitorConst
29
+ from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
30
+ from msprobe.pytorch.common.log import logger
30
31
 
31
32
 
32
33
  class ScanRule(ABC):
@@ -134,7 +135,7 @@ class AnomalyDataFactory(ABC):
134
135
  raise ValueError("tag must be a tuple with length 2")
135
136
  tag_name = tag[0]
136
137
  param_name = tag_name.split('/')[0]
137
- call_id = self.name2callid.get(param_name, -1)
138
+ call_id = self.name2callid.get(tag_name, -1)
138
139
  if MonitorConst.VPP_SEP in param_name:
139
140
  vpp_stage = int(param_name.split(MonitorConst.VPP_SEP)[0])
140
141
  else:
@@ -153,6 +154,24 @@ class AnomalyDataFactory(ABC):
153
154
  )
154
155
 
155
156
 
157
+ class TrainStage:
158
+ DEFAULT_STAGE = -1
159
+ FORWARD_STAGE = 0
160
+ BACKWARD_STAGE = 1
161
+ OPTIMIZER_STAGE = 2
162
+
163
+
164
+ FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
165
+ BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT,
166
+ MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
167
+ OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EFXP_AVG_SQ]
168
+ TRAIN_STAGE = {
169
+ **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY},
170
+ **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY},
171
+ **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY}
172
+ }
173
+
174
+
156
175
  @dataclass(eq=True)
157
176
  class GradAnomalyData:
158
177
  rank: int = 0
@@ -166,25 +185,48 @@ class GradAnomalyData:
166
185
  group_mates: list = field(default=None, compare=False)
167
186
 
168
187
  def __lt__(self, other):
188
+ """
189
+ 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。
190
+ 比较规则为:
191
+ step 和 micro_step 值越小优先级越高;
192
+ vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高;
193
+ call_id 值越小优先级越高。
194
+ """
169
195
  if not isinstance(other, GradAnomalyData):
170
196
  return NotImplemented
171
- if self.step != other.step:
172
- return self.step < other.step
173
- if self.micro_step != other.micro_step:
174
- return self.micro_step < other.micro_step
175
- if self.vpp_stage != other.vpp_stage:
176
- return self.vpp_stage > other.vpp_stage
177
- if self.pp_stage != other.pp_stage:
178
- return self.pp_stage > other.pp_stage
179
- if self.call_id != other.call_id:
180
- return self.call_id < other.call_id
181
- return False
197
+
198
+ self_train_stage = self.get_train_stage(self.tag_name)
199
+ other_train_stage = self.get_train_stage(other.tag_name)
200
+
201
+ def vpp_pp_comparator(anomaly):
202
+ """
203
+ Determine the priority rule for vpp and pp based on train stage
204
+ Forward stage prefers smaller vpp and pp
205
+ Other stages prefer larger vpp and pp
206
+ """
207
+ if self_train_stage == TrainStage.FORWARD_STAGE:
208
+ return anomaly.vpp_stage, anomaly.pp_stage
209
+ else:
210
+ return -anomaly.vpp_stage, -anomaly.pp_stage
211
+
212
+ self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id]
213
+ other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id]
214
+ return self_cmp < other_cmp
182
215
 
183
216
  def __le__(self, other):
184
217
  if not isinstance(other, GradAnomalyData):
185
218
  return NotImplemented
186
219
  return self == other or self < other
187
220
 
221
+ @staticmethod
222
+ def get_train_stage(tag_name):
223
+ """
224
+ :param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/efxp_avg_sq"
225
+ :return: int, if forward return 0; if backward return 1; if optimizer return 2
226
+ """
227
+ key_ = tag_name.split("/")[-1]
228
+ return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE)
229
+
188
230
  def to_dict(self):
189
231
  return self.__dict__
190
232
 
@@ -198,7 +240,6 @@ class WriterInput:
198
240
  path: str
199
241
  ad_rules: list
200
242
  job_id: str
201
- anomaly_inform: bool = False
202
243
  anomaly_factory: AnomalyDataFactory = None
203
244
  ndigits: int = 6
204
245
  step_count_per_record: int = 1
@@ -209,7 +250,6 @@ class BaseWriterWithAD:
209
250
  self.tag2scalars = {}
210
251
  self.ad_rules = writer_input.ad_rules
211
252
  self.job_id = writer_input.job_id
212
- self.anomaly_inform = writer_input.anomaly_inform
213
253
  self.anomaly_factory = writer_input.anomaly_factory
214
254
  self.anomalies = []
215
255
  self.ndigits = writer_input.ndigits
@@ -242,6 +282,27 @@ class BaseWriterWithAD:
242
282
  if self.anomaly_factory:
243
283
  self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
244
284
 
285
+ def write_metrics(self, ops, metric_value, step, prefix=''):
286
+ if not metric_value:
287
+ return
288
+ tensors = []
289
+ tags = list(itertools.product(metric_value.keys(), ops))
290
+ for op2tensor in metric_value.values():
291
+ tensors.extend(op2tensor.values())
292
+ if not tensors:
293
+ return
294
+
295
+ n_slices = len(tensors) // MonitorConst.SLICE_SIZE
296
+ with torch.no_grad():
297
+ for i in range(n_slices + 1):
298
+ begin = i * MonitorConst.SLICE_SIZE
299
+ end = (i+1) * MonitorConst.SLICE_SIZE
300
+ if begin == len(tensors):
301
+ continue
302
+ metric_list = torch.stack(tensors[begin:end]).cpu()
303
+ for tag, metric in zip(tags[begin:end], metric_list):
304
+ self.add_scalar(tag, metric, step)
305
+
245
306
  def _ad(self, scalar_value, history):
246
307
  return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
247
308
 
@@ -291,7 +352,7 @@ class CSVWriterWithAD(BaseWriterWithAD):
291
352
  """
292
353
  if len(self.context_dict) == 0:
293
354
  return
294
-
355
+
295
356
  ster_start, step_end = self.get_step_interval(step)
296
357
  filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
297
358
  if not os.path.exists(filepath):
@@ -304,7 +365,7 @@ class CSVWriterWithAD(BaseWriterWithAD):
304
365
  new_data.append([name] + [step] + metric_value)
305
366
  else:
306
367
  new_data.append(name.split(MonitorConst.VPP_SEP) + [step] + metric_value)
307
- new_data = pd.DataFrame(new_data).round(self.ndigits)
368
+ new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
308
369
  write_df_to_csv(new_data, filepath, mode='a+', header=False)
309
370
  self.context_dict = defaultdict(list)
310
371
 
@@ -317,6 +378,30 @@ class CSVWriterWithAD(BaseWriterWithAD):
317
378
  name = tag[0].split('/')[0]
318
379
  self.context_dict[name].append(scalar_value.item())
319
380
 
381
+ def write_metrics(self, ops, metric_value, step, prefix=''):
382
+ super().write_metrics(ops, metric_value, step, prefix='')
383
+
384
+ # generate csv headers
385
+ # set hashmap to reduce the number of headers generated.
386
+ # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
387
+ if prefix in {"actv", "actv_grad"}:
388
+ if prefix == "actv":
389
+ input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
390
+ else:
391
+ input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
392
+ ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, ops)]
393
+ csv_header = ["module_name", "step", *ops_]
394
+ else:
395
+ csv_header = ["param_name", "step", *ops]
396
+
397
+ keys = list(metric_value.keys())
398
+ if keys and MonitorConst.VPP_SEP in keys[0]:
399
+ csv_header.insert(0, "vpp_stage")
400
+
401
+ self.header = csv_header
402
+ self.write_csv(prefix, step)
403
+ self.header = []
404
+
320
405
  def close(self):
321
406
  pass
322
407