mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -14,8 +14,10 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ from dataclasses import dataclass
18
+ from typing import Any, Optional
17
19
  from tqdm import tqdm
18
-
20
+ import numpy as np
19
21
  from msprobe.core.common.const import Const, CompareConst
20
22
  from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
21
23
  from msprobe.core.common.utils import add_time_as_suffix
@@ -28,6 +30,9 @@ from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_jso
28
30
  from msprobe.mindspore.common.const import MsCompareConst
29
31
  from msprobe.mindspore.common.log import logger
30
32
  from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
33
+ from msprobe.core.data_dump.data_collector import build_data_collector
34
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
35
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
31
36
 
32
37
  cur_path = os.path.dirname(os.path.realpath(__file__))
33
38
  yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
@@ -59,13 +64,129 @@ class ProcessResultPacket:
59
64
  self.err_msg = err_msg
60
65
 
61
66
 
67
+ @dataclass
68
+ class Config:
69
+ execution_mode: str
70
+ dump_path: str
71
+ task: str
72
+ level: str
73
+ scope: Optional[Any]
74
+ list: Optional[Any]
75
+ framework: str
76
+ data_mode: str
77
+ file_format: str
78
+ dump_tensor_data_dir: str
79
+ async_dump: bool
80
+ summary_mode: Optional[Any] = None
81
+
82
+
62
83
  class ApiAccuracyChecker:
63
84
  def __init__(self, args):
64
85
  self.api_infos = dict()
65
86
  self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
87
+ self.save_error_data = args.save_error_data
88
+ if self.save_error_data:
89
+ config, dump_path_aggregation = self.init_save_error_data(args)
90
+ self.data_collector = build_data_collector(config)
91
+ self.data_collector.update_dump_paths(dump_path_aggregation)
66
92
 
67
93
  @staticmethod
68
- def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
94
+ def init_save_error_data(args):
95
+ config = Config(
96
+ execution_mode="pynative",
97
+ dump_path=f"{args.out_path}",
98
+ dump_tensor_data_dir=f"{args.out_path}",
99
+ task="tensor", # 任务类型,模拟保存tensor数据
100
+ level="L1", # 级别
101
+ scope=None, # 作用域 (None)
102
+ list=None, # API 列表 (None)
103
+ framework=Const.MS_FRAMEWORK, # 框架类型
104
+ data_mode="all",
105
+ file_format="npy",
106
+ async_dump=False
107
+ )
108
+
109
+ dump_dir = f"{args.out_path}"
110
+ dump_data_dir = os.path.join(dump_dir, "error_data")
111
+ create_directory(dump_data_dir)
112
+ dump_path_aggregation = DumpPathAggregation()
113
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
114
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
115
+ dump_path_aggregation.dump_error_info_path = os.path.join(dump_dir, "dump_error_info.log")
116
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
117
+ return config, dump_path_aggregation
118
+
119
+ @staticmethod
120
+ def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
121
+ """
122
+ Args:
123
+ api_info: ApiInfo
124
+ forward_or_backward: str
125
+ Returns:
126
+ ApiInputAggregation
127
+ """
128
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
129
+ kwargs = api_info.get_kwargs()
130
+ if forward_or_backward == Const.FORWARD:
131
+ gradient_inputs = None
132
+ else:
133
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
134
+ return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
135
+
136
+ @staticmethod
137
+ def is_api_checkable(api_name_str):
138
+ '''
139
+ Args:
140
+ api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
141
+ Returns:
142
+ is_checkable: bool
143
+ Description:
144
+ tell whether this api is checkable based on the key in "data" dict in api_info.json
145
+ '''
146
+ api_name_str_list = api_name_str.split(Const.SEP)
147
+ if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
148
+ return False
149
+ api_type_str = api_name_str_list[0]
150
+ real_api_str = Const.SEP.join(api_name_str_list[1:-2])
151
+ api_list = load_yaml(yaml_path)
152
+ supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
153
+ supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
154
+ if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
155
+ and global_context.get_framework() == Const.MS_FRAMEWORK:
156
+ return True
157
+ if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
158
+ and global_context.get_framework() == Const.MT_FRAMEWORK:
159
+ return True
160
+ if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
161
+ and global_context.get_framework() == Const.MS_FRAMEWORK:
162
+ return True
163
+ if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
164
+ and global_context.get_framework() == Const.MS_FRAMEWORK:
165
+ return True
166
+ return False
167
+
168
+ def post_forward_hook(self, api_or_module_name, primitive_instance, args, kwargs, output):
169
+ self.data_collector.update_api_or_module_name(api_or_module_name)
170
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
171
+ self.data_collector.forward_data_collect_only_tensor(
172
+ api_or_module_name,
173
+ primitive_instance,
174
+ os.getpid(),
175
+ module_input_output
176
+ )
177
+
178
+ def backward_hook(self, api_or_module_name, module, grad_input, grad_output):
179
+ self.data_collector.update_api_or_module_name(api_or_module_name)
180
+
181
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
182
+ self.data_collector.backward_data_collect_only_tensor(
183
+ api_or_module_name,
184
+ module,
185
+ os.getpid(),
186
+ module_input_output
187
+ )
188
+
189
+ def run_and_compare_helper(self, api_info, api_name_str, api_input_aggregation, forward_or_backward):
69
190
  """
70
191
  Args:
71
192
  api_info: ApiInfo
@@ -83,13 +204,22 @@ class ApiAccuracyChecker:
83
204
  """
