mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -25,13 +25,15 @@ from msprobe.core.common.exceptions import MsprobeException
25
25
  from msprobe.core.common.runtime import Runtime
26
26
  from msprobe.core.common.utils import ModuleQueue, ThreadSafe
27
27
  from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
28
+ from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
28
29
  from msprobe.mindspore.common.const import Const as MsConst
29
30
  from msprobe.mindspore.common.log import logger
30
31
  from msprobe.mindspore.common.utils import (
31
32
  is_mindtorch,
32
33
  get_cells_and_names_with_index,
33
34
  has_kwargs_in_forward_hook,
34
- is_graph_mode_cell_dump_allowed
35
+ is_graph_mode_cell_dump_allowed,
36
+ is_backward_hook_output_a_view
35
37
  )
36
38
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
37
39
  from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
@@ -46,6 +48,28 @@ def get_cell_construct(construct):
46
48
  return _construct
47
49
 
48
50
 
51
+ def patch_schedules_step():
52
+ try:
53
+ from mindspeed.mindspore.core.pipeline_parallel import schedules
54
+ schedules.forward_step = wrap_megatron_step(schedules.forward_step)
55
+ schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
56
+ logger.info_on_rank_0("Patch mindspeed.mindspore method success.")
57
+ except ImportError:
58
+ logger.info_on_rank_0("No mindspeed.mindspore find.")
59
+ except Exception as e:
60
+ logger.info_on_rank_0(f"Patch mindspeed.mindspore method failed, detail:{str(e)}")
61
+
62
+ try:
63
+ from megatron.core.pipeline_parallel import schedules
64
+ schedules.forward_step = wrap_megatron_step(schedules.forward_step)
65
+ schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
66
+ logger.info_on_rank_0("Patch megatron method success.")
67
+ except ImportError:
68
+ logger.info_on_rank_0("No megatron find.")
69
+ except Exception as e:
70
+ logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}")
71
+
72
+
49
73
  class CellProcessor:
50
74
  cell_queue = ModuleQueue()
51
75
  cell_count = {}
@@ -83,6 +107,8 @@ class CellProcessor:
83
107
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
84
108
  'The model cannot be None, when level is "L0" or "mix"')
85
109
 
110
+ patch_schedules_step()
111
+
86
112
  is_registered = False
87
113
  model_type = Const.MODULE if is_mindtorch() else Const.CELL
88
114
  cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models)
@@ -116,19 +142,23 @@ class CellProcessor:
116
142
  cells_and_names_in_graph_mode = []
117
143
  for index, cells_and_names in cells_with_index_in_graph_mode.items():
118
144
  model = models if index == "-1" else models[int(index)]
119
- for name, cell in cells_and_names:
145
+ for name, cell, parent_cell in cells_and_names:
120
146
  if cell == model:
121
147
  continue
122
148
  cell_index = (index + Const.SEP) if index != "-1" else ""
123
- cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell))
149
+ cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell, parent_cell))
124
150
 
125
151
  if cells_and_names_in_graph_mode:
126
152
  Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE
127
153
  GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
128
154
 
155
+
129
156
  def build_cell_hook(self, cell_name, build_data_hook):
130
157
  @ThreadSafe.synchronized
131
158
  def forward_pre_hook(cell, args):
159
+ if not Runtime.is_running:
160
+ return args
161
+
132
162
  index = CellProcessor.set_and_get_calls_number(cell_name)
133
163
  full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
134
164
  full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
@@ -174,7 +204,7 @@ class CellProcessor:
174
204
  bw_hook.register_backward_hook()
175
205
  CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook
176
206
 
177
- args = bw_hook(*args)
207
+ args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args)
178
208
 
179
209
  return args
180
210
 
@@ -199,12 +229,15 @@ class CellProcessor:
199
229
  logger.warning("For backward hooks to be called,"
200
230
  " cell output should be a Tensor or a tuple of Tensors"
201
231
  f" but received {type(outputs)}")
202
- if isinstance(outputs, tuple):
203
- new_outputs = bw_hook(*outputs)
204
- else:
232
+ if is_backward_hook_output_a_view():
205
233
  new_outputs = bw_hook(outputs)
