mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -46,6 +46,13 @@ class KernelGraphOverflowCheck:
46
46
  self.dump_json["common_dump_settings"]["op_debug_mode"] = 2
47
47
 
48
48
  def handle(self):
49
+ try:
50
+ from msprobe.lib import _msprobe_c
51
+ return
52
+ except ImportError:
53
+ # 如果没有_msprobe_ce_c走MindSpore老流程
54
+ logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
55
+
49
56
  if os.getenv("GRAPH_OP_RUN") == "1":
50
57
  raise Exception("Must run in graph mode, not kbk mode")
51
58
  json_path = self.dump_json["common_dump_settings"]["path"]
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,6 +20,8 @@ from collections import defaultdict
20
20
 
21
21
  import mindspore as ms
22
22
  from mindspore import nn
23
+ from mindspore.common.api import _no_grad
24
+ from mindspore.ops.primitive import Primitive
23
25
  try:
24
26
  from mindspore.common._pijit_context import PIJitCaptureContext
25
27
  except ImportError:
@@ -27,19 +29,25 @@ except ImportError:
27
29
  else:
28
30
  pijit_label = True
29
31
 
30
-
31
32
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
32
33
  from msprobe.core.common.file_utils import create_directory
33
34
  from msprobe.core.common.utils import Const, print_tools_ends_info
34
35
  from msprobe.core.data_dump.data_collector import build_data_collector
35
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
36
+ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
37
+ ModuleBackwardInputs)
36
38
  from msprobe.core.data_dump.scope import BaseScope
37
39
  from msprobe.mindspore.cell_processor import CellProcessor
38
40
  from msprobe.mindspore.common.log import logger
39
- from msprobe.mindspore.common.utils import get_rank_if_initialized
41
+ from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
42
+ is_mindtorch, register_backward_hook_functions)
40
43
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
41
44
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
42
45
  from msprobe.mindspore.dump.jit_dump import JitDump
46
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
47
+ from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
48
+
49
+ if is_mindtorch():
50
+ import torch
43
51
 
44
52
 
45
53
  class Service:
@@ -51,54 +59,144 @@ class Service:
51
59
  self.cell_processor = CellProcessor(self.data_collector.scope)
52
60
  self.primitive_hook_service = PrimitiveHookService(self)
53
61
  self.switch = False
62
+ self.inner_switch = False
54
63
  self.primitive_switch = False
55
64
  self.current_iter = 0
56
65
  self.first_start = True
57
66
  self.current_rank = None
58
67
  self.dump_iter_dir = None
59
68
  self.start_call = False
60
- self.check_level_valid()
61
69
  self.should_stop_service = False
70
+ self.params_grad_info = {}
71
+ # 提前注册,确保注册尽可能多的API hook
72
+ self.register_api_hook()
62
73
 
63
74
  @staticmethod
64
- def check_model_valid(model):
65
- if not model or isinstance(model, nn.Cell):
66
- return model
67
- raise MsprobeException(
68
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
69
- )
75
+ def check_model_valid(models):
76
+ target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
77
+ if models is None or isinstance(models, target_module_type[0]):
78
+ return models
79
+ error_model = None
80
+ if isinstance(models, (list, tuple)):
81
+ for model in models:
82
+ if not isinstance(model, target_module_type[0]):
83
+ error_model = model
84
+ break
85
+ else:
86
+ error_model = models
70
87
 
71
- def check_level_valid(self):
72
- if self.config.level == Const.LEVEL_L2:
88
+ if error_model is not None:
89
+ error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
90
+ f"type, currently there is a {type(error_model)} type.")
73
91
  raise MsprobeException(
74
- MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
75
- )
92
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
93
+ return models
94
+
95
+ @staticmethod
96
+ def prepare_module_input_output(target_type, cell, input_data, output):
97
+ if target_type == BaseScope.Module_Type_Module:
98
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
99
+ else:
100
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
101
+ return module_input_output
76
102
 
77
103
  def build_hook(self, target_type, name):
