mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -170,6 +170,16 @@ def gen_op_item(op_data, op_name):
170
170
  elif op_item.get('type') == 'slice':
171
171
  op_item['dtype'] = op_data.get('type')
172
172
  op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
173
+ elif op_item.get('type') == 'ellipsis':
174
+ op_item['dtype'] = op_data.get('type')
175
+ op_item['shape'] = '[]'
176
+ for i in params:
177
+ op_item[i] = op_data.get('value')
178
+ elif op_item.get('type') == 'torch.ProcessGroup':
179
+ op_item['dtype'] = op_data.get('type')
180
+ op_item['shape'] = '[]'
181
+ for i in params:
182
+ op_item[i] = str(op_data.get('group_ranks'))
173
183
  else:
174
184
  op_item['dtype'] = str(type(op_data.get('value')))
175
185
  op_item['shape'] = '[]'
@@ -275,9 +285,9 @@ def result_item_init(n_info, b_info, dump_mode):
275
285
  md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
276
286
  result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
277
287
  elif dump_mode == Const.SUMMARY:
278
- result_item.extend([" "] * 8)
288
+ result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
279
289
  else:
280
- result_item.extend([" "] * 5)
290
+ result_item.extend([" "] * 6) # 6个真实数据情况的比对指标
281
291
  else:
282
292
  err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
283
293
  f"npu_info_struct is {n_info.struct}\n" \
@@ -311,8 +321,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
311
321
  has_stack = npu_stack_info and bench_stack_info
312
322
 
313
323
  if dump_mode == Const.ALL:
314
- npu_data_name = n_dict.get("data_name", None)
315
- bench_data_name = b_dict.get("data_name", None)
324
+ npu_data_name_list = n_dict.get("data_name", None)
325
+ bench_data_name_list = b_dict.get("data_name", None)
316
326
 
317
327
  for index in range(min_len):
318
328
  n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
@@ -343,7 +353,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
343
353
  result_item.append(err_msg)
344
354
  result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
345
355
  if dump_mode == Const.ALL:
346
- result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
356
+ npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
357
+ bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list")
358
+ result_item.append([npu_data_name, bench_data_name])
347
359
 
348
360
  result.append(result_item)
349
361
 
@@ -361,7 +373,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
361
373
  continue
362
374
  result_item = [
363
375
  n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
364
- " ", " ", " ", " ", " "
376
+ " ", " ", " ", " ", " ", " "
365
377
  ]
366
378
  summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
367
379
  result_item.extend(summary_data)
@@ -378,7 +390,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
378
390
  result_item.append(err_msg)
379
391
  result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
380
392
  if dump_mode == Const.ALL:
381
- result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
393
+ npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
394
+ result_item.append([npu_data_name, "-1"])
382
395
 
383
396
  result.append(result_item)
384
397
 
@@ -443,9 +456,9 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
443
456
  result.append(result_item)
444
457
  continue
445
458
  if dump_mode == Const.SUMMARY:
446
- result_item.extend([CompareConst.N_A] * 8)
459
+ result_item.extend([CompareConst.N_A] * 8) # 8个统计量数据情况的比对指标
447
460
  if dump_mode == Const.ALL:
448
- result_item.extend([CompareConst.N_A] * 5)
461
+ result_item.extend([CompareConst.N_A] * 6) # 6个真实数据情况的比对指标
449
462
 
450
463
  npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
451
464
  bench_summary_data = [CompareConst.N_A] * 4
@@ -457,7 +470,7 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
457
470
  result_item.append(err_msg)
458
471
  append_stack_info(result_item, npu_stack_info, index)
459
472
  if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
460
- result_item.extend(["-1"])
473
+ result_item.extend([["-1", "-1"]])
461
474
  result.append(result_item)
462
475
 
463
476
 
@@ -532,10 +545,17 @@ def get_name_and_state(name):
532
545
 
533
546
  state type: input, output, kwargs, parameters, parameters_grad
