mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,218 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch
19
+
20
+ from msprobe.core.common.const import CompareConst
21
+
22
+
23
+ class StandardConfig:
24
+ """
25
+ Standard configuration class for managing precision and comparison thresholds.
26
+
27
+ This class provides a centralized way to manage the small value thresholds, absolute tolerances,
28
+ and relative tolerances (rtol) used in precision comparisons. It allows for different thresholds
29
+ based on the data type, with default values provided for common data types.
30
+
31
+ Attributes:
32
+ _small_value (dict): A dictionary mapping data types to their corresponding small value thresholds.
33
+ _small_value_atol (dict): A dictionary mapping data types to their corresponding absolute tolerances.
34
+ _rtol (dict): A dictionary mapping data types to their corresponding relative tolerances.
35
+
36
+ Methods:
37
+ get_small_value(dtype): Retrieves the small value threshold for the given data type.
38
+ get_small_value_atol(dtype): Retrieves the absolute tolerance for the given data type.
39
+ get_rtol(dtype): Retrieves the relative tolerance for the given data type.
40
+
41
+ Example:
42
+ >>> small_value = StandardConfig.get_small_value(torch.float32)
43
+ >>> atol = StandardConfig.get_small_value_atol(torch.float32)
44
+ >>> rtol = StandardConfig.get_rtol(torch.float32)
45
+ >>> print(small_value, atol, rtol)
46
+ 1e-6 1e-9 1e-6
47
+
48
+ Note:
49
+ The data type is expected to be a PyTorch data type. If the data type is not found in the dictionary,
50
+ the default value is returned.
51
+
52
+ See Also:
53
+ torch.dtype: PyTorch data types.
54
+ """
55
+ _small_value = {
56
+ torch.float16: 2**-10,
57
+ torch.bfloat16: 2**-10,
58
+ torch.float32: 2**-20,
59
+ "default": 2**-20
60
+ }
61
+ _threshold_small_value_atol = {
62
+ torch.float16: 2**-16,
63
+ torch.bfloat16: 1e-16,
64
+ torch.float32: 2**-30,
65
+ "default": 2**-30
66
+ }
67
+ _benchmark_small_value_atol = {
68
+ torch.float16: 1e-16,
69
+ torch.bfloat16: 1e-16,
70
+ torch.float32: 2**-30,
71
+ "default": 2**-30
72
+ }
73
+ _rtol = {
74
+ torch.float16: 2**-10,
75
+ torch.bfloat16: 2**-8,
76
+ torch.float32: 2**-20,
77
+ "default": 2**-20
78
+ }
79
+ _accumulative_error_bound = {
80
+ torch.float16: 2**-8,
81
+ torch.bfloat16: 2**-7,
82
+ torch.float32: 2**-11,
83
+ "default": 2**-11
84
+ }
85
+ _small_value_threshold = {
86
+ 'error_threshold': 2,
87
+ 'warning_threshold': 1,
88
+ "default": 1
89
+ }
90
+ _rmse_threshold = {
91
+ 'error_threshold': 2,
92
+ 'warning_threshold': 1,
93
+ "default": 1
94
+ }
95
+ _max_rel_err_threshold = {
96
+ 'error_threshold': 10,
97
+ 'warning_threshold': 1,
98
+ "default": 1
99
+ }
100
+ _mean_rel_err_threshold = {
101
+ 'error_threshold': 2,
102
+ 'warning_threshold': 1,
103
+ "default": 1
104
+ }
105
+ _eb_threshold = {
106
+ 'error_threshold': 2,
107
+ 'warning_threshold': 1,
108
+ "default": 1
109
+ }
110
+ _minmum_err = {
111
+ 'torch.float16': 2**-11,
112
+ 'torch.bfloat16': 2**-8,
113
+ 'torch.float32': 2**-14,
114
+ 'default': 2**-14
115
+ }
116
+ _accumulative_error_eb_threshold = {
117
+ 'torch.float16': 2**-20,
118
+ 'torch.bfloat16': 2**-7,
119
+ 'torch.float32': 2**-14,
120
+ 'default': 2**-14
121
+ }
122
+
123
+ _fp32_mean_ulp_err_threshold = 64
124
+ ulp_err_proportion_ratio = 1
125
+ _fp32_ulp_err_proportion = 0.05
126
+ _fp16_ulp_err_proportion = 0.001
127
+ _special_samll_value = 1
128
+
129
+ @classmethod
130
+ def get_small_value(cls, dtype, standard):
131
+ if standard == CompareConst.ACCUMULATIVE_ERROR_COMPARE:
132
+ return cls._special_samll_value
133
+ return cls._small_value.get(dtype, cls._small_value["default"])
134
+
135
+ @classmethod
136
+ def get_small_value_atol(cls, dtype, standard):
137
+ standard_dict = {
138
+ CompareConst.ABSOLUTE_THRESHOLD: cls._threshold_small_value_atol,
139
+ CompareConst.BENCHMARK: cls._benchmark_small_value_atol
140
+ }
141
+ small_value_atol_standard = standard_dict.get(standard, cls._benchmark_small_value_atol)
142
+ return small_value_atol_standard.get(dtype, small_value_atol_standard["default"])
143
+
144
+ @classmethod
145
+ def get_rtol(cls, dtype):
146
+ return cls._rtol.get(dtype, cls._rtol["default"])
147
+
148
+ @classmethod
149
+ def get_small_value_threshold(cls, threshold_type):
150
+ return cls._small_value_threshold.get(threshold_type, cls._small_value_threshold["default"])
151
+
152
+ @classmethod
153
+ def get_rmse_threshold(cls, threshold_type):
154
+ return cls._rmse_threshold.get(threshold_type, cls._rmse_threshold["default"])
155
+
156
+ @classmethod
157
+ def get_max_rel_err_threshold(cls, threshold_type):
158
+ return cls._max_rel_err_threshold.get(threshold_type, cls._max_rel_err_threshold["default"])
159
+
160
+ @classmethod
161
+ def get_mean_rel_err_threshold(cls, threshold_type):
162
+ return cls._mean_rel_err_threshold.get(threshold_type, cls._mean_rel_err_threshold["default"])
163
+
164
+ @classmethod
165
+ def get_eb_threshold(cls, threshold_type):
166
+ return cls._eb_threshold.get(threshold_type, cls._eb_threshold["default"])
167
+
168
+ @classmethod
169
+ def get_benchmark_threshold(cls, metric):
170
+ metric_threshold_functions = {
171
+ 'small_value': StandardConfig.get_small_value_threshold,
172
+ 'rmse': StandardConfig.get_rmse_threshold,
173
+ 'max_rel_err': StandardConfig.get_max_rel_err_threshold,
174
+ 'mean_rel_err': StandardConfig.get_mean_rel_err_threshold,
175
+ 'eb': StandardConfig.get_eb_threshold
176
+ }
177
+
178
+ threshold_func = metric_threshold_functions.get(metric)
179
+ return threshold_func('error_threshold')
180
+
181
+ @classmethod
182
+ def get_fp32_mean_ulp_err_threshold(cls):
183
+ return cls._fp32_mean_ulp_err_threshold
184
+
185
+ @classmethod
186
+ def get_ulp_err_proportion_ratio_threshold(cls):
187
+ return cls.ulp_err_proportion_ratio
188
+
189
+ @classmethod
190
+ def get_fp32_ulp_err_proportion_threshold(cls):
191
+ return cls._fp32_ulp_err_proportion
192
+
193
+ @classmethod
194
+ def get_fp16_ulp_err_proportion_threshold(cls):
195
+ return cls._fp16_ulp_err_proportion
196
+
197
+ @classmethod
198
+ def get_ulp_threshold(cls, dtype):
199
+ ulp_err_proportion_ratio_threshold = StandardConfig.get_ulp_err_proportion_ratio_threshold()
200
+ if dtype == torch.float32:
201
+ mean_ulp_err_threshold = StandardConfig.get_fp32_mean_ulp_err_threshold()
202
+ ulp_err_proportion_threshold = StandardConfig.get_fp32_ulp_err_proportion_threshold()
203
+ return mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
204
+ else:
205
+ ulp_err_proportion_threshold = StandardConfig.get_fp16_ulp_err_proportion_threshold()
206
+ return None, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
207
+
208
+ @classmethod
209
+ def get_minmum_err(cls, dtype):
210
+ return cls._minmum_err.get(dtype, cls._minmum_err["default"])
211
+
212
+ @classmethod
213
+ def get_accumulative_error_bound(cls, dtype):
214
+ return cls._accumulative_error_bound.get(dtype, cls._accumulative_error_bound["default"])
215
+
216
+ @classmethod
217
+ def get_accumulative_error_eb_threshold(cls, dtype):
218
+ return cls._accumulative_error_eb_threshold.get(dtype, cls._accumulative_error_eb_threshold["default"])
@@ -0,0 +1,104 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Callable
19
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \
20
+ ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST
21
+ from msprobe.core.common.const import CompareConst
22
+
23
+
24
+ class StandardRegistry:
25
+ """
26
+ Registry class for managing comparison standards and functions.
27
+
28
+ This class provides a centralized registry for different comparison standards and their corresponding functions.
29
+ It allows for dynamic registration of comparison functions based on the standard category.
30
+
31
+ Attributes:
32
+ comparison_functions (dict): A dictionary mapping standard categories to their corresponding comparison
33
+ functions.
34
+ standard_categories (dict): A dictionary mapping standard names to their corresponding API categories.
35
+
36
+ Methods:
37
+ _get_standard_category(api_name, dtype): Determines the standard category for a given API name and data type.
38
+ register(standard, func): Registers a comparison function for a given standard category.
39
+ get_comparison_function(api_name, dtype): Retrieves the comparison function for a given API name and data type.
40
+
41
+ Note:
42
+ The data type is used to determine the standard category if it is not supported by binary comparison.
43
+ If the API name is not found in any standard category, it defaults to the 'benchmark' category.
44
+
45
+ See Also:
46
+ BaseCompare: The base class for comparison classes.
47
+ """
48
+ def __init__(self):
49
+ self.comparison_functions = {}
50
+ self.api_standard_function_map = {
51
+ CompareConst.ABSOLUTE_THRESHOLD: absolute_standard_api,
52
+ CompareConst.BINARY_CONSISTENCY: binary_standard_api,
53
+ CompareConst.ULP_COMPARE: ulp_standard_api,
54
+ CompareConst.THOUSANDTH_STANDARD: thousandth_standard_api,
55
+ CompareConst.ACCUMULATIVE_ERROR_COMPARE: accumulative_error_standard_api
56
+ }
57
+
58
+ def register(self, standard: str, func: Callable) -> None:
59
+ """
60
+ Registers a comparison function for a given standard category.
61
+
62
+ Args:
63
+ standard (str): The name of the standard category.
64
+ func (Callable): The comparison function to be registered.
65
+
66
+ Raises:
67
+ ValueError: If the standard category is not supported.
68
+ """
69
+ if not callable(func):
70
+ raise ValueError("The function to be registered must be callable.")
71
+ self.comparison_functions[standard] = func
72
+
73
+ def get_comparison_function(self, api_name, dtype=None):
74
+ standard = self._get_standard_category(api_name, dtype)
75
+ return self.comparison_functions.get(standard)
76
+
77
+ def _get_standard_category(self, api_name, dtype=None):
78
+ """
79
+ Determines the standard category for a given API name and data type.
80
+
81
+ This method checks if the provided data type is supported for binary comparison.
82
+ If it is, the method returns 'binary_consistency'. Otherwise, it iterates over the
83
+ api_standard_function_map to find a matching category for the API name.
84
+
85
+ Args:
86
+ api_name (str): The name of the API for which to determine the standard category.
87
+ dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. Defaults to None.
88
+
89
+ Returns:
90
+ str: The name of the standard category that matches the API name and data type, or 'benchmark' if no match
91
+ is found.
92
+
93
+ Note:
94
+ This method assumes that the api_standard_function_map is properly populated with standard categories and
95
+ their corresponding API functions.
96
+ The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for
97
+ binary comparison.
98
+ """
99
+ if dtype and dtype not in BINARY_COMPARE_UNSUPPORT_LIST:
100
+ return CompareConst.BINARY_CONSISTENCY
101
+ for name, category in self.api_standard_function_map.items():
102
+ if api_name in category:
103
+ return name
104
+ return CompareConst.BENCHMARK
@@ -0,0 +1,63 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rel_err_ratio
19
+ from msprobe.core.common.const import CompareConst
20
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
21
+
22
+
23
+ class ThousandthStdCompare(BaseCompare):
24
+ """
25
+ Thousandth standard comparison class for calculating accuracy metrics.
26
+
27
+ A subclass of BaseCompare, specifically designed to compare the relative error
28
+ between benchmark and device outputs, focusing on errors within a thousandth (0.001) threshold.
29
+
30
+ Attributes:
31
+ rel_err_orign (float or array-like): The original relative error values to be compared.
32
+ compare_column (object): An object to store and update comparison metrics.
33
+
34
+ Methods:
35
+ _compute_metrics(): Computes the relative error metrics, specifically the thousandth error ratio.
36
+ """
37
+ def __init__(self, input_data):
38
+ self.rel_err_orign = input_data.rel_err_orign
39
+ self.compare_column = input_data.compare_column
40
+
41
+ def _pre_compare(self):
42
+ pass
43
+
44
+ def _compute_metrics(self):
45
+ """
46
+ Computes the relative error metrics for the comparison, specifically focusing on errors within a thousandth
47
+ (0.001) threshold.
48
+
49
+ This method calculates the proportion of relative errors that are within the thousandth threshold.
50
+ It uses the `get_rel_err_ratio` function to determine the ratio of relative errors that are less than or
51
+ equal to the
52
+ specified threshold defined in `CompareConst.THOUSAND_RATIO_THRESHOLD`.
53
+
54
+ Returns:
55
+ dict: A dictionary containing the computed relative error metric.
56
+ The dictionary has the following key:
57
+ - 'rel_err_thousandth': The proportion of relative errors within the thousandth threshold.
58
+ """
59
+ rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
60
+
61
+ return {
62
+ 'rel_err_thousandth': rel_err_thousandth
63
+ }
@@ -0,0 +1,200 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from collections import namedtuple
19
+ import numpy as np
20
+ import torch
21
+
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
24
+ from msprobe.core.common.const import Const, CompareConst
25
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_ulp_err
26
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
27
+ is_inf_or_nan
28
+
29
+
30
+ UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', ['mean_ulp_err_inf_nan_consistency',
31
+ 'ulp_err_proportion_ratio_inf_nan_consistency'])
32
+
33
+
34
+ class UlpCompare(BaseCompare):
35
+ """
36
+ Ulp compare comparison class for calculating accuracy metrics.
37
+
38
+ Attributes:
39
+ bench_output (array-like): The benchmark output values.
40
+ device_output (array-like): The device output values.
41
+ dtype (torch.dtype): The data type of the outputs (e.g., torch.float32 or torch.float16).
42
+ ulp_err (array-like): The ULP errors calculated from the benchmark and device outputs.
43
+
44
+ Methods:
45
+ _stat_max_ulp_err(ulp_err): Calculates the maximum ULP error.
46
+ _stat_mean_ulp_err(ulp_err): Calculates the mean ULP error.
47
+ _stat_ulp_error_proportion(ulp_err): Calculates the proportion of ULP errors exceeding a threshold.
48
+ _pre_compare(): Prepares for comparison by calculating ULP errors.
49
+ _compute_metrics(): Computes the ULP error metrics.
50
+ """
51
+ def __init__(self, input_data):
52
+ super(UlpCompare, self).__init__(input_data)
53
+
54
+ @staticmethod
55
+ def _stat_max_ulp_err(ulp_err):
56
+ return np.max(ulp_err)
57
+
58
+ @staticmethod
59
+ def _stat_mean_ulp_err(ulp_err):
60
+ return np.mean(ulp_err)
61
+
62
+ def _stat_ulp_error_proportion(self, ulp_err):
63
+ if self.dtype == torch.float32:
64
+ return np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / self.bench_output.size
65
+ else:
66
+ return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size
67
+
68
+ def _pre_compare(self):
69
+ self.ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype)
70
+
71
+ def _compute_metrics(self):
72
+ """
73
+ Computes the ULP error metrics for the comparison.
74
+
75
+ This method calculates three key metrics:
76
+ 1. Maximum ULP error: The maximum difference in ULPs between the benchmark and device outputs.
77
+ 2. Mean ULP error: The average difference in ULPs between the benchmark and device outputs.
78
+ 3. ULP error proportion: The proportion of ULP errors that exceed a certain threshold.
79
+
80
+ Args:
81
+ None (this method uses instance variables)
82
+
83
+ Returns:
84
+ dict: A dictionary containing the computed ULP error metrics.
85
+ The dictionary has the following keys:
86
+ - "max_ulp_error": The maximum ULP error.
87
+ - "mean_ulp_error": The mean ULP error.
88
+ - "ulp_error_proportion": The proportion of ULP errors exceeding the threshold.
89
+ """
90
+ max_ulp_error = self._stat_max_ulp_err(self.ulp_err)
91
+ mean_ulp_error = self._stat_mean_ulp_err(self.ulp_err)
92
+
93
+ ulp_error_proportion = self._stat_ulp_error_proportion(self.ulp_err)
94
+
95
+ return {
96
+ "max_ulp_error": max_ulp_error,
97
+ "mean_ulp_error": mean_ulp_error,
98
+ "ulp_error_proportion": ulp_error_proportion
99
+ }
100
+
101
+
102
+ class UlpPrecisionCompare(BasePrecisionCompare):
103
+ def __init__(self, input_data):
104
+ super().__init__(input_data)
105
+ self.compare_algorithm = CompareConst.ULP_COMPARE_ALGORITHM_NAME
106
+
107
+ @staticmethod
108
+ def _compute_ulp_err_proportion_ratio(npu_value, gpu_value, dtype):
109
+ column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION
110
+ if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
111
+ return check_inf_or_nan(npu_value, gpu_value, column_name)
112
+ else:
113
+ return calc_ratio(npu_value, gpu_value, dtype), True, ""
114
+
115
+ def _compute_mean_ulp_err(self):
116
+ column_name = ApiPrecisionCompareColumn.MEAN_ULP_ERR
117
+ npu_value, gpu_value = self._get_and_convert_values(column_name)
118
+ if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
119
+ _, mean_ulp_err_inf_nan_consistency, message = check_inf_or_nan(npu_value, gpu_value, column_name)
120
+ return npu_value, mean_ulp_err_inf_nan_consistency, message
121
+ else:
122
+ return npu_value, True, ""
123
+
124
+ def _compute_ulp_err_proportion(self):
125
+ column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION
126
+ npu_value, gpu_value = self._get_and_convert_values(column_name)
127
+ return npu_value, gpu_value
128
+
129
+ def _get_status(self, metrics, inf_nan_consistency):
130
+ ulp_inf_nan_consistency = inf_nan_consistency.mean_ulp_err_inf_nan_consistency and \
131
+ inf_nan_consistency.ulp_err_proportion_ratio_inf_nan_consistency
132
+
133
+ if not ulp_inf_nan_consistency:
134
+ status_dict = {
135
+ CompareConst.ULP_ERR_STATUS: CompareConst.ERROR
136
+ }
137
+ compare_result = CompareConst.ERROR
138
+ metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + \
139
+ "ERROR: ULP误差不满足标准\n"
140
+ metrics.update({CompareConst.COMPARE_RESULT: compare_result})
141
+ return metrics
142
+
143
+ dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE)
144
+ mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR)
145
+ ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION)
146
+ ulp_err_proportion_ratio = metrics.get(CompareConst.ULP_ERR_PROPORTION_RATIO)
147
+ if dtype == Const.TORCH_FLOAT32:
148
+ status, final_message = \
149
+ self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio)
150
+ else:
151
+ status, final_message = \
152
+ self._get_fp16_ulp_err_status(ulp_err_proportion, ulp_err_proportion_ratio)
153
+ metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message
154
+
155
+ status_dict = {
156
+ CompareConst.ULP_ERR_STATUS: status
157
+ }
158
+ compare_result = status
159
+ metrics.update(status_dict)
160
+ metrics.update({CompareConst.COMPARE_RESULT: compare_result})
161
+ return metrics
162
+
163
+ def _get_fp32_ulp_err_status(self, mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio):
164
+ mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold = \
165
+ StandardConfig.get_ulp_threshold(torch.float32)
166
+ if mean_ulp_err < mean_ulp_err_threshold:
167
+ return CompareConst.PASS, ""
168
+ elif ulp_err_proportion < ulp_err_proportion_threshold:
169
+ return CompareConst.PASS, ""
170
+ elif ulp_err_proportion_ratio < ulp_err_proportion_ratio_threshold:
171
+ return CompareConst.PASS, ""
172
+ compare_message = "ERROR: ULP误差不满足标准\n"
173
+ return CompareConst.ERROR, compare_message
174
+
175
+ def _get_fp16_ulp_err_status(self, ulp_err_proportion, ulp_err_proportion_ratio):
176
+ _, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold = \
177
+ StandardConfig.get_ulp_threshold(torch.float16)
178
+ if ulp_err_proportion < ulp_err_proportion_threshold:
179
+ return CompareConst.PASS, ""
180
+ elif ulp_err_proportion_ratio < ulp_err_proportion_ratio_threshold:
181
+ return CompareConst.PASS, ""
182
+ compare_message = "ERROR: ULP误差不满足标准\n"
183
+ return CompareConst.ERROR, compare_message
184
+
185
+ def _compute_ratio(self):
186
+ compare_message = ""
187
+ mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err()
188
+ compare_message += mean_ulp_err_message
189
+ npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion()
190
+ ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \
191
+ self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype))
192
+ compare_message += ulp_err_proportion_ratio_message
193
+ metrics = {
194
+ CompareConst.MEAN_ULP_ERR: mean_ulp_err,
195
+ CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion,
196
+ CompareConst.ULP_ERR_PROPORTION_RATIO: ulp_err_proportion_ratio,
197
+ CompareConst.COMPARE_MESSAGE: compare_message
198
+ }
199
+ return metrics, UlpInfNanConsistency(mean_ulp_err_inf_nan_consistency,
200
+ ulp_err_proportion_ratio_inf_nan_consistency)
@@ -28,6 +28,7 @@ from msprobe.pytorch.common.log import logger
28
28
  from msprobe.pytorch.common.utils import load_pt
