mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,7 +21,7 @@ from dataclasses import dataclass
21
21
 
22
22
  import numpy as np
23
23
 
24
- from msprobe.core.common.const import Const, CompareConst
24
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
25
25
  from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
26
26
  from msprobe.core.common.file_utils import check_file_or_directory_path
27
27
 
@@ -37,13 +37,20 @@ def extract_json(dirname, stack_json=False):
37
37
  # Provide robustness on invalid directory inputs
38
38
  if not json_path:
39
39
  if stack_json:
40
- logger.error(f'stack.json is not found in dump dir {dirname}.')
40
+ logger.warning(f'stack.json is not found in dump dir {dirname}.')
41
41
  else:
42
42
  logger.error(f'dump.json is not found in dump dir {dirname}.')
43
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
43
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
44
44
  return json_path
45
45
 
46
46
 
47
+ def set_stack_json_path(input_param):
48
+ npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
49
+ stack_path = extract_json(npu_data_dir, stack_json=True)
50
+ input_param["stack_json_path"] = stack_path if stack_path else None
51
+ return bool(stack_path)
52
+
53
+
47
54
  def check_and_return_dir_contents(dump_dir, prefix):
48
55
  """
49
56
  check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
@@ -75,6 +82,10 @@ def check_and_return_dir_contents(dump_dir, prefix):
75
82
 
76
83
 
77
84
  def rename_api(npu_name, process):
85
+ """
86
+ 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
87
+ rename后: {api_type}.{api_name}.{input/output}.{参数序号}
88
+ """
78
89
  npu_split = npu_name.split(process)
79
90
  try:
80
91
  torch_func_index, in_out = npu_split[0], npu_split[1]
@@ -87,17 +98,13 @@ def rename_api(npu_name, process):
87
98
 
88
99
 
89
100
  def read_op(op_data, op_name):
90
- io_name_mapping = {
91
- Const.INPUT_ARGS: '.input',
92
- Const.INPUT_KWARGS: '.input',
93
- Const.INPUT: '.input',
94
- Const.OUTPUT: '.output'
95
- }
96
-
97
- op_parsed_list = []
98
- for name in io_name_mapping:
99
- if name in op_data:
100
- op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
101
+ if Const.PARAMS_GRAD in op_name.split(Const.SEP):
102
+ op_parsed_list = op_item_parse(op_data, op_name)
103
+ else:
104
+ op_parsed_list = []
105
+ for name in CompareConst.IO_NAME_MAPPING:
106
+ if name in op_data:
107
+ op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
101
108
  return op_parsed_list
102
109
 
103
110
 
@@ -124,11 +131,14 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
124
131
  return [default_item]
125
132
  elif not op_data:
126
133
  return []
127
-
134
+
128
135
  item_list = []
129
136
  if isinstance(op_data, list):
130
137
  for i, data in enumerate(op_data):
131
- item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
138
+ if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
139
+ item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
140
+ else:
141
+ item_list.extend(op_item_parse(data, op_name, depth + 1))
132
142
  elif isinstance(op_data, dict):
133
143
  if is_leaf_data(op_data):
134
144
  return [gen_op_item(op_data, op_name)]
@@ -144,14 +154,15 @@ def is_leaf_data(op_data):
144
154
  def gen_op_item(op_data, op_name):
145
155
  op_item = {}
146
156
  op_item.update(op_data)
147
- op_item['full_op_name'] = op_name
148
- op_item['data_name'] = op_data.get('data_name', '-1')
157
+ data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1
158
+ op_item['data_name'] = data_name
159
+ op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
149
160
 
150
161
  params = ['Max', 'Min', 'Mean', 'Norm']
151
162
  for i in params:
152
163
  if i not in op_item:
153
164
  op_item[i] = None
154
-
165
+
155
166
  if not op_item.get('dtype'):
156
167
  if op_item.get('type') == 'torch.Size':
157
168
  op_item['dtype'] = op_data.get('type')
@@ -159,6 +170,16 @@ def gen_op_item(op_data, op_name):
159
170
  elif op_item.get('type') == 'slice':
160
171
  op_item['dtype'] = op_data.get('type')
161
172
  op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
173
+ elif op_item.get('type') == 'ellipsis':
174
+ op_item['dtype'] = op_data.get('type')
175
+ op_item['shape'] = '[]'
176
+ for i in params:
177
+ op_item[i] = op_data.get('value')
178
+ elif op_item.get('type') == 'torch.ProcessGroup':
179
+ op_item['dtype'] = op_data.get('type')
180
+ op_item['shape'] = '[]'
181
+ for i in params:
182
+ op_item[i] = str(op_data.get('group_ranks'))
162
183
  else:
163
184
  op_item['dtype'] = str(type(op_data.get('value')))
164
185
  op_item['shape'] = '[]'
@@ -166,7 +187,7 @@ def gen_op_item(op_data, op_name):
166
187
  op_item[i] = op_data.get('value')
167
188
  if not op_item.get('md5'):
168
189
  op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
169
-
190
+
170
191
  return op_item
171
192
 
172
193
 
@@ -276,6 +297,22 @@ def result_item_init(n_info, b_info, dump_mode):
276
297
  return result_item
277
298
 
278
299
 
300
+ def count_struct(op_dict):
301
+ parts = [
302
+ CompareConst.OP_NAME,
303
+ CompareConst.INPUT_STRUCT,
304
+ CompareConst.OUTPUT_STRUCT,
305
+ CompareConst.PARAMS_STRUCT,
306
+ CompareConst.PARAMS_GRAD_STRUCT
307
+ ]
308
+ lengths = [len(op_dict.get(part, [])) for part in parts]
309
+ num = lengths[0]
310
+ if num != sum(lengths[1:]):
311
+ logger.error(f"Length of names and structs of op_dict not match. Please check! op_dict: {op_dict}")
312
+ raise CompareException(CompareException.NAMES_STRUCTS_MATCH_ERROR)
313
+ return tuple(lengths)
314
+
315
+
279
316
  def get_accuracy(result, n_dict, b_dict, dump_mode):
280
317
  def get_accuracy_core(n_start, n_len, b_start, b_len, key):
281
318
  min_len = min(n_len, b_len)
@@ -355,31 +392,50 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
355
392
 
356
393
  result.append(result_item)
357
394
 
358
- n_num = len(n_dict['op_name'])
359
- b_num = len(b_dict['op_name'])
360
- n_num_input = len([name for name in n_dict['op_name']
361
- if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
362
- b_num_input = len([name for name in b_dict['op_name']
363
- if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
364
- n_num_output = n_num - n_num_input
365
- b_num_output = b_num - b_num_input
366
- get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
367
- get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
395
+ n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
396
+ b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
397
+
398
+ get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
399
+ get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
400
+ CompareConst.PARAMS_STRUCT)
401
+ get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT)
402
+ get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad,
403
+ b_num_input + b_num_output + b_num_params, b_num_params_grad,
404
+ CompareConst.PARAMS_GRAD_STRUCT)
405
+
406
+
407
+ def append_stack_info(result_item, npu_stack_info, index):
408
+ """添加堆栈信息到 result_item"""
409
+ if npu_stack_info and index == 0:
410
+ result_item.extend(npu_stack_info)
411
+ else:
412
+ result_item.append(CompareConst.NONE)
368
413
 
369
414
 
370
415
  def get_un_match_accuracy(result, n_dict, dump_mode):
371
- index_out = 0
372
416
  npu_stack_info = n_dict.get("stack_info", None)
373
417
  bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
374
- err_msg = CompareConst.NO_BENCH
375
- accuracy_check_res = CompareConst.N_A
376
- for index, n_name in enumerate(n_dict["op_name"]):
377
- name_ele_list = n_name.split(Const.SEP)
378
- if Const.INPUT in name_ele_list or Const.KWARGS in name_ele_list:
379
- n_struct = safe_get_value(n_dict, index, "n_dict", key=CompareConst.INPUT_STRUCT)
380
- if Const.OUTPUT in name_ele_list:
381
- n_struct = safe_get_value(n_dict, index_out, "n_dict", key=CompareConst.OUTPUT_STRUCT)
382
- index_out += 1
418
+
419
+ struct_to_index_mapping = {
420
+ CompareConst.INPUT_STRUCT: 0,
421
+ CompareConst.OUTPUT_STRUCT: 0,
422
+ CompareConst.PARAMS_STRUCT: 0,
423
+ CompareConst.PARAMS_GRAD_STRUCT: 0
424
+ }
425
+
426
+ op_name_list = n_dict.get(CompareConst.OP_NAME)
427
+ summary_list = n_dict.get(Const.SUMMARY)
428
+ data_name_list = n_dict.get('data_name')
429
+ op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
430
+ summary_list,
431
+ data_name_list)
432
+ for index, n_name in enumerate(op_name_reorder):
433
+ _, state = get_name_and_state(n_name)
434
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
435
+ if not struct_key:
436
+ continue
437
+ n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
438
+ struct_to_index_mapping[struct_key] += 1
383
439
 
384
440
  try:
385
441
  result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
@@ -390,28 +446,26 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
390
446
  f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
391
447
  logger.error(err_msg)
392
448
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
449
+
393
450
  if dump_mode == Const.MD5:
394
451
  result_item.extend([CompareConst.N_A] * 3)
395
- if npu_stack_info and index == 0:
396
- result_item.extend(npu_stack_info)
397
- else:
398
- result_item.append(CompareConst.NONE)
452
+ append_stack_info(result_item, npu_stack_info, index)
399
453
  result.append(result_item)
400
454
  continue
401
455
  if dump_mode == Const.SUMMARY:
402
456
  result_item.extend([CompareConst.N_A] * 8)
403
- else:
457
+ if dump_mode == Const.ALL:
404
458
  result_item.extend([CompareConst.N_A] * 5)
405
- npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY)
406
- result_item.extend(npu_summary_data)
459
+
460
+ npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
407
461
  bench_summary_data = [CompareConst.N_A] * 4
462
+ result_item.extend(npu_summary_data)
408
463
  result_item.extend(bench_summary_data)
464
+ err_msg = CompareConst.NO_BENCH
465
+ accuracy_check_res = CompareConst.N_A
409
466
  result_item.append(accuracy_check_res)
410
467
  result_item.append(err_msg)
411
- if npu_stack_info and index == 0:
412
- result_item.extend(npu_stack_info)
413
- else:
414
- result_item.append(CompareConst.NONE)
468
+ append_stack_info(result_item, npu_stack_info, index)
415
469
  if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
416
470
  result_item.extend(["-1"])
417
471
  result.append(result_item)
@@ -423,6 +477,8 @@ def merge_tensor(tensor_list, dump_mode):
423
477
  op_dict[CompareConst.INPUT_STRUCT] = []
424
478
  op_dict[CompareConst.KWARGS_STRUCT] = []
425
479
  op_dict[CompareConst.OUTPUT_STRUCT] = []
480
+ op_dict[CompareConst.PARAMS_STRUCT] = []
481
+ op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
426
482
  op_dict[Const.SUMMARY] = []
427
483
  op_dict["stack_info"] = []
428
484
 
@@ -430,30 +486,25 @@ def merge_tensor(tensor_list, dump_mode):
430
486
  op_dict["data_name"] = []
431
487
 
432
488
  for tensor in tensor_list:
489
+ # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
433
490
  if len(tensor) == 2:
434
491
  op_dict['stack_info'].append(tensor['full_info'])
435
492
  break
493
+
436
494
  op_dict["op_name"].append(tensor['full_op_name'])
437
- name_ele_list = tensor['full_op_name'].split(Const.SEP)
438
- name_to_struct_mapping = {
439
- Const.INPUT: CompareConst.INPUT_STRUCT,
440
- Const.KWARGS: CompareConst.KWARGS_STRUCT,
441
- Const.OUTPUT: CompareConst.OUTPUT_STRUCT
442
- }
443
- for name_key, struct_key in name_to_struct_mapping.items():
444
- if name_key in name_ele_list:
445
- if dump_mode == Const.MD5:
446
- op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
447
- else:
448
- op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
449
- break
495
+
496
+ _, state = get_name_and_state(tensor['full_op_name'])
497
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
498
+ if not struct_key:
499
+ continue
500
+ if dump_mode == Const.MD5:
501
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
502
+ else:
503
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
450
504
  op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
451
505
 
452
506
  if dump_mode == Const.ALL:
453
507
  op_dict["data_name"].append(tensor['data_name'])
454
- data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
455
- if data_name != "-1":
456
- op_dict["op_name"][-1] = data_name
457
508
 
458
509
  if not op_dict[CompareConst.KWARGS_STRUCT]:
459
510
  del op_dict[CompareConst.KWARGS_STRUCT]
@@ -467,11 +518,90 @@ def print_compare_ends_info():
467
518
  logger.info('*' * total_len)
468
519
 
469
520
 
521
+ def table_value_is_valid(value: str) -> bool:
522
+ if not isinstance(value, str):
523
+ return True
524
+ try:
525
+ # -1.00 or +1.00 should be consdiered as digit numbers
526
+ float(value)
527
+ except ValueError:
528
+ # otherwise, they will be considered as formular injections
529
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
530
+ return True
531
+
532
+
533
+ def get_name_and_state(name):
534
+ """
535
+ Get api/module name and state
536
+ example:
537
+ name = 'conv2d.forward.1.input.0'
538
+ return: ('conv2d.forward.1.', 'input')
539
+
540
+ name = 'Functional.pad.0.backward.output.0'
541
+ return: ('Functional.pad.0.backward.', 'output')
542
+
543
+ state type: input, output, kwargs, parameters, parameters_grad
544
+ """
545
+ if Const.PARAMS_GRAD in name.split(Const.SEP):
546
+ return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
547
+
548
+ split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
549
+ api = f'{split[0]}.{split[1]}.'
550
+ state_str = split[2]
551
+ match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
552
+ if not match:
553
+ raise CompareException(f'Invalid name string: {name}')
554
+ if match.group(1):
555
+ api = f'{api}{match.group(1)}'
556
+ state = match.group(2)
557
+ return api, state
558
+
559
+
560
+ def reorder_op_name_list(op_name_list):
561
+ if not op_name_list:
562
+ return op_name_list
563
+
564
+ parameters = []
565
+ output = []
566
+ parameters_grad = []
567
+ others = []
568
+ for x in op_name_list:
569
+ state = get_name_and_state(x)[1]
570
+ if state == Const.PARAMS:
571
+ parameters.append(x)
572
+ elif state == Const.OUTPUT:
573
+ output.append(x)
574
+ elif state == Const.PARAMS_GRAD:
575
+ parameters_grad.append(x)
576
+ else:
577
+ others.append(x)
578
+ # 合并others, parameters, 和output,确保parameters排在output前面
579
+ op_name_reorder = others + parameters + output + parameters_grad
580
+ return op_name_reorder
581
+
582
+
583
+ def reorder_op_x_list(op_name_list, summary_list, data_name_list):
584
+ """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
585
+ if not op_name_list or not summary_list:
586
+ return op_name_list, summary_list, data_name_list
587
+
588
+ index_map = {name: index for index, name in enumerate(op_name_list)}
589
+
590
+ op_name_reorder = reorder_op_name_list(op_name_list)
591
+ summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
592
+ if data_name_list:
593
+ data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
594
+ else:
595
+ data_name_reorder = data_name_list
596
+
597
+ return op_name_reorder, summary_reorder, data_name_reorder
598
+
599
+
470
600
  def _compare_parser(parser):