206
- if isinstance(outputs, tuple) and len(outputs) == 1:
207
- new_outputs = (new_outputs,)
234
+ else:
235
+ if isinstance(outputs, tuple):
236
+ new_outputs = bw_hook(*outputs)
237
+ else:
238
+ new_outputs = bw_hook(outputs)
239
+ if isinstance(outputs, tuple) and len(outputs) == 1:
240
+ new_outputs = (new_outputs,)
208
241
  outputs = new_outputs
209
242
 
210
243
  def get_backward_pre_hook(full_backward_name, backward_data_hook):
@@ -227,18 +260,21 @@ class CellProcessor:
227
260
  self.cell_backward_pre_hook[-1])
228
261
  bw_pre_hook.register_backward_pre_hook()
229
262
 
230
- if isinstance(outputs, tuple):
231
- result = bw_pre_hook(*outputs)
232
- else:
263
+ if is_backward_hook_output_a_view():
233
264
  result = bw_pre_hook(outputs)
234
- if isinstance(outputs, tuple):
235
- if len(outputs) == 1:
236
- result = (result,)
237
- if len(result) != len(outputs):
238
- raise TypeError(
239
- f"The backward pre hook return value size is {len(result)} "
240
- f"not equal to output size {len(outputs)}"
241
- )
265
+ else:
266
+ if isinstance(outputs, tuple):
267
+ result = bw_pre_hook(*outputs)
268
+ else:
269
+ result = bw_pre_hook(outputs)
270
+ if isinstance(outputs, tuple):
271
+ if len(outputs) == 1:
272
+ result = (result,)
273
+ if len(result) != len(outputs):
274
+ raise TypeError(
275
+ f"The backward pre hook return value size is {len(result)} "
276
+ f"not equal to output size {len(outputs)}"
277
+ )
242
278
  return result
243
279
 
244
280
  return forward_pre_hook
@@ -249,23 +285,26 @@ class CellProcessor:
249
285
  CellProcessor.cell_stack[tid] = []
250
286
 
251
287
  if self.cell_stack[tid]:
252
- CellProcessor.module_node[full_name] = self.cell_stack[tid][-1]
288
+ CellProcessor.module_node[full_name] = self.cell_stack[tid][-1] if not is_megatron() \
289
+ else [self.cell_stack[tid][-1], get_micro_step()]
253
290
  else:
254
291
  parent_name = CellProcessor.cell_queue.find_last(full_name)
255
- CellProcessor.module_node[full_name] = parent_name
292
+ CellProcessor.module_node[full_name] = parent_name if not is_megatron() else [parent_name, get_micro_step()]
256
293
 
257
294
  CellProcessor.cell_queue.add_name(full_name)
258
295
  CellProcessor.cell_stack[tid].append(full_name)
259
- CellProcessor.api_parent_node[tid] = full_name
296
+ CellProcessor.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
260
297
  if self.scope:
261
298
  self.scope.begin_module(full_name)
262
299
 
263
300
  def set_construct_info_in_hook(self, full_name):
264
301
  tid = threading.get_ident()
265
- CellProcessor.api_parent_node[tid] = None
302
+ CellProcessor.cell_queue.remove_name(full_name)
303
+ CellProcessor.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
266
304
  if self.cell_stack.get(tid):
267
305
  CellProcessor.cell_stack[tid].pop()
268
306
  if self.cell_stack.get(tid):
269
- CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1]
307
+ CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1] if not is_megatron() \
308
+ else [CellProcessor.cell_stack[tid][-1], get_micro_step()]
270
309
  if self.scope:
271
310
  self.scope.end_module(full_name)
@@ -16,6 +16,7 @@
16
16
  import inspect
17
17
  import os
18
18
  import random
19
+ import sys
19
20
  import types
20
21
 
21
22
  import mindspore as ms
@@ -41,6 +42,7 @@ else:
41
42
  mindtorch_check_result = None
42
43
  register_backward_hook_functions = {}
43
44
  kwargs_exist_in_forward_hook = None
45
+ is_output_of_backward_hook_a_view = None
44
46
 
45
47
 
46
48
  class MsprobeStep(ms.train.Callback):
@@ -129,7 +131,7 @@ def list_lowest_level_directories(root_dir):
129
131
  return lowest_level_dirs
130
132
 
131
133
 