84
205
  # get output
85
206
  if global_context.get_is_constructed():
86
- # constructed situation, need use constructed input to run mindspore api getting tested_output
87
- tested_outputs = api_runner(api_input_aggregation, api_name_str,
88
- forward_or_backward, global_context.get_framework())
207
+ if forward_or_backward == Const.FORWARD:
208
+ tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str,
209
+ forward_or_backward,
210
+ global_context.get_framework())
211
+ elif forward_or_backward == Const.BACKWARD:
212
+ tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str,
213
+ forward_or_backward,
214
+ global_context.get_framework())
215
+ else:
216
+ tested_outputs = api_runner(api_input_aggregation, api_name_str,
217
+ forward_or_backward, global_context.get_framework())
89
218
  else:
90
219
  tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
91
220
 
92
221
  bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
222
+
93
223
  tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
94
224
  bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
95
225
  if len(tested_outputs) != len(bench_outputs):
@@ -114,64 +244,26 @@ class ApiAccuracyChecker:
114
244
  compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
115
245
  status = CompareConst.PASS
116
246
  err_msg = ""
247
+
117
248
  else:
118
249
  status = CompareConst.ERROR
119
250
  err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
120
251
  compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
252
+ if forward_or_backward == Const.FORWARD and self.save_error_data \
253
+ and global_context.get_is_constructed():
254
+ api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.FORWARD}"
255
+ self.post_forward_hook(api_name_str_backward, None, inputs, kwargs, forward_result_tuple)
256
+
257
+ if forward_or_backward == Const.BACKWARD and self.save_error_data \
258
+ and global_context.get_is_constructed():
259
+ api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.BACKWARD}"
260
+ self.backward_hook(api_name_str_backward, None, gradient_inputs, backward_result_tuple)
261
+
121
262
  basic_info_status = \
122
263
  BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
123
264
  output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
124
265
  return output_list
125
266
 
126
- @staticmethod
127
- def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
128
- """
129
- Args:
130
- api_info: ApiInfo
131
- forward_or_backward: str
132
- Returns:
133
- ApiInputAggregation
134
- """
135
- forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
136
- kwargs = api_info.get_kwargs()
137
- if forward_or_backward == Const.FORWARD:
138
- gradient_inputs = None
139
- else:
140
- gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
141
- return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
142
-
143
- @staticmethod
144
- def is_api_checkable(api_name_str):
145
- '''
146
- Args:
147
- api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
148
- Returns:
149
- is_checkable: bool
150
- Description:
151
- tell whether this api is checkable based on the key in "data" dict in api_info.json
152
- '''
153
- api_name_str_list = api_name_str.split(Const.SEP)
154
- if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
155
- return False
156
- api_type_str = api_name_str_list[0]
157
- real_api_str = Const.SEP.join(api_name_str_list[1:-2])
158
- api_list = load_yaml(yaml_path)
159
- supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
160
- supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
161
- if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
162
- and global_context.get_framework() == Const.MS_FRAMEWORK:
163
- return True
164
- if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
165
- and global_context.get_framework() == Const.MT_FRAMEWORK:
166
- return True
167
- if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
168
- and global_context.get_framework() == Const.MS_FRAMEWORK:
169
- return True
170
- if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
171
- and global_context.get_framework() == Const.MS_FRAMEWORK:
172
- return True
173
- return False
174
-
175
267
  def parse(self, api_info_path):
176
268
 
177
269
  api_info_dict = load_json(api_info_path)
@@ -183,9 +275,9 @@ class ApiAccuracyChecker:
183
275
  MsCompareConst.TENSOR_TASK))
184
276
  try:
185
277
  framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK,