471
601
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
472
602
  help="<Required> The compare input path, a dict json.", required=True)
473
603
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
474
- help="<Required> The compare task result out path. Default path: ./output",
604
+ help="<Required> The compare task result out path. Default path: ./output",
475
605
  required=False, default="./output", nargs="?", const="./output")
476
606
  parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
477
607
  help="<optional> Whether to save stack info.", required=False)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -38,6 +38,9 @@ class DataCollector:
38
38
  self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
39
39
  self.module_count = {}
40
40
  self.scope = ScopeFactory(self.config).build_scope()
41
+ self.backward_module_names = {}
42
+ self.optimizer_status = ""
43
+ self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
41
44
  atexit.register(self.write_json)
42
45
 
43
46
  @property
@@ -53,8 +56,15 @@ class DataCollector:
53
56
  return (not scope or scope.check(name)) and pid == os.getpid()
54
57
 
55
58
  @staticmethod
56
- def is_inplace(module):
57
- return getattr(module, "op_is_inplace", False)
59
+ def set_is_recomputable(data_info, is_recompute):
60
+ if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1
61
+ data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute
62
+
63
+ def reset_status(self):
64
+ self.optimizer_status = ""
65
+ self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
66
+ self.data_writer.reset_cache()
67
+ self.backward_module_names.clear()
58
68
 