78
- def forward_hook(api_or_cell_name, cell, input_data, output):
79
- if not self.should_excute_hook():
80
- if hasattr(cell, 'input_kwargs'):
81
- del cell.input_kwargs
104
+ def pre_hook(api_or_cell_name, cell, input_data):
105
+ if not self.should_execute_hook(target_type, cell, True):
106
+ clean_input_kwargs(cell)
82
107
  return None
83
108
 
84
- if target_type == BaseScope.Module_Type_Module:
85
- api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
86
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
87
- else:
88
- module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs,
89
- output=output)
109
+ with _no_grad():
110
+ self.inner_switch = True
111
+ if target_type == BaseScope.Module_Type_Module:
112
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
113
+ else:
114
+ cell.forward_data_collected = True
115
+ HOOKCell.add_cell_count(name)
116
+ module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
117
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
118
+ self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
119
+ self.inner_switch = False
120
+ return input_data
121
+
122
+ def grad_hook(cell, ori_name, param_name):
123
+ def hook_fn(grad):
124
+ if not self.should_execute_hook(target_type, cell, False):
125
+ return None
126
+ self.inner_switch = True
127
+ self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
128
+ self.inner_switch = False
129
+ return None
90
130
 
91
- self.data_collector.update_api_or_module_name(api_or_cell_name)
92
- self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
93
- if self.data_collector.if_return_forward_new_output():
94
- return self.data_collector.get_forward_new_output()
95
- if hasattr(cell, 'input_kwargs'):
96
- del cell.input_kwargs
97
- return output
131
+ return hook_fn
132
+
133
+ def register_param_hook(ori_name, cell, params_dict):
134
+ '''
135
+ 注册参数hook
136
+ '''
137
+ # data_mode为forward时,不注册参数hook
138
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
139
+ for param_name, param in params_dict.items():
140
+ if param.requires_grad:
141
+ param.register_hook(grad_hook(cell, ori_name, param_name))
142
+
143
+ def init_params_grad_info(cell, params_dict):
144
+ '''
145
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
146
+ '''
147
+ if not params_dict:
148
+ return
149
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
150
+ grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
151
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
152
+ if not self.params_grad_info.get(grad_name):
153
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
154
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
155
+ if data_info.get(grad_name):
156
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
157
+ self.data_collector.handle_data(grad_name, data_info,
158
+ flush=self.data_collector.data_processor.is_terminated)
159
+ # 记录当前模块的参数梯度信息已占位
160
+ self.params_grad_info[grad_name] = True
161
+
162
+ def forward_hook(api_or_cell_name, cell, input_data, output):
163
+ if not self.should_execute_hook(target_type, cell, True):
164
+ clean_input_kwargs(cell)
165
+ return None
166
+ with _no_grad():
167
+ self.inner_switch = True
168
+ module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
169
+ if target_type == BaseScope.Module_Type_Module:
170
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
171
+ params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(
172
+ recurse=False).items()}
173
+ setattr(module_input_output, Const.PARAMS, params_dict)
174
+ # 判断是否需要注册参数hook
175
+ if not hasattr(cell, 'params_grad_name') and params_dict:
176
+ ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
177
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
178
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
179
+ setattr(cell, 'params_grad_name', grad_name)
180
+ register_param_hook(ori_name, cell, params_dict)
181
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
182
+ self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
183
+ init_params_grad_info(cell, params_dict)
184
+ else:
185
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
186
+ self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
187
+
188
+ if self.data_collector.if_return_forward_new_output():
189
+ forward_new_output = self.data_collector.get_forward_new_output()
190
+ self.inner_switch = False
191
+ return forward_new_output
192
+ clean_input_kwargs(cell)
193
+ self.inner_switch = False
194
+ return output
98
195
 
99
196
  def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
100
- if not self.should_excute_hook():
197
+ if not self.should_execute_hook(target_type, cell, False):
101
198
  return
199
+ self.inner_switch = True
102
200
 
103
201
  need_exchange = True