186
- "framework field in api_info.json", accepted_type=str,
187
- accepted_value=(Const.MS_FRAMEWORK,
188
- Const.MT_FRAMEWORK))
278
+ "framework field in api_info.json", accepted_type=str,
279
+ accepted_value=(Const.MS_FRAMEWORK,
280
+ Const.MT_FRAMEWORK))
189
281
  except Exception as e:
190
282
  framework = Const.MS_FRAMEWORK
191
283
  logger.warning(f"JSON parsing error in framework field: {e}")
@@ -301,4 +393,4 @@ class ApiAccuracyChecker:
301
393
  elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
302
394
  self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
303
395
 
304
- self.data_manager.save_results(api_name_str)
396
+ self.data_manager.save_results(api_name_str)
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import os
17
+ import numpy as np
16
18
  import mindspore
17
19
  from mindspore import ops
18
20
  from msprobe.core.common.const import Const
@@ -38,7 +40,6 @@ else:
38
40
  import torch
39
41
 
40
42
 
41
-
42
43
  class ApiInputAggregation:
43
44
  def __init__(self, inputs, kwargs, gradient_inputs) -> None:
44
45
  """
@@ -148,13 +149,13 @@ class ApiRunner:
148
149
  Args:
149
150
  api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
150
151
  api_sub_name: str, e.g. "relu"
151
- api_platform: str: Union["mindpore", "pytorch"]
152
+ api_platform: str: Union["mindspore", "pytorch"]
152
153
 
153
154
  Return:
154
155
  api_instance: function object
155
156
 
156
157
  Description:
157
- get mindspore.mint/torch api fucntion
158
+ get mindspore.mint/torch api function
158
159
  mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
159
160
  mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
160
161
  """
@@ -189,6 +190,8 @@ class ApiRunner:
189
190
  forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
190
191
  forward_result_tuple = convert_to_tuple(forward_result)
191
192
  res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
193
+ if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
194
+ return res_compute_element_list, inputs, kwargs, forward_result_tuple
192
195
  else:
193
196
  if gradient_inputs is None:
194
197
  err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
@@ -206,6 +209,7 @@ class ApiRunner:
206
209
  backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
207
210
  backward_result_tuple = convert_to_tuple(backward_result)
208
211
  res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
212
+ return res_compute_element_list, gradient_inputs, backward_result_tuple
209
213
  else:
210
214
  # set requires_grad
211
215
  requires_grad_index = []
@@ -52,8 +52,14 @@ def softmax_grad(dp, softmax_res):
52
52
 
53
53
 
54
54
  def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
55
+ # 检查维度
56
+ if kv_tensor.dim() != 4:
57
+ raise ValueError(f"broadcast_kv: kv_tensor 必须是 4 维 (B, N_kv, S, D),但得到 {kv_tensor.shape}")
55
58
  if num_kv_heads == 0 or num_kv_heads > num_heads:
56
- raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.")
59
+ raise ValueError("broadcast_kv: num_kv_heads 必须大于 0 且不超过 num_heads")
60
+ if num_heads % num_kv_heads != 0:
61
+ raise ValueError(f"broadcast_kv: num_heads({num_heads}) 必须能被 num_kv_heads({num_kv_heads}) 整除。")
62
+
57
63
 
58
64
  factor = num_heads // num_kv_heads
59
65
  kv_shape = kv_tensor.shape
@@ -68,6 +74,13 @@ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
68
74
 
69
75
 
70
76
  def calculate_qk(q, k, attn_mask, pse, scalar_value):
77
+ # 基本形状检查
78
+ if q.dim() < 4 or k.dim() < 4:
79
+ raise ValueError(f"calculate_qk: q,k 必须至少 4 维,q={q.dim()},k={k.dim()}")
80
+ # 检查 head_dim 一致性
81
+ if q.size(-1) != k.size(-1):
82
+ raise ValueError(f"calculate_qk: q.head_dim({q.size(-1)}) != k.head_dim({k.size(-1)})")
83
+
71
84
  if k.dim() != 4:
72
85
  raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})")
73
86
 
@@ -95,6 +108,10 @@ def fusion_attention_forward(forward_params):
95
108
  scalar_value = forward_params.scalar_value
96
109
  keep_prob = forward_params.keep_prob
97
110
 
111
+ # 拦截 keep_prob 为 0 的情况,防止除零
112
+ if keep_prob == 0:
113
+ raise ValueError("fusion_attention_forward: keep_prob 不能为 0,避免除零错误。")
114
+
98
115
  qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
99
116
  softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
100
117
  if drop_mask is None or len(drop_mask.shape) == 0:
@@ -115,6 +132,11 @@ def fusion_attention_backward(backward_params):
115
132
  pse = backward_params.pse
116
133
  scalar_value = backward_params.scalar_value
117
134
  keep_prob = backward_params.keep_prob
135
+
136
+ # 拦截 keep_prob 为 0 的情况,防止除零
137
+ if keep_prob == 0:
138
+ raise ValueError("fusion_attention_backward: keep_prob 不能为 0,避免除零错误。")
139
+
118
140
  dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
119
141
  if drop_mask is None or len(drop_mask.shape) == 0:
120
142
  drop_res = softmax_res.permute(0, 1, 3, 2)
@@ -138,34 +160,45 @@ def parse_bsnd_args(query, key, head_num, input_layout):
138
160
 
139
161
  if input_layout == "TND":
140
162
  raise ValueError(f"input_layout {input_layout} does not supported for now.")
163
+
164
+ # 防止 head_num 为 0
165
+ if n1 == 0:
166
+ raise ValueError("parse_bsnd_args: head_num (n1) 不能为 0,避免除零错误。")
167
+
141
168
  try:
142
169
  if input_layout == "BSH":
143
170
  b, s1, h1 = query.shape
144
171
  _, s2, h2 = key.shape
145
172
  d = h1 // n1
173
+ # 拦截 d 为 0 的情况
174
+ if d == 0:
175
+ raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
146
176
  n2 = h2 // d
147
177
  elif input_layout == "SBH":
148
178
  s1, b, h1 = query.shape
149
179
  s2, _, h2 = key.shape
150
180
  d = h1 // n1
181
+ if d == 0:
182
+ raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
151
183
  n2 = h2 // d
152
184
  elif input_layout == "BSND":
153
185
  b, s1, n1, d = query.shape
154
186
  _, s2, n2, _ = key.shape
187
+ if d == 0:
188
+ raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
155
189
  h1 = n1 * d
156
190
  h2 = n2 * d
157
191
  elif input_layout == "BNSD":
158
192
  b, n1, s1, d = query.shape
159
193
  _, n2, s2, _ = key.shape
194
+ if d == 0:
195
+ raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
160
196
  h1 = n1 * d
161
197
  h2 = n2 * d
162
198
  except Exception as e:
163
199
  raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
164
200
 
165
- if d == 0:
166
- raise ValueError(f"Value d must be non-zero.")
167
- _dtype = query.dtype
168
- ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
201
+ ret = (b, s1, s2, n1, n2, d, h1, h2, query.dtype)
169
202
  return ret
170
203
 
171
204
 
@@ -230,67 +263,6 @@ def convert_to_bnsd(_input, n, input_layout):
230
263
  return out.to(GTYPE)
231
264
 
232
265
 
233
- def convert_from_bsnd(_input, input_layout):
234
- """
235
- transform qkv from bsnd to input_layout.
236
- B: batch_size
237
- S: sequence_length
238
- N: num_heads
239
- D: head_dim
240
- Args:
241
- _input (torch.Tensor): tensor of shape (B,S,N,D)
242
- input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
243
- Returns:
244
- tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
245
- """
246
- if input_layout == "BSH":
247
- # (B,S,N,D)=>(B,S,N*D)
248
- out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
249
- elif input_layout == "SBH":
250
- # (B,S,N,D)=>(S,B,N*D)
251
- out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
252
- elif input_layout == "BNSD":
253
- # (B,S,N,D)=>(B,N,S,D)
254
- out = rearrange(_input, 'b s n d -> b n s d').contiguous()
255
- elif input_layout == "TND":
256
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
257
- else:
258
- out = _input
259
- return out
260
-
261
-
262
- def convert_to_bsnd(_input, n, input_layout):
263
- """
264
- transform qkv from input_layout to bsnd.
265
- B: batch_size
266
- S: sequence_length
267
- N: num_heads
268
- D: head_dim
269
- Args:
270
- _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
271
- n (int): num_heads
272
- input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
273
- Returns:
274
- tensor of shape (B,S,N,D)
275
- """
276
- if input_layout == "BSH":
277
- # (B,S,N*D)=>(B,S,N,D)
278
- out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
279
- elif input_layout == "SBH":
280
- # (S,B,N*D)=>(B,S,N,D)
281
- out = rearrange(_input, 's b (n d) -> b s n d', n=n)
282
- elif input_layout == "BNSD":
283
- # (B,N,S,D)=>(B,S,N,D)
284
- out = rearrange(_input, 'b n s d -> b s n d', n=n)
285
- elif input_layout == "TND":
286
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
287
- else:
288
- out = _input
289
- if out.dim() != 4:
290
- raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
291
- return out
292
-
293
-
294
266
  def generate_attn_mask(*args):