59
69
  def if_return_forward_new_output(self):
60
70
  return self.data_processor.if_return_forward_new_output()
@@ -79,69 +89,105 @@ class DataCollector:
79
89
  logger.debug(msg)
80
90
  self.data_writer.update_data(data_info)
81
91
 
82
- def pre_forward_data_collect(self, name, module, pid, module_input_output):
83
- if self.config.level == Const.LEVEL_L2 and self.check_scope_and_pid(self.scope, name, pid):
84
- self.data_processor.analyze_pre_forward(name, module, module_input_output)
92
+ def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
93
+ if self.config.task == Const.FREE_BENCHMARK:
94
+ backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
95
+ if self.check_scope_and_pid(self.scope, backward_name, pid):
96
+ self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
97
+ return
98
+
99
+ if not self.check_scope_and_pid(self.scope, name, pid):
85
100
  return
86
101
 
87
- backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
88
- if self.check_scope_and_pid(self.scope, backward_name, pid):
89
- self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
90
- if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid):
102
+ data_info = {}
103
+ if self.config.task != Const.STRUCTURE:
104
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
105
+ self.set_is_recomputable(data_info, is_recompute)
106
+ if self.config.level == Const.LEVEL_L2:
91
107
  return
92
- logger.info(f"API {name} is inplace.")
93
- data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
94
108
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
95
109
 