29
29
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
30
30
 
31
+
31
32
  TORCH_TYPE = ["torch.device", "torch.dtype"]
32
33
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
34
  FLOAT_TYPE = [
@@ -310,6 +311,19 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
310
311
  kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
311
312
  elif value is None:
312
313
  kwargs_params[key] = None
314
+ elif key == 'atten_mask' and api_name == 'npu_fusion_attention':
315
+ sparse_mode = kwargs_params.get('sparse_mode', {})
316
+ if isinstance(sparse_mode, dict):
317
+ sparse_mode_value = sparse_mode.get('value', 0)
318
+ elif isinstance(sparse_mode, int):
319
+ sparse_mode_value = sparse_mode
320
+ else:
321
+ msg = f'The sparse_mode value is not int or dict, but {type(sparse_mode)}'
322
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, msg)
323
+ if sparse_mode_value in Const.FA_SPECIAL_SPARSE_MODE:
324
+ kwargs_params[key] = gen_atten_mask(value, convert_type, real_data_path)
325
+ else:
326
+ kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
313
327
  elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
314
328
  kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
315
329
  elif value.get('type') in TORCH_TYPE:
@@ -319,6 +333,30 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
319
333
  return kwargs_params
320
334
 
321
335
 
336
+ def gen_atten_mask(info, convert_type, real_data_path):
337
+ """
338
+ Function Description:
339
+ Based on API basic information, generate input parameters: atten_mask, for API forward running
340
+ Parameter:
341
+ info: API basic information. Dict
342
+ convert_type: convert ori_type to dist_type flag.
343
+ real_data_path: the root directory for storing real data.
344
+ """
345
+ check_object_type(info, dict)
346
+ data_type = info.get('type')
347
+ data_path = info.get('datapath', info.get('data_name'))
348
+ data_path = get_full_data_path(data_path, real_data_path)
349
+ data = None
350
+ if data_type in TENSOR_DATA_LIST:
351
+ if data_path:
352
+ data = gen_real_tensor(data_path, convert_type)
353
+ else:
354
+ # 生成一个2048x2048的三角矩阵,对角线为1,其余为0
355
+ # 这是npu_fusion_attention的sparse_mode为[2, 3, 4]时,atten_mask的shape
356
+ data = torch.triu(torch.ones([2048, 2048]), diagonal=1).to(torch.bool)
357
+ return data
358
+
359
+
322
360
  def gen_torch_kwargs(kwargs_params, key, value):