132
- def seed_all(seed=1234, mode=False, rm_dropout=True):
134
+ def seed_all(seed=1234, mode=False, rm_dropout=False):
133
135
  check_seed_all(seed, mode, rm_dropout)
134
136
  os.environ['PYTHONHASHSEED'] = str(seed)
135
137
  ms.set_seed(seed)
@@ -179,6 +181,8 @@ def is_mindtorch():
179
181
  global mindtorch_check_result
180
182
  if mindtorch_check_result is None:
181
183
  mindtorch_check_result = False
184
+ if 'torch' not in sys.modules:
185
+ return mindtorch_check_result
182
186
  try:
183
187
  import torch
184
188
  except ImportError:
@@ -254,14 +258,14 @@ def is_decorated_by_jit(func):
254
258
 
255
259
 
256
260
  @recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names')
257
- def get_cells_and_names(model, cells_set=None, name_prefix=''):
261
+ def get_cells_and_names(model, cells_set=None, name_prefix='', parent_cell=None):
258
262
  cells_set = cells_set if cells_set else set()
259
263
  if model in cells_set:
260
264
  return
261
265
 
262
266
  cells_set.add(model)
263
267
  jit_decorated = is_decorated_by_jit(model.construct)
264
- yield name_prefix, model, jit_decorated
268
+ yield name_prefix, model, jit_decorated, parent_cell
265
269
  if jit_decorated:
266
270
  return
267
271
 
@@ -271,9 +275,9 @@ def get_cells_and_names(model, cells_set=None, name_prefix=''):
271
275
  cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name
272
276
  jit_decorated = is_decorated_by_jit(model.construct)
273
277
  if jit_decorated:
274
- yield cells_name_prefix, cell, jit_decorated
278
+ yield cells_name_prefix, cell, jit_decorated, model
275
279
  else:
276
- for ele in get_cells_and_names(cell, cells_set, cells_name_prefix):
280
+ for ele in get_cells_and_names(cell, cells_set, cells_name_prefix, model):
277
281
  yield ele
278
282
 
279
283
 
@@ -284,9 +288,9 @@ def get_cells_and_names_with_index(models):
284
288
  def distinguish_cells(cells):
285
289
  cells_in_pynative_mode = []
286
290
  cells_in_graph_mode = []
287
- for name, cell, jit_decorated in cells:
291
+ for name, cell, jit_decorated, parent_cell in cells:
288
292
  if jit_decorated:
289
- cells_in_graph_mode.append((name, cell))
293
+ cells_in_graph_mode.append((name, cell, parent_cell))
290
294
  else:
291
295
  cells_in_pynative_mode.append((name, cell))
292
296
  return cells_in_pynative_mode, cells_in_graph_mode
@@ -329,3 +333,43 @@ def has_kwargs_in_forward_hook():
329
333
  return kwargs_exist_in_forward_hook
330
334
 
331
335
  return kwargs_exist_in_forward_hook
336
+
337
+
338
+ def is_backward_hook_output_a_view():
339
+ global is_output_of_backward_hook_a_view
340
+
341
+ if is_output_of_backward_hook_a_view is None:
342
+ is_output_of_backward_hook_a_view = False
343
+ if getattr(ms, '__version__', '2.4.0') < '2.7.0':
344
+ return is_output_of_backward_hook_a_view
345
+ try:
346
+ from mindspore.ops.operations import _inner_ops as inner
347
+ call_func = getattr(inner.CellBackwardHook, '__call__')
348
+ func_params = inspect.signature(call_func).parameters
349
+ except Exception:
350
+ return is_output_of_backward_hook_a_view
351
+ if 'args' in func_params and func_params['args'].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
352
+ is_output_of_backward_hook_a_view = True
353
+
354
+ return is_output_of_backward_hook_a_view
355
+
356
+
357
+ def wrap_backward_hook_call_func(call_func):
358
+ if not is_backward_hook_output_a_view():
359
+ return call_func
360
+
361
+ from mindspore.common.api import _pynative_executor as executor
362
+ from mindspore._c_expression import CreationType
363
+
364
+ def new_call(self, args):
365
+ outputs = call_func(self, args)
366
+ if isinstance(outputs, ms.Tensor):
367
+ executor.set_creation_type(outputs, CreationType.DEFAULT)
368
+ elif isinstance(outputs, tuple):
369
+ for item in outputs:
370
+ if isinstance(item, ms.Tensor):
371
+ executor.set_creation_type(item, CreationType.DEFAULT)
372
+ return outputs
373
+ new_call.__name__ = '__call__'
374
+
375
+ return new_call
@@ -154,21 +154,34 @@ def find_npy_files(directory):
154
154
  dirs.clear()