96
- def forward_data_collect(self, name, module, pid, module_input_output):
110
+ def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
97
111
  self.update_construct(name)
98
112
  if not self.check_scope_and_pid(self.scope, name, pid):
99
113
  return
114
+
115
+ data_info = {}
116
+ if self.config.task != Const.STRUCTURE:
117
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
118
+ self.set_is_recomputable(data_info, is_recompute)
100
119
  if self.config.level == Const.LEVEL_L2:
101
- self.data_processor.analyze_forward(name, module, module_input_output)
120
+ return
121
+ self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
122
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
123
+
124
+ def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
125
+ self.update_construct(name)
126
+ if not self.check_scope_and_pid(self.scope, name, pid):
102
127
  return
103
128
 
104
- if not self.is_inplace(module):
129
+ data_info = {}
130
+ if self.config.task != Const.STRUCTURE:
105
131
  data_info = self.data_processor.analyze_forward(name, module, module_input_output)
106
- else:
107
- data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
132
+ self.set_is_recomputable(data_info, is_recompute)
108
133
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
109
134
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
110
135
 
111
- def backward_data_collect(self, name, module, pid, module_input_output):
136
+ def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
112
137
  self.update_construct(name)
113
138
  if not self.check_scope_and_pid(self.scope, name, pid):