323
361
  if value.get('type') != "torch.device":
324
362
  module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
@@ -346,6 +384,23 @@ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=No
346
384
  return kwargs_item_result
347
385
 
348
386
 
387
+ def get_output_dtype(api_info):
388
+ """
389
+ Function Description:
390
+ Based on API basic information, get the output data dtype
391
+ Parameter:
392
+ api_info: API basic information. Dict
393
+ """
394
+ output_dtype = None
395
+ output_info = api_info.get(Const.OUTPUT)
396
+ if output_info and isinstance(output_info[0], dict):
397
+ output_str_dtype = output_info[0].get(Const.DTYPE)
398
+ if output_str_dtype in Const.TORCH_FLOAT_DTYPE:
399
+ module_name, attribute_name = get_module_and_atttribute_name(output_str_dtype)
400
+ output_dtype = get_attribute(module_name, attribute_name)
401
+ return output_dtype
402
+
403
+
349
404
  def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
350
405
  """
351
406
  Function Description:
@@ -372,4 +427,5 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
372
427
  else:
373
428
  logger.warning(f'Warning: No args in {api_info} ')
374
429
  args_params = []
375
- return args_params, kwargs_params
430
+ output_dtype = get_output_dtype(api_info)
431
+ return args_params, kwargs_params, output_dtype