104
202
  if target_type == BaseScope.Module_Type_Module:
@@ -114,12 +212,32 @@ class Service:
114
212
  else:
115
213
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
116
214
  self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
215
+ self.inner_switch = False
216
+
217
+ def pre_backward_hook(api_or_cell_name, cell, grad_input):
218
+ if not self.should_execute_hook(target_type, cell, False):
219
+ return
220
+ self.inner_switch = True
221
+ module_input = ModuleBackwardInputs(grad_input=grad_input)
222
+ self.data_collector.update_api_or_module_name(api_or_cell_name)
223
+ self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
224
+
225
+ self.inner_switch = False
117
226
 
118
227
  pid = os.getpid()
119
- forward_name_template = name + Const.FORWARD
120
- backward_name_template = name + Const.BACKWARD
121
- forward_hook = functools.partial(forward_hook, forward_name_template)
122
- backward_hook = functools.partial(backward_hook, backward_name_template)
228
+ if target_type == BaseScope.Module_Type_Module:
229
+ full_forward_name = name + Const.FORWARD
230
+ full_backward_name = name + Const.BACKWARD
231
+ else:
232
+ full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
233
+ full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
234
+ pre_forward_hook = functools.partial(pre_hook, full_forward_name)
235
+ forward_hook = functools.partial(forward_hook, full_forward_name)
236
+ backward_hook = functools.partial(backward_hook, full_backward_name)
237
+ pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
238
+
239
+ def wrap_pre_forward_hook(cell, input_data):
240
+ return pre_forward_hook(cell, input_data)
123
241
 
124
242
  def wrap_forward_hook(cell, input_data, output_data):
125
243
  return forward_hook(cell, input_data, output_data)
@@ -127,7 +245,10 @@ class Service:
127
245
  def wrap_backward_hook(cell, grad_input, grad_output):
128
246
  return backward_hook(cell, grad_input, grad_output)
129
247
 
130
- return wrap_forward_hook, wrap_backward_hook
248
+ def wrap_pre_backward_hook(cell, grad_input):
249
+ return pre_backward_hook(cell, grad_input)
250
+
251
+ return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
131
252
 
132
253
  def update_primitive_counters(self, primitive_name):
133
254
  if primitive_name not in self.primitive_counters:
@@ -135,33 +256,20 @@ class Service:
135
256
  else:
136
257
  self.primitive_counters[primitive_name] += 1
137
258
 
138
- def register_primitive_hooks(self):
139
- primitive_set = set()
140
- for _, cell in self.model.cells_and_names():
141
- for pname, primitive in cell._primitives.items():
142
- primitive_set.add((pname, primitive))
143
-
144
- for pname, primitive in primitive_set:
145
- primitive_class_name = primitive.__class__.__name__
146
- primitive_combined_name = pname + Const.SEP + primitive_class_name
147
- new_primitive = type('NewPrimitive', (primitive.__class__,),
148
- {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
149
- primitive_combined_name)})
150
- primitive.__class__ = new_primitive
151
-
152
259
  def step(self):
260
+ if self.config.async_dump:
261
+ self.data_collector.fill_stack_tensor_data()
262
+ self.data_collector.data_processor.dump_async_data()
263
+ self.data_collector.write_json()
153
264
  self.current_iter += 1
154
265
  self.data_collector.update_iter(self.current_iter)
155
- self.primitive_hook_service.primitive_counters.clear()
156
- self.data_collector.data_writer.reset_cache()
157
- JitDump.jit_count = defaultdict(int)
266
+ self.reset_status()
158
267
 
159
268
  def start(self, model=None):
160
269
  self.start_call = True
161
270
  if self.should_stop_service:
162
271
  return
163
272
  if self.need_end_service():
164
- api_register.api_set_ori_func()
165
273
  self.should_stop_service = True
166
274
  self.switch = False
167
275
  self.primitive_switch = False
@@ -181,7 +289,8 @@ class Service:
181
289
 
182
290
  if self.config.rank and self.current_rank not in self.config.rank:
183
291
  return
184
- self.register_hook_new()
292
+ self.register_primitive_hook()
293
+ self.register_cell_hook()
185
294
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
186
295
  JitDump.set_config(self.config)
187
296
  JitDump.set_data_collector(self.data_collector)
@@ -200,25 +309,6 @@ class Service:
200
309
  logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
201
310
  JitDump.jit_dump_switch = True
202
311
 
203
- def forward_backward_dump_end(self):
204
- if self.should_stop_service:
205
- return
206
- logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
207
- if not self.start_call:
208
- logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
209
- raise Exception("debugger.start() is not set in the current scope.")
210
- if not self.switch:
211
- logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
212
- "debugger.start() and debugger.stop() ")
213
- raise Exception("debugger.stop() is already called. ")
214
- if self.config.step and self.current_iter not in self.config.step:
215
- return
216
- if self.config.rank and self.current_rank not in self.config.rank:
217
- return
218
- self.primitive_switch = False
219
- api_register.api_set_ori_func()
220
- JitDump.jit_dump_switch = False
221
-
222
312
  def stop(self):
223
313
  if self.should_stop_service:
224
314
  return
@@ -234,6 +324,9 @@ class Service:
234
324
  self.switch = False
235
325
  self.primitive_switch = False
236
326
  self.start_call = False
327
+ if self.config.async_dump:
328
+ self.data_collector.fill_stack_tensor_data()
329
+ self.data_collector.data_processor.dump_async_data()
237
330
  self.data_collector.write_json()
238
331
  JitDump.jit_dump_switch = False
239
332
 
@@ -244,8 +337,16 @@ class Service:
244
337
  return True
245
338
  return False
246
339
 
247
- def should_excute_hook(self):
248
- if not self.switch:
340
+ def should_execute_hook(self, hook_type, cell, is_forward):
341
+ is_cell_hook = hook_type == BaseScope.Module_Type_Module
342
+ if is_cell_hook and not self.switch:
343
+ return False
344
+ elif not is_cell_hook and is_forward and not self.switch:
345
+ return False
346
+ elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
347
+ return False
348
+
349
+ if self.inner_switch:
249
350
  return False
250
351
  if not self.data_collector or self.data_collector.data_processor.is_terminated:
251
352
  return False
@@ -255,6 +356,12 @@ class Service:
255
356
  create_directory(self.config.dump_path)
256
357
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
257
358
  cur_rank = self.current_rank if self.current_rank is not None else ''
359
+ if self.config.level == Const.LEVEL_L2:
360
+ create_directory(self.dump_iter_dir)
361
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
362
+ self.config.kernel_config_path = kernel_config_path
363
+ return
364
+
258
365
  dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
259
366
  create_directory(dump_dir)
260
367
  if self.config.task in self.data_collector.tasks_need_tensor_data:
@@ -267,37 +374,96 @@ class Service:
267
374
  stack_file_path = os.path.join(dump_dir, "stack.json")
268
375
  construct_file_path = os.path.join(dump_dir, "construct.json")
269
376
  self.data_collector.update_dump_paths(
270
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
377
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
378
+ )
379
+ self.data_collector.initialize_json_file(
380
+ framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
381
+ )
271
382
 
272
383
  def empty(self, *args, **kwargs):
273
384
  pass
274
385
 
275
- def register_hook_new(self):
276
- logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
277
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
386
+ def register_api_hook(self):
387
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
388
+ logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
278
389
  api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
279
390
  api_register.api_set_hook_func()
280
- if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
281
- self.register_primitive_hooks()
282
391
 