114
139
  return
115
140
 
116
- data_info = self.data_processor.analyze_backward(name, module, module_input_output)
141
+ data_info = {}
142
+ if self.config.task != Const.STRUCTURE:
143
+ data_info = self.data_processor.analyze_backward(name, module, module_input_output)
117
144
  if self.config.level == Const.LEVEL_L2:
118
145
  return
146
+ # 获取执行反向的模块名称
147
+ if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
148
+ module_name = name.rsplit(Const.SEP, 2)[0]
149
+ # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
150
+ self.backward_module_names[module_name] = True
119
151
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
120
152
 
121
- def backward_input_data_collect(self, name, module, pid, module_input_output):
153
+ def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
122
154
  self.update_construct(name)
123
155
  if not self.check_scope_and_pid(self.scope, name, pid):
124
156
  return
125
157
 
126
- data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
158
+ data_info = {}
159
+ if self.config.task != Const.STRUCTURE:
160
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
161
+ self.set_is_recomputable(data_info, is_recompute)
127
162
  self.handle_data(name, data_info)
128
163
 
129
- def backward_output_data_collect(self, name, module, pid, module_input_output):
164
+ def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
130
165
  self.update_construct(name)
131
166
  if not self.check_scope_and_pid(self.scope, name, pid):
132
167
  return