155
155
  for file in files:
156
156
  if file.endswith(".npy"):
157
- # 分割文件名并去掉最后两个元素
158
- file_name = file.split('_')
159
- if len(file_name) < 2:
157
+ # 正确移除文件扩展名
158
+ base_name = os.path.splitext(file)
159
+ if not base_name or len(base_name) < 1:
160
+ logger.warning("Invalid file encountered.")
160
161
  continue
161
- key = '_'.join(file_name[:-2])
162
- # 文件的完整路径
163
- value = os.path.join(root, file)
164
- # 添加到字典中
165
- if not npy_files_dict.get(key):
166
- npy_files_dict[key] = []
167
- npy_files_dict[key].append(value)
162
+ file_name = base_name[0]
163
+
164
+ logger.info(f"Generating file info for file: {file}")
165
+
166
+ # 使用一致的分割逻辑
167
+ file_ele = file_name.split('_')
168
+
169
+ if len(file_ele) < 2:
170
+ continue
171
+
172
+ key = '_'.join(file_ele[:-2])
173
+ if key:
174
+ # 文件的完整路径
175
+ value = os.path.join(root, file)
176
+ # 添加到字典中
177
+ if key not in npy_files_dict:
178
+ npy_files_dict[key] = []
179
+ npy_files_dict[key].append(value)
168
180
  return npy_files_dict
169
181
 
170
182
 
171
183
  def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
184
+ result_dict = {}
172
185
  for k, npu_file_list in npu_file_dict.items():
173
186
  bench_file_list = bench_file_dict.get(k)
174
187
  if not bench_file_list and k in name_map_dict:
@@ -176,7 +189,6 @@ def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
176
189
  bench_length = len(bench_file_list)
177
190
  if not (bench_file_list and bench_length):
178
191
  continue
179
- result_dict = {}
180
192
  for i, npu_file in enumerate(npu_file_list):
181
193
  if i >= bench_length:
182
194
  break
@@ -200,14 +212,14 @@ def do_multi_process(func, map_dict):
200
212
  df_chunks = [result_df]
201
213
  process_num = 1
202
214
  logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}")
203
-
215
+
204
216
  # 分割字典
205
217
  map_chunks = split_dict(map_dict, df_chunk_size)
206
-
218
+
207
219
  # 创建结果列表和进程池
208
220
  results = []
209
221
  pool = multiprocessing.Pool(process_num)
210
-
222
+
211
223
  progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
212
224
 
213
225
  def update_progress(size, progress_lock, extra_param=None):
@@ -216,34 +228,30 @@ def do_multi_process(func, map_dict):
216
228
 
217
229
  def err_call(args):
218
230
  logger.error('multiprocess compare failed! Reason: {}'.format(args))
219
- try:
220
- pool.close()
221
- except OSError as e:
222
- logger.error(f'pool terminate failed: {str(e)}')
231
+
223
232
  results = []
233
+
234
+ # 提交任务到进程池
235
+ for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
236
+ start_idx = df_chunk_size * process_idx
237
+ result = pool.apply_async(
238
+ func,
239
+ args=(df_chunk, start_idx, map_chunk, lock),
240
+ error_callback=err_call,
241
+ callback=partial(update_progress, len(map_chunk), lock)
242
+ )
243
+ results.append(result)
244
+ pool.close()
245
+
224
246
  try:
225
- # 提交任务到进程池
226
- for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
227
- start_idx = df_chunk_size * process_idx
228
- result = pool.apply_async(
229
- func,
230
- args=(df_chunk, start_idx, map_chunk, lock),
231
- error_callback=err_call,
232
- callback=partial(update_progress, len(map_chunk), lock)
233
- )
234
- results.append(result)
235
-
236
- final_results = [r.get() for r in results]
237
- # 等待所有任务完成
238
- pool.close()
239
- pool.join()
240
- return pd.concat(final_results, ignore_index=True)
247
+ final_results = [r.get(timeout=3600) for r in results]
241
248
  except Exception as e:
242
- logger.error(f"\nMain process error: {str(e)}")
249
+ logger.error(f"Task failed with exception: {e}")
243
250
  pool.terminate()
244
251
  return pd.DataFrame({})
245
- finally:
246
- pool.close()
252
+ # 等待所有任务完成
253
+ pool.join()
254
+ return pd.concat(final_results, ignore_index=True)
247
255
 
248
256
 
249
257
  def initialize_result_df(total_size):
@@ -35,8 +35,16 @@ def ms_compare(input_param, output_path, **kwargs):
35
35
  config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path)
36
36
 
37
37
  is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
38
- mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
39
- config.dump_mode, config.compared_file_type)
38
+
39
+ config_dict = {
40
+ 'stack_mode': config.stack_mode,
41
+ 'auto_analyze': config.auto_analyze,
42
+ 'fuzzy_match': config.fuzzy_match,
43
+ 'highlight': config.highlight,
44
+ 'dump_mode': config.dump_mode,
45
+ 'compared_file_type': config.compared_file_type
46
+ }
47
+ mode_config = ModeConfig(**config_dict)
40
48
  mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping)
41
49
  ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework)
42
50
  ms_comparator.compare_core(input_param, output_path, suffix=config.suffix)
@@ -34,10 +34,11 @@ class RowData:
34
34
  self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
35
35
  self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
36
36
  self.statistic_data = copy.deepcopy(CompareConst.MS_GRAPH_STATISTIC)
37
+ self.csv = copy.deepcopy(CompareConst.MS_GRAPH_CSV)
37
38
  if mode == GraphMode.NPY_MODE:
38
39
  self.data = {**self.basic_data, **self.npy_data}
39
40
  else:
40
- self.data = {**self.basic_data, **self.statistic_data}
41
+ self.data = {**self.basic_data, **self.statistic_data, **self.csv}
41
42
 
42
43
  def __call__(self):
43
44
  return self.data
@@ -80,8 +81,8 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
80
81
  data_list = []
81
82
  statistic_data_list = []
82
83
  header_index = {
83
- 'Data Type': None, 'Shape': None, 'Max Value': None,
84
- 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
84
+ 'Data Type': None, 'Shape': None,
85
+ 'Max Value': None, 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
85
86
  }
86
87
  for statistic_file in statistic_file_list:
87
88
  content = read_csv(statistic_file, as_pd=False)
@@ -107,7 +108,7 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
107
108
  logger.error(f'Dump file {statistic_file_path} has been modified into incorrect format!')
108
109
  raise CompareException(f'Dump file {statistic_file_path} has been modified into incorrect format!')
109
110
  compare_key = f"{data[1]}.{data[2]}.{data[5]}.{data[6]}" # OpName, TaskId, IO, Slot
110
- op_name = f"{compare_key} {statistic_file_path}"
111
+ op_name = f"{compare_key}"
111
112
  timestamp = int(data[4])
112
113
  result_data = [op_name, compare_key, timestamp]
113
114
  for key in header_index.keys():
@@ -115,6 +116,8 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
115
116
  result_data.append(np.nan)
116
117
  else:
117
118
  result_data.append(data[header_index[key]])
119
+ csv_file = f"{statistic_file_path}"
120
+ result_data.append(csv_file)
118
121
  data_list.append(result_data)
119
122
  return data_list
120
123
 
@@ -230,6 +233,17 @@ class GraphMSComparator:
230
233
  result[f'{prefix} min'] = np.float32(rows[f'{prefix} min'])
231
234
  result[f'{prefix} mean'] = np.float32(rows[f'{prefix} mean'])
232
235
  result[f'{prefix} l2norm'] = np.float32(rows[f'{prefix} l2norm'])
236
+ result[f'{prefix} CSV File'] = rows[f'{prefix} CSV File']
237
+
238
+ def calculate_relative_error(numerator, denominator):
239
+ """Calculates relative error, handling division by zero and NaN."""
240
+ if denominator != 0:
241
+ result = numerator / denominator
242
+ if not np.isnan(result):
243
+ return str(abs(result * 100)) + "%"
244
+ else:
245
+ return CompareConst.NAN
246
+ return CompareConst.N_A
233
247
 