392
+ def get_cells_and_names(self):
393
+ cells_and_names_with_index = {}
394
+
395
+ def get_cell_or_module(model):
396
+ return model.named_modules() if is_mindtorch() else model.cells_and_names()
397
+
398
+ if isinstance(self.model, (list, tuple)):
399
+ for index, model in enumerate(self.model):
400
+ cells_and_names_with_index[str(index)] = get_cell_or_module(model)
401
+ else:
402
+ cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
403
+ return cells_and_names_with_index
404
+
405
+ def register_primitive_hook(self):
406
+ if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
407
+ return
408
+ if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
409
+ return
410
+
411
+ primitive_set = set()
412
+ cells_and_names_with_index = self.get_cells_and_names()
413
+ for cells_and_names in cells_and_names_with_index.values():
414
+ for _, cell in cells_and_names:
415
+ for attribute, value in vars(cell).items():
416
+ if isinstance(value, Primitive):
417
+ primitive_set.add((attribute, value))
418
+
419
+ for pname, primitive in primitive_set:
420
+ primitive_class_name = primitive.__class__.__name__
421
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
422
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
423
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
424
+ primitive_combined_name)})
425
+ primitive.__class__ = new_primitive
426
+
427
+ def register_cell_hook(self):
283
428
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
429
+ logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
284
430
  if not self.model:
285
431
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
286
432
  f"The current level is {self.config.level}, the model cannot be None")