133
168
 
134
- data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
169
+ data_info = {}
170
+ if self.config.task != Const.STRUCTURE:
171
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
172
+ self.set_is_recomputable(data_info, is_recompute)
135
173
  self.handle_data(name, data_info)
136
174
 
137
175
  def update_construct(self, name):
138
176
  if self.config.level not in DataCollector.level_without_construct:
139
- self.data_writer.update_construct({name: self.module_processor.api_parent_node})
177
+ if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
178
+ if self.optimizer_status_first_start[self.optimizer_status]:
179
+ self.data_writer.update_construct({self.optimizer_status: None})
180
+ self.optimizer_status_first_start[self.optimizer_status] = False
181
+ self.data_writer.update_construct({name: self.optimizer_status})
182
+ else:
183
+ self.data_writer.update_construct({name: self.module_processor.api_parent_node})
140
184
  self.data_writer.update_construct(self.module_processor.module_node)
141
185
 
142
186
  def handle_data(self, name, data_info, flush=False):
143
187
  if data_info:
144
188
  self.update_data(name, data_info)
189
+ if self.config.async_dump:
190
+ return
145
191
  if not flush:
146
192
  self.data_writer.flush_data_periodically()
147
193
  else:
@@ -149,7 +195,36 @@ class DataCollector:
149
195
 
150
196
  def update_dump_paths(self, *args):
151
197
  self.data_writer.update_dump_paths(*args)
152
- self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
198
+
199
+ def initialize_json_file(self, framework=Const.UNKNOWN_FRAMEWORK):
200
+ self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level, framework=framework)
153
201
 
154
202
  def update_iter(self, current_iter):
155
203
  self.data_processor.update_iter(current_iter)
204
+
205
+ def params_data_collect(self, name, param_name, pid, data):
206
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
207
+ # 校验scope和pid,以及当前name是否有过反向计算
208
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
209
+ # 如果没有反向计算,则需要清除之前占位写入的grad数据
210
+ if self.data_writer.cache_data.get("data"):
211
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
212
+ return
213
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
214
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
215
+
216
+ def fill_stack_tensor_data(self):
217
+ self.data_writer.fill_stack_tensor_data()
218
+
219
+ def debug_data_collect_forward(self, variable, name_with_count):
220
+
221
+ data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
222
+ self.data_writer.update_debug({name_with_count: data_info})
223
+
224
+ def debug_data_collect_backward(self, variable, grad_name_with_count):
225
+ # prepare all None nested data structure
226
+ all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
227
+ self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
228
+
229
+ # register tensor backward hook
230
+ self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])