295
267
  """
296
268
  # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
@@ -417,17 +389,20 @@ def get_input_layout(*args, **kwargs):
417
389
 
418
390
  def npu_fusion_attention_forward_patch(*args, **kwargs):
419
391
  if len(args) < 2:
420
- raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
392
+ raise RuntimeError("npu_fusion_attention_forward_patch: length of args should be greater than or equal to 2.")
421
393
 
422
394
  # query, key, value, head_num, input_layout
423
395
  head_num = get_head_num(*args, **kwargs)
424
396
  input_layout = get_input_layout(*args, **kwargs)
425
397
 
426
398
  b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
399
+ # 此处 d 已在 parse_bsnd_args 中检查为非零
427
400
  if n1 == n2 and s1 == s2:
428
401
  logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
429
402
  else:
430
403
  logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
404
+ if n2 == 0:
405
+ raise ValueError("n2 不能为 0,避免除零错误。")
431
406
  if not (n1 % n2 == 0 and n1 >= n2):
432
407
  raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
433
408
 
@@ -436,7 +411,7 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
436
411
  "d": d, "h1": h1, "h2": h2, "dtype": dtype
437
412
  }
438
413
  new_kwargs = {
439
- "keep_prob": 1,
414
+ "keep_prob": 1, # 注意:如果外部传入 keep_prob 为 0,也会在 fusion_attention_forward 中捕获
440
415
  "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)),
441
416
  "sparse_mode": kwargs.get("sparse_mode", 0),
442
417
  "prefix": kwargs.get("prefix"),
@@ -455,10 +430,13 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
455
430
  raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
456
431
 
457
432
  b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
433
+ # 此处 d 已在 parse_bsnd_args 中检查为非零
458
434
  if n1 == n2 and s1 == s2:
459
435
  logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
460
436
  else:
461
437
  logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
438
+ if n2 == 0:
439
+ raise ValueError("n2 不能为 0,避免除零错误。")
462
440
  if not (n1 % n2 == 0 and n1 >= n2):
463
441
  raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
464
442
 
@@ -468,7 +446,7 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
468
446
  }
469
447
 
470
448
  new_kwargs = {
471
- "keep_prob": 1,
449
+ "keep_prob": 1, # 同上,fusion_attention_backward 内会拦截 keep_prob 为 0 的情况
472
450
  "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)),
473
451
  "sparse_mode": kwargs.get("sparse_mode", 0),
474
452
  "prefix": kwargs.get("prefix"),
@@ -39,6 +39,8 @@ def add_api_accuracy_checker_argument(parser):
39
39
  help="<optional> The ut task result out path.")
40
40
  parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
41
41
  help="<optional> the exit csv for continue")
42
+ parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
43
+ help="<optional> Save compare failed api output.", required=False)
42
44
 
43
45
 
44
46
  def multi_add_api_accuracy_checker_argument(parser):
@@ -49,6 +51,8 @@ def multi_add_api_accuracy_checker_argument(parser):
49
51
  help="<optional> The ut task result out path.")
50
52
  parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
51
53
  help="<optional> the exit csv for continue")
54
+ parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
55
+ help="<optional> Save compare failed api output.", required=False)
52
56
  #以下属于多线程参数
53
57
  parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
54
58
  help="<optional> set device id to run ut, must be unique and in range 0-7",
@@ -17,7 +17,6 @@ import os
17
17
 
18
18
  import mindspore
19
19
  import numpy as np
20
- import torch
21
20
  from mindspore._c_expression import typing
22
21
  from msprobe.core.common.const import Const
23
22
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
@@ -188,7 +188,7 @@ class DataManager:
188
188
 
189
189
  def record_exception_skip(self, api_name, forward_or_backward, err_msg):
190
190
  '''
191
- record exception_skip infomation into self.record_exception_skip.
191
+ record exception_skip information into self.record_exception_skip.
192
192
  self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
193
193
  string in key is api_name, string in value is err_msg
194
194
  '''
@@ -270,7 +270,7 @@ class DataManager:
270
270
  entry.backward_pass_status,
271
271
  overall_err_msg
272
272
  ]
273
- # change row if this api has excption_skip infomation
273
+ # change row if this api has exception_skip information
274
274
  if api_name in self.results_exception_skip:
275
275
  if self.results_exception_skip[api_name][Const.FORWARD] is not None:
276
276
  row[1] = CompareConst.SKIP