mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,28 +1,28 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
- from msprobe.pytorch.free_benchmark.common.params import DataParams
5
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
6
- NpuBaseLayer,
7
- )
8
-
9
-
10
- class NoChangeLayer(NpuBaseLayer):
11
-
12
- def no_change(self, tensor_obj):
13
- """
14
- 不对输入做任何改变、直接二次执行
15
- """
16
- self.is_added = True
17
- return tensor_obj
18
-
19
- def handle(self, params: DataParams):
20
- """
21
- 对输入添加扰动并返回
22
- """
23
- logger.info_on_rank_0(
24
- f"[msprobe] Free benchmark: Perturbation is "
25
- f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
26
- )
27
- params.perturbed_value = self.no_change(params.args[params.valid_input_index])
28
- return self.perturbed_result(params)
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
5
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
6
+ NpuBaseLayer,
7
+ )
8
+
9
+
10
+ class NoChangeLayer(NpuBaseLayer):
11
+
12
+ def no_change(self, tensor_obj):
13
+ """
14
+ 不对输入做任何改变、直接二次执行
15
+ """
16
+ self.is_added = True
17
+ return tensor_obj
18
+
19
+ def handle(self, params: DataParams):
20
+ """
21
+ 对输入添加扰动并返回
22
+ """
23
+ logger.info_on_rank_0(
24
+ f"[msprobe] Free benchmark: Perturbation is "
25
+ f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
26
+ )
27
+ params.perturbed_value = self.no_change(params.args[params.valid_input_index])
28
+ return self.perturbed_result(params)
@@ -1,45 +1,45 @@
1
- from abc import abstractmethod
2
- from typing import Any
3
-
4
- import torch
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
-
8
-
9
- class NpuBaseLayer(BaseLayer):
10
- def __init__(self, api_name: str) -> None:
11
- super().__init__(api_name)
12
- self.perturbed_value = None # 扰动的元素
13
- self.is_added = False # 标记当前算子输入是否调整
14
-
15
- @staticmethod
16
- def perturbed_result(params: DataParams) -> Any:
17
- args_front = params.args[: params.valid_input_index]
18
- args_rear = params.args[params.valid_input_index + 1:]
19
- # 此处会将有inplace属性的算子换为非inplace
20
- if "inplace" in params.kwargs:
21
- params.kwargs["inplace"] = False
22
- params.perturbed_result = params.origin_func(
23
- *args_front, params.perturbed_value, *args_rear, **params.kwargs
24
- )
25
- return params.perturbed_result
26
-
27
- @abstractmethod
28
- def handle(self, params: DataParams) -> Any:
29
- pass
30
-
31
- def pre_check(self, tensor_obj):
32
- """
33
- 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
34
- """
35
- # 只针对第一个满足要求的添加扰动
36
- if self.is_added:
37
- return False
38
- if not torch.is_floating_point(tensor_obj):
39
- return False
40
- if not self._check_details(tensor_obj):
41
- return False
42
- return True
43
-
44
- def _check_details(self, tensor_obj):
45
- return True
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ import torch
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
+
8
+
9
+ class NpuBaseLayer(BaseLayer):
10
+ def __init__(self, api_name: str) -> None:
11
+ super().__init__(api_name)
12
+ self.perturbed_value = None # 扰动的元素
13
+ self.is_added = False # 标记当前算子输入是否调整
14
+
15
+ @staticmethod
16
+ def perturbed_result(params: DataParams) -> Any:
17
+ args_front = params.args[: params.valid_input_index]
18
+ args_rear = params.args[params.valid_input_index + 1:]
19
+ # 此处会将有inplace属性的算子换为非inplace
20
+ if "inplace" in params.kwargs:
21
+ params.kwargs["inplace"] = False
22
+ params.perturbed_result = params.origin_func(
23
+ *args_front, params.perturbed_value, *args_rear, **params.kwargs
24
+ )
25
+ return params.perturbed_result
26
+
27
+ @abstractmethod
28
+ def handle(self, params: DataParams) -> Any:
29
+ pass
30
+
31
+ def pre_check(self, tensor_obj):
32
+ """
33
+ 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
34
+ """
35
+ # 只针对第一个满足要求的添加扰动
36
+ if self.is_added:
37
+ return False
38
+ if not torch.is_floating_point(tensor_obj):
39
+ return False
40
+ if not self._check_details(tensor_obj):
41
+ return False
42
+ return True
43
+
44
+ def _check_details(self, tensor_obj):
45
+ return True
@@ -1,19 +1,19 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.params import DataParams
4
- from msprobe.pytorch.free_benchmark.common.utils import Tools
5
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
-
8
-
9
- class CpuLayer(BaseLayer):
10
-
11
- def handle(self, params: DataParams):
12
-
13
- logger.info_on_rank_0(
14
- f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
15
- )
16
- new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
17
- new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
18
- params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
19
- return params.perturbed_result
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
4
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
5
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
+
8
+
9
+ class CpuLayer(BaseLayer):
10
+
11
+ def handle(self, params: DataParams):
12
+
13
+ logger.info_on_rank_0(
14
+ f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
15
+ )
16
+ new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
17
+ new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
18
+ params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
19
+ return params.perturbed_result
@@ -1,217 +1,217 @@
1
- import math
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Optional, Tuple
4
- import numpy as np
5
-
6
- import torch
7
- from msprobe.core.common.const import Const
8
- from msprobe.pytorch.free_benchmark import logger
9
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
10
- from msprobe.pytorch.free_benchmark.common.enums import (
11
- FuzzThreshold,
12
- NormType,
13
- PerturbationMode,
14
- )
15
- from msprobe.pytorch.free_benchmark.common.params import (
16
- DataParams,
17
- HandlerParams,
18
- make_unequal_row,
19
- )
20
- from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
21
-
22
-
23
- class FuzzHandler(ABC):
24
- def __init__(self, params: HandlerParams) -> None:
25
- self.params = params
26
- self.unequal_rows = []
27
-
28
- @staticmethod
29
- def pre_process(origin_ouput, perturbed_output):
30
- if (
31
- isinstance(origin_ouput, tuple)
32
- and hasattr(origin_ouput, "values")
33
- and hasattr(origin_ouput, "indices")
34
- ):
35
- origin_ouput = origin_ouput.values
36
- perturbed_output = perturbed_output.values
37
- if hasattr(perturbed_output, "dtype"):
38
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
39
- else:
40
- abs_tol = FuzzThreshold.F32_THD
41
- return (
42
- origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
43
- perturbed_output,
44
- abs_tol,
45
- )
46
-
47
- @staticmethod
48
- def tensor_split_for_error_calculate(origin_output, perturbed_output):
49
- """
50
- 对将投入误差值计算的扰动前后输出张量进行分块
51
- :param origin_output: 原始输出
52
- :param perturbed_output: 扰动后输出
53
- :return origin_output_chunks: 切块后原始输出列表
54
- :return perturbed_output_chunks: 切块后扰动后输出列表
55
- """
56
- single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
57
- if single_output_mem == 0 or origin_output.ndim == 0:
58
- return [origin_output], [perturbed_output]
59
- # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
- chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
- chunks = 2 ** chunks_exp
62
- chunks = max(chunks, 1)
63
- chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
- origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
- perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
66
- return origin_output_chunks, perturbed_output_chunks
67
-
68
- @staticmethod
69
- def convert_overflow_ratio_to_consistent(ratio):
70
- if math.isnan(ratio) or math.isinf(ratio):
71
- return ThresholdConfig.COMP_CONSISTENT
72
- return ratio
73
-
74
- @abstractmethod
75
- def get_threshold(self, dtype):
76
- pass
77
-
78
- @abstractmethod
79
- def handle(self, data_params: DataParams) -> Any:
80
- pass
81
-
82
- def get_ratio_from_specific_norm(
83
- self, origin_output, perturbed_output, norm_type, abs_tol
84
- ):
85
- if norm_type == NormType.ENDLESS_NORM:
86
- return self.calculate_error(origin_output, perturbed_output, abs_tol)
87
- return ThresholdConfig.COMP_CONSISTENT
88
-
89
- def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
- origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
91
- norm1 = -np.inf
92
- norm2 = -np.inf
93
- norm3 = np.inf
94
- for i, chunk_origin in enumerate(origin_output_chunks):
95
- if chunk_origin.nelement() == 0:
96
- break
97
- chunk_perturbed = perturbed_output_chunks[i]
98
- ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
- TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
- ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
- TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
- norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
103
- max_ratio1, max_ratio2 = norm_values.tolist()
104
- norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
- norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
106
- norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
107
-
108
- if norm3 < 0:
109
- ratio = ThresholdConfig.SYMBOL_FLIPPING
110
- else:
111
- ratio = max(norm1, norm2)
112
- return ratio
113
-
114
- def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
115
- try:
116
- origin_output, perturbed_output, abs_tol = self.pre_process(
117
- origin_output, perturbed_output
118
- )
119
- except Exception as e:
120
- logger.warning_on_rank_0(
121
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
122
- f"when computing ratio,"
123
- f" y1 or y2 dtype is not supported {e}"
124
- )
125
- return ThresholdConfig.COMP_NAN
126
- if self.params.fuzz_stage == Const.BACKWARD:
127
- abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
128
- else:
129
- abs_tol = abs_tol ** 0.5
130
- return self.get_ratio_from_specific_norm(
131
- origin_output, perturbed_output, norm_type, abs_tol
132
- )
133
-
134
- def npu_compare(
135
- self, origin_output, perturbed_output
136
- ) -> Tuple[bool, Optional[float]]:
137
-
138
- if isinstance(perturbed_output, int):
139
- return origin_output == perturbed_output, None
140
- elif isinstance(perturbed_output, float):
141
- if perturbed_output == 0:
142
- origin_output += FuzzThreshold.F32_THD
143
- perturbed_output += FuzzThreshold.F32_THD
144
- return (
145
- math.isclose(origin_output, perturbed_output),
146
- origin_output / perturbed_output,
147
- )
148
- elif not isinstance(perturbed_output, torch.Tensor):
149
- logger.warning_on_rank_0(
150
- f"[msprobe] Free Benchmark: For {self.params.api_name} "
151
- f"The compare for output type {type(perturbed_output)} is not supported"
152
- )
153
-
154
- threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
155
- ratio = self.ratio_calculate(
156
- origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
157
- )
158
- if ratio == ThresholdConfig.SYMBOL_FLIPPING:
159
- is_consistent = False
160
- else:
161
- is_consistent = threshold >= ratio >= 1 / threshold
162
- return is_consistent, ratio
163
-
164
- def cmp_output_npu(self, data_params: DataParams):
165
- npu_consistent = True
166
- max_fuzz_ratio = 0
167
- try:
168
- if isinstance(data_params.original_result, torch.Tensor):
169
- is_consistent, ratio = self.npu_compare(
170
- data_params.original_result, data_params.perturbed_result
171
- )
172
- npu_consistent = is_consistent
173
- max_fuzz_ratio = (
174
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
175
- )
176
- data_params.is_consistent = is_consistent and data_params.is_consistent
177
- if not is_consistent and data_params.grad_unequal_flag:
178
- self.unequal_rows.append(
179
- make_unequal_row(data_params, self.params, ratio=ratio)
180
- )
181
-
182
- elif isinstance(data_params.original_result, (list, tuple)):
183
- for index_, origin_item in enumerate(data_params.original_result):
184
- is_consistent, ratio = self.npu_compare(
185
- origin_item, data_params.perturbed_result[index_]
186
- )
187
- npu_consistent = npu_consistent and is_consistent
188
- max_fuzz_ratio = (
189
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
190
- )
191
- data_params.is_consistent = (
192
- is_consistent and data_params.is_consistent
193
- )
194
- if not is_consistent and data_params.grad_unequal_flag:
195
- self.unequal_rows.append(
196
- make_unequal_row(
197
- data_params, self.params, ratio=ratio, index=index_
198
- )
199
- )
200
- except Exception as e:
201
- logger.warning_on_rank_0(
202
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
203
- f"when campare the result exception raise {e}"
204
- )
205
- return npu_consistent, max_fuzz_ratio
206
-
207
- def get_unequal_rows(self):
208
- return self.unequal_rows
209
-
210
- def _get_default_threshold(self, dtype):
211
- if self.params.pert_mode == PerturbationMode.NO_CHANGE:
212
- threshold = ThresholdConfig.COMP_CONSISTENT
213
- else:
214
- threshold = ThresholdConfig.DTYPE_PER_THD.get(
215
- dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
216
- )
217
- return threshold
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Optional, Tuple
4
+ import numpy as np
5
+
6
+ import torch
7
+ from msprobe.core.common.const import Const
8
+ from msprobe.pytorch.free_benchmark import logger
9
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
10
+ from msprobe.pytorch.free_benchmark.common.enums import (
11
+ FuzzThreshold,
12
+ NormType,
13
+ PerturbationMode,
14
+ )
15
+ from msprobe.pytorch.free_benchmark.common.params import (
16
+ DataParams,
17
+ HandlerParams,
18
+ make_unequal_row,
19
+ )
20
+ from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
21
+
22
+
23
+ class FuzzHandler(ABC):
24
+ def __init__(self, params: HandlerParams) -> None:
25
+ self.params = params
26
+ self.unequal_rows = []
27
+
28
+ @staticmethod
29
+ def pre_process(origin_ouput, perturbed_output):
30
+ if (
31
+ isinstance(origin_ouput, tuple)
32
+ and hasattr(origin_ouput, "values")
33
+ and hasattr(origin_ouput, "indices")
34
+ ):
35
+ origin_ouput = origin_ouput.values
36
+ perturbed_output = perturbed_output.values
37
+ if hasattr(perturbed_output, "dtype"):
38
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
39
+ else:
40
+ abs_tol = FuzzThreshold.F32_THD
41
+ return (
42
+ origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
43
+ perturbed_output,
44
+ abs_tol,
45
+ )
46
+
47
+ @staticmethod
48
+ def tensor_split_for_error_calculate(origin_output, perturbed_output):
49
+ """
50
+ 对将投入误差值计算的扰动前后输出张量进行分块
51
+ :param origin_output: 原始输出
52
+ :param perturbed_output: 扰动后输出
53
+ :return origin_output_chunks: 切块后原始输出列表
54
+ :return perturbed_output_chunks: 切块后扰动后输出列表
55
+ """
56
+ single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
57
+ if single_output_mem == 0 or origin_output.ndim == 0:
58
+ return [origin_output], [perturbed_output]
59
+ # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
+ chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
+ chunks = 2 ** chunks_exp
62
+ chunks = max(chunks, 1)
63
+ chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
+ origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
+ perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
66
+ return origin_output_chunks, perturbed_output_chunks
67
+
68
+ @staticmethod
69
+ def convert_overflow_ratio_to_consistent(ratio):
70
+ if math.isnan(ratio) or math.isinf(ratio):
71
+ return ThresholdConfig.COMP_CONSISTENT
72
+ return ratio
73
+
74
+ @abstractmethod
75
+ def get_threshold(self, dtype):
76
+ pass
77
+
78
+ @abstractmethod
79
+ def handle(self, data_params: DataParams) -> Any:
80
+ pass
81
+
82
+ def get_ratio_from_specific_norm(
83
+ self, origin_output, perturbed_output, norm_type, abs_tol
84
+ ):
85
+ if norm_type == NormType.ENDLESS_NORM:
86
+ return self.calculate_error(origin_output, perturbed_output, abs_tol)
87
+ return ThresholdConfig.COMP_CONSISTENT
88
+
89
+ def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
+ origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
91
+ norm1 = -np.inf
92
+ norm2 = -np.inf
93
+ norm3 = np.inf
94
+ for i, chunk_origin in enumerate(origin_output_chunks):
95
+ if chunk_origin.nelement() == 0:
96
+ break
97
+ chunk_perturbed = perturbed_output_chunks[i]
98
+ ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
+ TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
+ ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
+ TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
+ norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
103
+ max_ratio1, max_ratio2 = norm_values.tolist()
104
+ norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
+ norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
106
+ norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
107
+
108
+ if norm3 < 0:
109
+ ratio = ThresholdConfig.SYMBOL_FLIPPING
110
+ else:
111
+ ratio = max(norm1, norm2)
112
+ return ratio
113
+
114
+ def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
115
+ try:
116
+ origin_output, perturbed_output, abs_tol = self.pre_process(
117
+ origin_output, perturbed_output
118
+ )
119
+ except Exception as e:
120
+ logger.warning_on_rank_0(
121
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
122
+ f"when computing ratio,"
123
+ f" y1 or y2 dtype is not supported {e}"
124
+ )
125
+ return ThresholdConfig.COMP_NAN
126
+ if self.params.fuzz_stage == Const.BACKWARD:
127
+ abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
128
+ else:
129
+ abs_tol = abs_tol ** 0.5
130
+ return self.get_ratio_from_specific_norm(
131
+ origin_output, perturbed_output, norm_type, abs_tol
132
+ )
133
+
134
+ def npu_compare(
135
+ self, origin_output, perturbed_output
136
+ ) -> Tuple[bool, Optional[float]]:
137
+
138
+ if isinstance(perturbed_output, int):
139
+ return origin_output == perturbed_output, None
140
+ elif isinstance(perturbed_output, float):
141
+ if perturbed_output == 0:
142
+ origin_output += FuzzThreshold.F32_THD
143
+ perturbed_output += FuzzThreshold.F32_THD
144
+ return (
145
+ math.isclose(origin_output, perturbed_output),
146
+ origin_output / perturbed_output,
147
+ )
148
+ elif not isinstance(perturbed_output, torch.Tensor):
149
+ logger.warning_on_rank_0(
150
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
151
+ f"The compare for output type {type(perturbed_output)} is not supported"
152
+ )
153
+
154
+ threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
155
+ ratio = self.ratio_calculate(
156
+ origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
157
+ )
158
+ if ratio == ThresholdConfig.SYMBOL_FLIPPING:
159
+ is_consistent = False
160
+ else:
161
+ is_consistent = threshold >= ratio >= 1 / threshold
162
+ return is_consistent, ratio
163
+
164
+ def cmp_output_npu(self, data_params: DataParams):
165
+ npu_consistent = True
166
+ max_fuzz_ratio = 0
167
+ try:
168
+ if isinstance(data_params.original_result, torch.Tensor):
169
+ is_consistent, ratio = self.npu_compare(
170
+ data_params.original_result, data_params.perturbed_result
171
+ )
172
+ npu_consistent = is_consistent
173
+ max_fuzz_ratio = (
174
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
175
+ )
176
+ data_params.is_consistent = is_consistent and data_params.is_consistent
177
+ if not is_consistent and data_params.grad_unequal_flag:
178
+ self.unequal_rows.append(
179
+ make_unequal_row(data_params, self.params, ratio=ratio)
180
+ )
181
+
182
+ elif isinstance(data_params.original_result, (list, tuple)):
183
+ for index_, origin_item in enumerate(data_params.original_result):
184
+ is_consistent, ratio = self.npu_compare(
185
+ origin_item, data_params.perturbed_result[index_]
186
+ )
187
+ npu_consistent = npu_consistent and is_consistent
188
+ max_fuzz_ratio = (
189
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
190
+ )
191
+ data_params.is_consistent = (
192
+ is_consistent and data_params.is_consistent
193
+ )
194
+ if not is_consistent and data_params.grad_unequal_flag:
195
+ self.unequal_rows.append(
196
+ make_unequal_row(
197
+ data_params, self.params, ratio=ratio, index=index_
198
+ )
199
+ )
200
+ except Exception as e:
201
+ logger.warning_on_rank_0(
202
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
203
+ f"when campare the result exception raise {e}"
204
+ )
205
+ return npu_consistent, max_fuzz_ratio
206
+
207
+ def get_unequal_rows(self):
208
+ return self.unequal_rows
209
+
210
+ def _get_default_threshold(self, dtype):
211
+ if self.params.pert_mode == PerturbationMode.NO_CHANGE:
212
+ threshold = ThresholdConfig.COMP_CONSISTENT
213
+ else:
214
+ threshold = ThresholdConfig.DTYPE_PER_THD.get(
215
+ dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
216
+ )
217
+ return threshold