234
248
  # 使用示例
235
249
  update_result_dict(result_dict, row, 'NPU')
@@ -237,34 +251,26 @@ class GraphMSComparator:
237
251
  error_flag, error_message = statistics_data_check(result_dict)
238
252
  result_dict[CompareConst.ERROR_MESSAGE] += error_message
239
253
  if not error_flag:
240
- result_dict[CompareConst.MAX_DIFF] = np.abs(
241
- result_dict[CompareConst.NPU_MAX] - result_dict[CompareConst.BENCH_MAX])
242
- result_dict[CompareConst.MIN_DIFF] = np.abs(
243
- result_dict[CompareConst.NPU_MIN] - result_dict[CompareConst.BENCH_MIN])
244
- result_dict[CompareConst.MEAN_DIFF] = np.abs(
245
- result_dict[CompareConst.NPU_MEAN] - result_dict[CompareConst.BENCH_MEAN])
246
- result_dict[CompareConst.NORM_DIFF] = np.abs(
247
- result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
248
- result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
249
- CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
250
- if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
251
- result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
252
- result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
253
- result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
254
- CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
255
- if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
256
- result_dict[CompareConst.MIN_RELATIVE_ERR] = \
257
- str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
258
- result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
259
- CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
260
- if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
261
- result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
262
- result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
263
- result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
264
- CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
265
- if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
266
- result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
267
- result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
254
+ metrics = [
255
+ (CompareConst.MAX_DIFF, CompareConst.NPU_MAX, CompareConst.BENCH_MAX),
256
+ (CompareConst.MIN_DIFF, CompareConst.NPU_MIN, CompareConst.BENCH_MIN),
257
+ (CompareConst.MEAN_DIFF, CompareConst.NPU_MEAN, CompareConst.BENCH_MEAN),
258
+ (CompareConst.NORM_DIFF, CompareConst.NPU_NORM, CompareConst.BENCH_NORM),
259
+ ]
260
+ relative_error_metrics = [
261
+ (CompareConst.MAX_RELATIVE_ERR, CompareConst.MAX_DIFF, CompareConst.BENCH_MAX),
262
+ (CompareConst.MIN_RELATIVE_ERR, CompareConst.MIN_DIFF, CompareConst.BENCH_MIN),
263
+ (CompareConst.MEAN_RELATIVE_ERR, CompareConst.MEAN_DIFF, CompareConst.BENCH_MEAN),
264
+ (CompareConst.NORM_RELATIVE_ERR, CompareConst.NORM_DIFF, CompareConst.BENCH_NORM),
265
+ ]
266
+
267
+ for diff_metric, npu_metric, bench_metric in metrics:
268
+ result_dict[diff_metric] = result_dict[npu_metric] - result_dict[bench_metric]
269
+
270
+ for rel_metric, diff_metric, bench_metric in relative_error_metrics:
271
+ result_dict[rel_metric] = calculate_relative_error(result_dict[diff_metric],
272
+ result_dict[bench_metric])
273
+
268
274
  magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
269
275
  max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
270
276
  if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
@@ -296,20 +302,8 @@ class GraphMSComparator:
296
302
  compare_result_df = self.do_multi_process(compare_result_df, mode)
297
303
  compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
298
304
  compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
299
- self.to_excel(compare_result_df, compare_result_path)
300
- logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
301
-
302
- def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
303
- size = len(compare_result_df)
304
- # sheet size cannot be larger than 1048576
305
- if size < CompareConst.MAX_EXCEL_LENGTH:
306
- compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
307
- need_slice else compare_result_path
308
305
  save_excel(compare_result_path, compare_result_df)
309
- return slice_num + 1
310
- else:
311
- slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
312
- return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
306
+ logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
313
307
 
314
308
  def compare_process(self, rank_id, step_id):
315
309
  # generate data_path
@@ -331,7 +325,7 @@ class GraphMSComparator:
331
325
  bench_data_list.extend(data_list)
332
326
 
333
327
  if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE:
334
- logger.warning(f"Data_path {npu_data_path} or {bench_data_path} is not exist.")
328
+ logger.warning(f"Data path: npu_data_path or bench_data_path does not exist.")
335
329
  return [], ''
336
330
  if npu_mode != bench_mode:
337
331
  logger.error(f"NPU mode {npu_mode} not equal to MATCH mode {bench_mode}.")
@@ -344,14 +338,15 @@ class GraphMSComparator:
344
338
  npu_data_df = pd.DataFrame(npu_data_list,
345
339
  columns=[CompareConst.NPU_NAME, 'Compare Key', 'TimeStamp',
346
340
  CompareConst.NPU_DTYPE, CompareConst.NPU_SHAPE,
347
- CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN,
348
- CompareConst.NPU_NORM])
341
+ CompareConst.NPU_MAX, CompareConst.NPU_MIN,
342
+ CompareConst.NPU_MEAN, CompareConst.NPU_NORM,
343
+ CompareConst.NPU_CSV_FILE])
349
344
  bench_data_df = pd.DataFrame(bench_data_list,
350
345
  columns=[CompareConst.BENCH_NAME, 'Compare Key', 'TimeStamp',
351
- CompareConst.BENCH_DTYPE,
352
- CompareConst.BENCH_SHAPE, CompareConst.BENCH_MAX,
353
- CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
354
- CompareConst.BENCH_NORM])
346
+ CompareConst.BENCH_DTYPE, CompareConst.BENCH_SHAPE,
347
+ CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
348
+ CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM,
349
+ CompareConst.BENCH_CSV_FILE])
355
350
 