534
547
  """
548
+ if not isinstance(name, str):
549
+ logger.error(f'Invalid name: {name}, type should be string, please check.')
550
+ raise CompareException(CompareException.INVALID_API_NAME_ERROR)
551
+
535
552
  if Const.PARAMS_GRAD in name.split(Const.SEP):
536
553
  return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
537
554
 
538
555
  split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
556
+ if len(split) < 3:
557
+ logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
558
+ raise CompareException(CompareException.INVALID_API_NAME_ERROR)
539
559
  api = f'{split[0]}.{split[1]}.'
540
560
  state_str = split[2]
541
561
  match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Any, Optional, Callable, Union, List, Tuple
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import load_yaml
20
+
21
+
22
+ def _get_attr(module, attr_name):
23
+ if Const.SEP in attr_name:
24
+ sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1)
25
+ sub_module = getattr(module, sub_module_name, None)
26
+ attr = getattr(sub_module, sub_attr, None)
27
+ else:
28
+ attr = getattr(module, attr_name, None)
29
+ return attr
30
+
31
+
32
+ class ApiWrapper:
33
+ def __init__(
34
+ self, api_types: Dict[str, Dict[str, Any]],
35
+ api_list_paths: Union[str, List[str], Tuple[str]]
36
+ ):
37
+ self.api_types = api_types
38
+ if not isinstance(api_list_paths, (list, tuple)):
39
+ api_list_paths = [api_list_paths] * len(self.api_types)
40
+ elif len(api_list_paths) != len(self.api_types):
41
+ raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
42
+ "when api_list_paths is a list or tuple.")
43
+ self.api_list_paths = api_list_paths
44
+ self.api_names = self._get_api_names()
45
+ self.wrapped_api_functions = dict()
46
+
47
+ def wrap_api(
48
+ self, api_templates, hook_build_func: Optional[Callable]
49
+ ):
50
+ api_types_num = sum([len(v) for v in self.api_types.values()])
51
+ if not isinstance(api_templates, (list, tuple)):
52
+ api_templates = [api_templates] * api_types_num
53
+ elif len(api_templates) != api_types_num:
54
+ raise RuntimeError("The number of api_templates must be equal to the number of api_types, "
55
+ "when api_templates is a list or tuple.")
56
+
57
+ self.wrapped_api_functions.clear()
58
+ index = 0
59
+ for framework, api_types in self.api_types.items():
60
+ wrapped_functions_in_framework = dict()
61
+ for api_type, api_modules in api_types.items():
62
+ wrapped_functions = dict()
63
+ name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API")
64
+ api_template = api_templates[index]
65
+ index += 1
66
+ for api_name in self.api_names.get(framework, {}).get(api_type, []):
67
+ ori_api = _get_attr(api_modules[0], api_name)
68
+ if callable(ori_api):
69
+ def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
70
+ def api_function(*args, **kwargs):
71
+ return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
72
+ api_function.__name__ = api_name
73
+ return api_function
74
+ wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
75
+ hook_build_func, api_template)
76
+ wrapped_functions_in_framework[api_type] = wrapped_functions
77
+ self.wrapped_api_functions[framework] = wrapped_functions_in_framework
78
+ return self.wrapped_api_functions
79
+
80
+ def _get_api_names(self):
81
+ api_names = dict()
82
+
83
+ for index, framework in enumerate(self.api_types.keys()):
84
+ api_list = load_yaml(self.api_list_paths[index])
85
+ valid_names = dict()
86
+ for api_type, api_modules in self.api_types.get(framework, {}).items():
87
+ api_from_file = api_list.get(Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type), [])
88
+ names = set()
89
+ for api_name in api_from_file:
90
+ target_attr = api_name
91
+ target_module = api_modules[0]
92
+ if Const.SEP in api_name:
93
+ sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
94
+ target_module = getattr(api_modules[0], sub_module_name, None)
95
+ if target_module and target_attr in dir(target_module):
96
+ names.add(api_name)
97
+ valid_names[api_type] = names
98
+ api_names[framework] = valid_names
99
+
100
+ return api_names
101
+
102
+
103
+ class ApiRegistry:
104
+ """
105
+ Base class for api registry.
106
+ """
107
+
108
+ def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates):
109
+ self.ori_api_attr = dict()
110
+ self.wrapped_api_attr = dict()
111
+ self.inner_used_ori_attr = dict()
112
+ self.inner_used_wrapped_attr = dict()
113
+ self.api_types = api_types
114
+ self.inner_used_api = inner_used_api
115
+ self.supported_api_list_path = supported_api_list_path
116
+ self.api_templates = api_templates
117
+
118
+ @staticmethod
119
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
120
+ for api in api_list:
121
+ api_ori_attr[api] = _get_attr(ori_api_group, api)
122
+
123
+ @staticmethod
124
+ def set_api_attr(api_group, attr_dict):
125
+ for api, api_attr in attr_dict.items():
126
+ if Const.SEP in api:
127
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
128
+ sub_module = getattr(api_group, sub_module_name, None)
129
+ if sub_module is not None:
130
+ setattr(sub_module, sub_op, api_attr)
131
+ else:
132
+ setattr(api_group, api, api_attr)
133
+
134
+ def register_all_api(self):
135
+ for framework, api_types in self.api_types.items():
136
+ for api_type, api_modules in api_types.items():
137
+ api_type_with_framework = framework + Const.SEP + api_type
138
+ for module in api_modules[1]:
139
+ self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {}))
140
+
141
+ def register_inner_used_api(self):
142
+ for api_type in self.inner_used_api.keys():
143
+ self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
144
+
145
+ def restore_all_api(self):
146
+ for framework, api_types in self.api_types.items():
147
+ for api_type, api_modules in api_types.items():
148
+ api_type_with_framework = framework + Const.SEP + api_type
149
+ for module in api_modules[1]:
150
+ self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {}))
151
+
152
+ def restore_inner_used_api(self):
153
+ for api_type in self.inner_used_api.keys():
154
+ self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
155
+
156
+ def initialize_hook(self, hook_build_func):
157
+ api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path)
158
+ wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
159
+
160
+ for framework, api_types in self.api_types.items():
161
+ for api_type, api_modules in api_types.items():
162
+ ori_attr = dict()
163
+ self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr)
164
+ api_type_with_framework = framework + Const.SEP + api_type
165
+ self.ori_api_attr[api_type_with_framework] = ori_attr
166
+ self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type)
167
+
168
+ for inner_used_api_type, inner_used_api_list in self.inner_used_api.items():
169
+ ori_attr = dict()
170
+ wrapped_attr = dict()
171
+ for api_name in inner_used_api_list[1:]:
172
+ if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name):
173
+ ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name)
174
+ wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name)
175
+ self.inner_used_ori_attr[inner_used_api_type] = ori_attr
176
+ self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -40,6 +40,7 @@ class DataCollector:
40
40
  self.scope = ScopeFactory(self.config).build_scope()
41
41
  self.backward_module_names = {}
42
42
  self.optimizer_status = ""
43
+ self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
43
44
  atexit.register(self.write_json)
44
45
 
45
46
  @property
@@ -54,6 +55,17 @@ class DataCollector:
54
55
  def check_scope_and_pid(scope, name, pid):
55
56
  return (not scope or scope.check(name)) and pid == os.getpid()
56
57
 
58
+ @staticmethod
59
+ def set_is_recomputable(data_info, is_recompute):
60
+ if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1
61
+ data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute
62
+
63
+ def reset_status(self):
64
+ self.optimizer_status = ""
65
+ self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
66
+ self.data_writer.reset_cache()
67
+ self.backward_module_names.clear()
68
+
57
69
  def if_return_forward_new_output(self):
58
70
  return self.data_processor.if_return_forward_new_output()
59
71
 
@@ -77,7 +89,7 @@ class DataCollector:
77
89
  logger.debug(msg)
78
90
  self.data_writer.update_data(data_info)
79
91
 
80
- def forward_input_data_collect(self, name, module, pid, module_input_output):
92
+ def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
81
93
  if self.config.task == Const.FREE_BENCHMARK:
82
94
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
83
95
  if self.check_scope_and_pid(self.scope, backward_name, pid):
@@ -87,37 +99,48 @@ class DataCollector:
87
99
  if not self.check_scope_and_pid(self.scope, name, pid):
88
100
  return
89
101
 
90
- data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
102
+ data_info = {}
103
+ if self.config.task != Const.STRUCTURE:
104
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
105
+ self.set_is_recomputable(data_info, is_recompute)
91
106
  if self.config.level == Const.LEVEL_L2:
92
107
  return
93
108
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
94
109
 
95
- def forward_output_data_collect(self, name, module, pid, module_input_output):
110
+ def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
96
111
  self.update_construct(name)
97
112
  if not self.check_scope_and_pid(self.scope, name, pid):
98
113
  return
99
114
 
100
- data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
115
+ data_info = {}
116
+ if self.config.task != Const.STRUCTURE:
117
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
118
+ self.set_is_recomputable(data_info, is_recompute)
101
119
  if self.config.level == Const.LEVEL_L2:
102
120
  return
103
121
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
104
122
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
105
123
 
106
- def forward_data_collect(self, name, module, pid, module_input_output):
124
+ def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
107
125
  self.update_construct(name)
108
126
  if not self.check_scope_and_pid(self.scope, name, pid):
109
127
  return
110
128
 
111
- data_info = self.data_processor.analyze_forward(name, module, module_input_output)
129
+ data_info = {}
130
+ if self.config.task != Const.STRUCTURE:
131
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
132
+ self.set_is_recomputable(data_info, is_recompute)
112
133
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
113
134
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
114
135
 
115
- def backward_data_collect(self, name, module, pid, module_input_output):
136
+ def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
116
137
  self.update_construct(name)
117
138
  if not self.check_scope_and_pid(self.scope, name, pid):
118
139
  return
119
140
 
120
- data_info = self.data_processor.analyze_backward(name, module, module_input_output)
141
+ data_info = {}
142
+ if self.config.task != Const.STRUCTURE:
143
+ data_info = self.data_processor.analyze_backward(name, module, module_input_output)
121
144
  if self.config.level == Const.LEVEL_L2:
122
145
  return
123
146
  # 获取执行反向的模块名称
@@ -127,25 +150,34 @@ class DataCollector:
127
150
  self.backward_module_names[module_name] = True
128
151
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
129
152
 
130
- def backward_input_data_collect(self, name, module, pid, module_input_output):
153
+ def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
131
154
  self.update_construct(name)
132
155
  if not self.check_scope_and_pid(self.scope, name, pid):
133
156
  return
134
157
 
135
- data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
158
+ data_info = {}
159
+ if self.config.task != Const.STRUCTURE:
160
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
161
+ self.set_is_recomputable(data_info, is_recompute)
136
162
  self.handle_data(name, data_info)
137
163
 
138
- def backward_output_data_collect(self, name, module, pid, module_input_output):
164
+ def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
139
165
  self.update_construct(name)
140
166
  if not self.check_scope_and_pid(self.scope, name, pid):
141
167
  return
142
168
 
143
- data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
169
+ data_info = {}
170
+ if self.config.task != Const.STRUCTURE:
171
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
172
+ self.set_is_recomputable(data_info, is_recompute)
144
173
  self.handle_data(name, data_info)
145
174
 
146
175
  def update_construct(self, name):
147
176
  if self.config.level not in DataCollector.level_without_construct:
148
177
  if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
178
+ if self.optimizer_status_first_start[self.optimizer_status]:
179
+ self.data_writer.update_construct({self.optimizer_status: None})
180
+ self.optimizer_status_first_start[self.optimizer_status] = False
149
181
  self.data_writer.update_construct({name: self.optimizer_status})
150
182
  else:
151
183
  self.data_writer.update_construct({name: self.module_processor.api_parent_node})
@@ -183,3 +215,16 @@ class DataCollector:
183
215
 
184
216
  def fill_stack_tensor_data(self):
185
217
  self.data_writer.fill_stack_tensor_data()
218
+
219
+ def debug_data_collect_forward(self, variable, name_with_count):
220
+
221
+ data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
222
+ self.data_writer.update_debug({name_with_count: data_info})
223
+
224
+ def debug_data_collect_backward(self, variable, grad_name_with_count):
225
+ # prepare all None nested data structure
226
+ all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
227
+ self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
228
+
229
+ # register tensor backward hook
230
+ self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
@@ -17,6 +17,9 @@ import inspect
17
17
  import os
18
18
  from dataclasses import dataclass, is_dataclass
19
19
  from typing import Tuple, Dict, Optional, Any
20
+ from functools import partial
21
+ import copy
22
+ from typing import Union
20
23
 
21
24
  import numpy as np
22
25
 
@@ -87,7 +90,7 @@ class TensorStatInfo:
87
90
  class BaseDataProcessor:
88
91
  _recursive_key_stack = []
89
92
  special_type = (
90
- np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
93
+ np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
91
94
  bool, int, float, str, slice,
92
95
  type(Ellipsis)
93
96
  )
@@ -143,6 +146,37 @@ class BaseDataProcessor:
143
146
  else:
144
147
  return data
145
148
 
149
+ @staticmethod
150
+ def set_value_into_nested_structure(data_structure, indexes, value):
151
+ '''
152
+ Args:
153
+ data_structure: nested data structure
154
+ indexes: List
155
+ value: value to be set
156
+ '''
157
+ if not indexes:
158
+ raise ValueError("set_value_into_nested_structure failed: "
159
+ "indexes need to be non empty when set value to nested data structure")
160
+ current_level = data_structure
161
+ for i, index in enumerate(indexes):
162
+ valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index
163
+ valid_for_dict = isinstance(current_level, dict) and index in current_level
164
+ is_last = i == len(indexes) - 1
165
+ if valid_for_dict or valid_for_list:
166
+ if is_last:
167
+ try:
168
+ current_level[index] = value
169
+ except Exception as e:
170
+ raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
171
+ else:
172
+ try:
173
+ current_level = current_level[index]
174
+ except Exception as e:
175
+ raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
176
+ else:
177
+ raise ValueError("set_value_into_nested_structure failed: "
178
+ "invalid data_structure type or invalid index")
179
+
146
180
  @staticmethod
147
181
  def _convert_numpy_to_builtin(arg):
148
182
  type_mapping = {
@@ -183,8 +217,22 @@ class BaseDataProcessor:
183
217
  return single_arg
184
218
 
185
219
  @staticmethod
186
- def _analyze_numpy(value, numpy_type):
187
- return {"type": numpy_type, "value": value}
220
+ def _analyze_numpy(ndarray, numpy_type):
221
+ ndarray_json = {}
222
+ ndarray_json.update({'type': 'numpy.ndarray'})
223
+ ndarray_json.update({'dtype': str(ndarray.dtype)})
224
+ ndarray_json.update({'shape': ndarray.shape})
225
+ if ndarray.size > 0:
226
+ ndarray_json.update({"Max": np.max(ndarray).item()})
227
+ ndarray_json.update({"Min": np.min(ndarray).item()})
228
+ ndarray_json.update({"Mean": np.mean(ndarray).item()})
229
+ ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
230
+ else:
231
+ ndarray_json.update({"Max": None})
232
+ ndarray_json.update({"Min": None})
233
+ ndarray_json.update({"Mean": None})
234
+ ndarray_json.update({"Norm": None})
235
+ return ndarray_json
188
236
 
189
237
  @staticmethod
190
238
  def _get_allowed_data_mode(data_mode):
@@ -203,9 +251,9 @@ class BaseDataProcessor:
203
251
  return cls.special_type
204
252
 
205
253
  @classmethod
206
- def recursive_apply_transform(cls, args, transform, depth=0):
207
- if depth > Const.MAX_DEPTH:
208
- logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
254
+ def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
255
+ if depth > Const.DUMP_MAX_DEPTH:
256
+ logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.")
209
257
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
210
258
  if isinstance(args, cls.get_special_types()):
211
259
  arg_transform = transform(args, cls._recursive_key_stack)
@@ -220,7 +268,7 @@ class BaseDataProcessor:
220
268
  return cls.apply_transform_dict(args_dict, transform, depth)
221
269
  elif isinstance(args, (list, tuple)):
222
270
  result_list = cls.apply_transform_list(args, transform, depth)
223
- return type(args)(result_list)
271
+ return result_list
224
272
  elif isinstance(args, dict):
225
273
  return cls.apply_transform_dict(args, transform, depth)
226
274
  elif args is not None:
@@ -228,12 +276,12 @@ class BaseDataProcessor:
228
276
  return None
229
277
  else:
230
278
  return None
231
-
279
+
232
280
  @classmethod
233
281
  def apply_transform_dict(cls, args, transform, depth):
234
282
  result_dict = {}
235
283
  for k, arg in args.items():
236
- cls._recursive_key_stack.append(str(k))
284
+ cls._recursive_key_stack.append(k)
237
285
  result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
238
286
  cls._recursive_key_stack.pop()
239
287
  return result_dict
@@ -242,11 +290,21 @@ class BaseDataProcessor:
242
290
  def apply_transform_list(cls, args, transform, depth):
243
291
  result_list = []
244
292
  for i, arg in enumerate(args):
245
- cls._recursive_key_stack.append(str(i))
293
+ cls._recursive_key_stack.append(i)
246
294
  result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
247
295
  cls._recursive_key_stack.pop()
248
296
  return result_list
249
297
 
298
+ @classmethod
299
+ def register_hook_single_element(cls, element, suffix_stack, hook_fn):
300
+ if cls.is_hookable_element(element):
301
+ indexes = copy.deepcopy(suffix_stack)
302
+ wrap_hook_fn = partial(hook_fn, indexes=indexes)
303
+
304
+ def real_hook_fn(grad):
305
+ return wrap_hook_fn(grad)
306
+ element.register_hook(real_hook_fn)
307
+
250
308
  def if_return_forward_new_output(self):
251
309
  return self._return_forward_new_output
252
310
 
@@ -383,3 +441,29 @@ class BaseDataProcessor:
383
441
  suffix + file_format)
384
442
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
385
443
  return dump_data_name, file_path
444
+
445
+ def analyze_element_to_all_none(self, element):
446
+ return self.recursive_apply_transform(element, lambda element, stack: None)
447
+
448
+ def analyze_debug_forward(self, variable, name_with_count):
449
+ self.current_api_or_module_name = name_with_count
450
+ self.api_data_category = Const.TENSOR
451
+ # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
452
+ data_info = self.analyze_element(variable)
453
+ return data_info
454
+
455
+ def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
456
+ def hook_fn(grad, indexes):
457
+ suffix = Const.SEP.join([str(index) for index in indexes])
458
+ self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
459
+ grad_data_info = self.analyze_element(grad)
460
+ self.save_name = None
461
+ full_index = [grad_name_with_count] + indexes
462
+ try:
463
+ self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
464
+ except (ValueError, IndexError) as e:
465
+ logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
466
+ f"skip current recording, detailed infomation: {e}")
467
+ return grad
468
+ wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
469
+ self.recursive_apply_transform(variable, wrap_register_hook_single_element)
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.core.common.const import Const
17
+ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor
17
18
 
18
19
 
19
20
  class DataProcessorFactory:
@@ -62,6 +63,7 @@ class DataProcessorFactory:
62
63
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
63
64
  cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
64
65
  cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
66
+ cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
65
67
  cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
66
68
  elif framework == Const.MS_FRAMEWORK:
67
69
  from msprobe.core.data_dump.data_processor.mindspore_processor import (
@@ -75,4 +77,5 @@ class DataProcessorFactory:
75
77
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
76
78
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
77
79
  cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
80
+ cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
78
81
  cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)