287
- for name, cell in self.model.cells_and_names():
288
- if cell == self.model:
289
- continue
290
- prefix = 'Cell' + Const.SEP + name + Const.SEP + \
291
- cell.__class__.__name__ + Const.SEP
292
- forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
293
- cell.register_forward_hook(forward_hook)
294
- cell.register_backward_hook(backward_hook)
295
-
296
- cell.register_forward_pre_hook(
297
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
298
- cell.register_forward_hook(
299
- self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
300
- cell.register_backward_pre_hook(
301
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
302
- cell.register_backward_hook(
303
- self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
433
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
434
+ cells_and_names_with_index = self.get_cells_and_names()
435
+
436
+ for index, cells_and_names in cells_and_names_with_index.items():
437
+ model = self.model if index == "-1" else self.model[int(index)]
438
+ for name, cell in cells_and_names:
439
+ if cell == model:
440
+ continue
441
+ cell_index = (index + Const.SEP) if index != "-1" else ""
442
+ prefix = (model_type + Const.SEP + cell_index + name +
443
+ Const.SEP + cell.__class__.__name__ + Const.SEP)
444
+ _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
445
+ cell.register_forward_hook(forward_hook)
446
+ cell.register_forward_pre_hook(
447
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
448
+ cell.register_forward_hook(
449
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
450
+
451
+ register_backward_hook_functions["full"](cell, backward_hook)
452
+ register_backward_hook_functions["pre"](
453
+ cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
454
+ register_backward_hook_functions["full"](
455
+ cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
456
+
457
+ def reset_status(self):
458
+ self.primitive_hook_service.primitive_counters.clear()
459
+ self.data_collector.data_writer.reset_cache()
460
+ JitDump.jit_count = defaultdict(int)
461
+ self.params_grad_info.clear()
462
+
463
+ if self.config.level == Const.LEVEL_L2:
464
+ self.data_collector.data_processor.reset_status()
465
+ return
466
+ if self.config.step and self.current_iter not in self.config.step:
467
+ return
468
+ if self.config.rank and self.current_rank not in self.config.rank:
469
+ return
msprobe/msprobe.py CHANGED
@@ -16,10 +16,12 @@
16
16
  import argparse
17
17
  import sys
18
18
  import importlib.util
19
- from msprobe.core.compare.utils import _compare_parser
19
+
20
+ from msprobe.core.common.const import Const
20
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.compare.utils import _compare_parser
21
23
  from msprobe.core.compare.compare_cli import compare_cli
22
- from msprobe.core.common.const import Const
24
+ from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
23
25
 
24
26
 
25
27
  def is_module_available(module_name):
@@ -45,10 +47,15 @@ def main():
45
47
  multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
46
48
  api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
47
49
  run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
50
+ code_mapping_cmd_parser = subparsers.add_parser('code_mapping')
48
51
  graph_service_cmd_parser = subparsers.add_parser('graph')
52
+ op_generate_cmd_parser = subparsers.add_parser('op_generate')
53
+ merge_result_parser = subparsers.add_parser('merge_result')
49
54
  _compare_parser(compare_cmd_parser)
55
+ _merge_result_parser(merge_result_parser)
56
+
50
57
  is_torch_available = is_module_available("torch")
51
- is_mindspore_available = is_module_available("mindspore")
58
+
52
59
  if len(sys.argv) < 4:
53
60
  parser.print_help()
54
61
  sys.exit(0)
@@ -62,6 +69,8 @@ def main():
62
69
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
63
70
  _run_overflow_check_command
64
71
  from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
72
+ from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
73
+ _run_operator_generate_commond
65
74
 
66
75
  _run_ut_parser(run_ut_cmd_parser)
67
76
  _run_ut_parser(multi_run_ut_cmd_parser)
@@ -70,12 +79,15 @@ def main():
70
79
  _api_precision_compare_parser(api_precision_compare_cmd_parser)
71
80
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
72
81
  _pt_graph_service_parser(graph_service_cmd_parser)
82
+ _op_generator_parser(op_generate_cmd_parser)
73
83
  elif framework_args.framework == Const.MS_FRAMEWORK:
74
84
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
75
85
  from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
76
86
  add_api_accuracy_checker_argument(run_ut_cmd_parser)
77
87
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
78
88
  multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
89
+ from msprobe.mindspore.code_mapping.cmd_parser import add_ir_parser_arguments
90
+ add_ir_parser_arguments(code_mapping_cmd_parser)
79
91
 
80
92
  _ms_graph_service_parser(graph_service_cmd_parser)
81
93
 
@@ -97,17 +109,23 @@ def main():
97
109
  _run_overflow_check_command(args)
98
110
  elif sys.argv[3] == "graph":
99
111
  _pt_graph_service_command(args)
112
+ elif sys.argv[3] == 'op_generate':
113
+ _run_operator_generate_commond(args)
100
114
  elif sys.argv[3] == "compare":
101
115
  if args.cell_mapping is not None or args.api_mapping is not None:
102
116
  logger.error("Argument -cm or -am is not supported in PyTorch framework")
103
117
  raise Exception("Argument -cm or -am is not supported in PyTorch framework")
104
118
  compare_cli(args)
119
+ elif sys.argv[3] == "merge_result":
120
+ merge_result_cli(args)
105
121
  else:
106
122
  if not is_module_available(Const.MS_FRAMEWORK):
107
123
  logger.error("MindSpore does not exist, please install MindSpore library")
108
124
  raise Exception("MindSpore does not exist, please install MindSpore library")
109
125
  if sys.argv[3] == "compare":
110
126
  compare_cli(args)
127
+ elif sys.argv[3] == "merge_result":
128
+ merge_result_cli(args)
111
129
  elif sys.argv[3] == "run_ut":
112
130
  from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
113
131
  api_checker_main(args)
@@ -116,6 +134,9 @@ def main():
116
134
  mul_api_checker_main(args)
117
135
  elif sys.argv[3] == "graph":
118
136
  _ms_graph_service_command(args)
137
+ elif sys.argv[3] == "code_mapping":
138
+ from msprobe.mindspore.code_mapping.main import code_mapping_main
139
+ code_mapping_main(args)
119
140
 
120
141
 
121
142
  if __name__ == "__main__":
@@ -1,6 +1,4 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
2
  # All rights reserved.
5
3
  #
6
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,9 +14,12 @@
16
14
  # limitations under the License.
17
15
 
18
16
 
19
- from msprobe.pytorch.monitor.module_hook import TrainerMon
17
+ import torch
20
18
  from .compare.distributed_compare import compare_distributed
21
19
  from .compare.pt_compare import compare
22
20
  from .common.utils import seed_all
23
- from .debugger.precision_debugger import PrecisionDebugger
24
- from .functional.module_dump import module_dump, module_dump_end
21
+ from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end
22
+
23
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
+ if torch_version_above_or_equal_2:
25
+ from msprobe.pytorch.monitor.module_hook import TrainerMon