356
351
  npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
357
352
  npu_float_data_df = npu_data_df[npu_float_type].astype(str)
@@ -49,8 +49,9 @@ class DebuggerConfig:
49
49
  self.summary_mode = task_config.summary_mode
50
50
  self.stat_cal_mode = task_config.stat_cal_mode if hasattr(task_config, 'stat_cal_mode') else None
51
51
  self.device_stat_precision_mode = task_config.device_stat_precision_mode \
52
- if hasattr(task_config, 'device_stat_precision_mode') else None
52
+ if hasattr(task_config, 'device_stat_precision_mode') else None
53
53
  self.async_dump = common_config.async_dump if common_config.async_dump else False
54
+ self.precision = common_config.precision if common_config.precision else Const.DUMP_PRECISION_LOW
54
55
  self.check()
55
56
  self._check_statistics_config(task_config)
56
57
  create_directory(self.dump_path)
@@ -115,18 +116,28 @@ class DebuggerConfig:
115
116
  self.check_mode = "all"
116
117
  if not isinstance(self.async_dump, bool):
117
118
  raise Exception("The parameters async_dump should be bool.")
118
- if self.async_dump and self.task == Const.TENSOR:
119
- if self.level_ori == Const.LEVEL_DEBUG:
120
- self.list = [] # async_dump + debug level case ignore list
121
- if not self.list and self.level_ori != Const.LEVEL_DEBUG:
122
- raise Exception("The parameters async_dump is true in tensor task,"
123
- " the parameters list cannot be empty.")
124
119
  if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
125
120
  logger.warning_on_rank_0(
126
121
  f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
127
122
  f"If not, the default level is {Const.LEVEL_MIX}."
128
123
  )
129
124
  self.level_ori = Const.LEVEL_MIX
125
+ if self.async_dump:
126
+ if self.task == Const.TENSOR:
127
+ if self.level_ori == Const.LEVEL_DEBUG:
128
+ self.list = [] # async_dump + debug level case ignore list
129
+ if not self.list and self.level_ori != Const.LEVEL_DEBUG:
130
+ raise MsprobeException(
131
+ MsprobeException.INVALID_PARAM_ERROR,
132
+ "The parameters async_dump is true in tensor task, the parameters list cannot be empty."
133
+ )
134
+ is_unsupported_mode = self.summary_mode == Const.MD5 or \
135
+ isinstance(self.summary_mode, list) and Const.MD5 in self.summary_mode
136
+ if is_unsupported_mode:
137
+ raise MsprobeException(
138
+ MsprobeException.INVALID_PARAM_ERROR,
139
+ f"The parameters async_dump is true, the parameters summary_mode cannot be/contain md5."
140
+ )
130
141
  return True
131
142
 
132
143
  def check_config_with_l2(self, is_graph_config):