mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,45 +1,60 @@
1
- from abc import abstractmethod
2
- from typing import Any
3
-
4
- import torch
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
-
8
-
9
- class NpuBaseLayer(BaseLayer):
10
- def __init__(self, api_name: str) -> None:
11
- super().__init__(api_name)
12
- self.perturbed_value = None # 扰动的元素
13
- self.is_added = False # 标记当前算子输入是否调整
14
-
15
- @staticmethod
16
- def perturbed_result(params: DataParams) -> Any:
17
- args_front = params.args[: params.valid_input_index]
18
- args_rear = params.args[params.valid_input_index + 1:]
19
- # 此处会将有inplace属性的算子换为非inplace
20
- if "inplace" in params.kwargs:
21
- params.kwargs["inplace"] = False
22
- params.perturbed_result = params.origin_func(
23
- *args_front, params.perturbed_value, *args_rear, **params.kwargs
24
- )
25
- return params.perturbed_result
26
-
27
- @abstractmethod
28
- def handle(self, params: DataParams) -> Any:
29
- pass
30
-
31
- def pre_check(self, tensor_obj):
32
- """
33
- 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
34
- """
35
- # 只针对第一个满足要求的添加扰动
36
- if self.is_added:
37
- return False
38
- if not torch.is_floating_point(tensor_obj):
39
- return False
40
- if not self._check_details(tensor_obj):
41
- return False
42
- return True
43
-
44
- def _check_details(self, tensor_obj):
45
- return True
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import abstractmethod
17
+ from typing import Any
18
+
19
+ import torch
20
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
21
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
22
+
23
+
24
+ class NpuBaseLayer(BaseLayer):
25
+ def __init__(self, api_name: str) -> None:
26
+ super().__init__(api_name)
27
+ self.perturbed_value = None # 扰动的元素
28
+ self.is_added = False # 标记当前算子输入是否调整
29
+
30
+ @staticmethod
31
+ def perturbed_result(params: DataParams) -> Any:
32
+ args_front = params.args[: params.valid_input_index]
33
+ args_rear = params.args[params.valid_input_index + 1:]
34
+ # 此处会将有inplace属性的算子换为非inplace
35
+ if "inplace" in params.kwargs:
36
+ params.kwargs["inplace"] = False
37
+ params.perturbed_result = params.origin_func(
38
+ *args_front, params.perturbed_value, *args_rear, **params.kwargs
39
+ )
40
+ return params.perturbed_result
41
+
42
+ @abstractmethod
43
+ def handle(self, params: DataParams) -> Any:
44
+ pass
45
+
46
+ def pre_check(self, tensor_obj):
47
+ """
48
+ 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
49
+ """
50
+ # 只针对第一个满足要求的添加扰动
51
+ if self.is_added:
52
+ return False
53
+ if not torch.is_floating_point(tensor_obj):
54
+ return False
55
+ if not self._check_details(tensor_obj):
56
+ return False
57
+ return True
58
+
59
+ def _check_details(self, tensor_obj):
60
+ return True
@@ -1,19 +1,34 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark import logger
3
- from msprobe.pytorch.free_benchmark.common.params import DataParams
4
- from msprobe.pytorch.free_benchmark.common.utils import Tools
5
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
6
- from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
-
8
-
9
- class CpuLayer(BaseLayer):
10
-
11
- def handle(self, params: DataParams):
12
-
13
- logger.info_on_rank_0(
14
- f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
15
- )
16
- new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
17
- new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
18
- params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
19
- return params.perturbed_result
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.pytorch.free_benchmark import logger
18
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
19
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
20
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
21
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
22
+
23
+
24
+ class CpuLayer(BaseLayer):
25
+
26
+ def handle(self, params: DataParams):
27
+
28
+ logger.info_on_rank_0(
29
+ f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
30
+ )
31
+ new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
32
+ new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
33
+ params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
34
+ return params.perturbed_result
@@ -1,217 +1,256 @@
1
- import math
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Optional, Tuple
4
- import numpy as np
5
-
6
- import torch
7
- from msprobe.core.common.const import Const
8
- from msprobe.pytorch.free_benchmark import logger
9
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
10
- from msprobe.pytorch.free_benchmark.common.enums import (
11
- FuzzThreshold,
12
- NormType,
13
- PerturbationMode,
14
- )
15
- from msprobe.pytorch.free_benchmark.common.params import (
16
- DataParams,
17
- HandlerParams,
18
- make_unequal_row,
19
- )
20
- from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
21
-
22
-
23
- class FuzzHandler(ABC):
24
- def __init__(self, params: HandlerParams) -> None:
25
- self.params = params
26
- self.unequal_rows = []
27
-
28
- @staticmethod
29
- def pre_process(origin_ouput, perturbed_output):
30
- if (
31
- isinstance(origin_ouput, tuple)
32
- and hasattr(origin_ouput, "values")
33
- and hasattr(origin_ouput, "indices")
34
- ):
35
- origin_ouput = origin_ouput.values
36
- perturbed_output = perturbed_output.values
37
- if hasattr(perturbed_output, "dtype"):
38
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
39
- else:
40
- abs_tol = FuzzThreshold.F32_THD
41
- return (
42
- origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
43
- perturbed_output,
44
- abs_tol,
45
- )
46
-
47
- @staticmethod
48
- def tensor_split_for_error_calculate(origin_output, perturbed_output):
49
- """
50
- 对将投入误差值计算的扰动前后输出张量进行分块
51
- :param origin_output: 原始输出
52
- :param perturbed_output: 扰动后输出
53
- :return origin_output_chunks: 切块后原始输出列表
54
- :return perturbed_output_chunks: 切块后扰动后输出列表
55
- """
56
- single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
57
- if single_output_mem == 0 or origin_output.ndim == 0:
58
- return [origin_output], [perturbed_output]
59
- # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
- chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
- chunks = 2 ** chunks_exp
62
- chunks = max(chunks, 1)
63
- chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
- origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
- perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
66
- return origin_output_chunks, perturbed_output_chunks
67
-
68
- @staticmethod
69
- def convert_overflow_ratio_to_consistent(ratio):
70
- if math.isnan(ratio) or math.isinf(ratio):
71
- return ThresholdConfig.COMP_CONSISTENT
72
- return ratio
73
-
74
- @abstractmethod
75
- def get_threshold(self, dtype):
76
- pass
77
-
78
- @abstractmethod
79
- def handle(self, data_params: DataParams) -> Any:
80
- pass
81
-
82
- def get_ratio_from_specific_norm(
83
- self, origin_output, perturbed_output, norm_type, abs_tol
84
- ):
85
- if norm_type == NormType.ENDLESS_NORM:
86
- return self.calculate_error(origin_output, perturbed_output, abs_tol)
87
- return ThresholdConfig.COMP_CONSISTENT
88
-
89
- def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
- origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
91
- norm1 = -np.inf
92
- norm2 = -np.inf
93
- norm3 = np.inf
94
- for i, chunk_origin in enumerate(origin_output_chunks):
95
- if chunk_origin.nelement() == 0:
96
- break
97
- chunk_perturbed = perturbed_output_chunks[i]
98
- ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
- TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
- ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
- TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
- norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
103
- max_ratio1, max_ratio2 = norm_values.tolist()
104
- norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
- norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
106
- norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
107
-
108
- if norm3 < 0:
109
- ratio = ThresholdConfig.SYMBOL_FLIPPING
110
- else:
111
- ratio = max(norm1, norm2)
112
- return ratio
113
-
114
- def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
115
- try:
116
- origin_output, perturbed_output, abs_tol = self.pre_process(
117
- origin_output, perturbed_output
118
- )
119
- except Exception as e:
120
- logger.warning_on_rank_0(
121
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
122
- f"when computing ratio,"
123
- f" y1 or y2 dtype is not supported {e}"
124
- )
125
- return ThresholdConfig.COMP_NAN
126
- if self.params.fuzz_stage == Const.BACKWARD:
127
- abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
128
- else:
129
- abs_tol = abs_tol ** 0.5
130
- return self.get_ratio_from_specific_norm(
131
- origin_output, perturbed_output, norm_type, abs_tol
132
- )
133
-
134
- def npu_compare(
135
- self, origin_output, perturbed_output
136
- ) -> Tuple[bool, Optional[float]]:
137
-
138
- if isinstance(perturbed_output, int):
139
- return origin_output == perturbed_output, None
140
- elif isinstance(perturbed_output, float):
141
- if perturbed_output == 0:
142
- origin_output += FuzzThreshold.F32_THD
143
- perturbed_output += FuzzThreshold.F32_THD
144
- return (
145
- math.isclose(origin_output, perturbed_output),
146
- origin_output / perturbed_output,
147
- )
148
- elif not isinstance(perturbed_output, torch.Tensor):
149
- logger.warning_on_rank_0(
150
- f"[msprobe] Free Benchmark: For {self.params.api_name} "
151
- f"The compare for output type {type(perturbed_output)} is not supported"
152
- )
153
-
154
- threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
155
- ratio = self.ratio_calculate(
156
- origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
157
- )
158
- if ratio == ThresholdConfig.SYMBOL_FLIPPING:
159
- is_consistent = False
160
- else:
161
- is_consistent = threshold >= ratio >= 1 / threshold
162
- return is_consistent, ratio
163
-
164
- def cmp_output_npu(self, data_params: DataParams):
165
- npu_consistent = True
166
- max_fuzz_ratio = 0
167
- try:
168
- if isinstance(data_params.original_result, torch.Tensor):
169
- is_consistent, ratio = self.npu_compare(
170
- data_params.original_result, data_params.perturbed_result
171
- )
172
- npu_consistent = is_consistent
173
- max_fuzz_ratio = (
174
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
175
- )
176
- data_params.is_consistent = is_consistent and data_params.is_consistent
177
- if not is_consistent and data_params.grad_unequal_flag:
178
- self.unequal_rows.append(
179
- make_unequal_row(data_params, self.params, ratio=ratio)
180
- )
181
-
182
- elif isinstance(data_params.original_result, (list, tuple)):
183
- for index_, origin_item in enumerate(data_params.original_result):
184
- is_consistent, ratio = self.npu_compare(
185
- origin_item, data_params.perturbed_result[index_]
186
- )
187
- npu_consistent = npu_consistent and is_consistent
188
- max_fuzz_ratio = (
189
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
190
- )
191
- data_params.is_consistent = (
192
- is_consistent and data_params.is_consistent
193
- )
194
- if not is_consistent and data_params.grad_unequal_flag:
195
- self.unequal_rows.append(
196
- make_unequal_row(
197
- data_params, self.params, ratio=ratio, index=index_
198
- )
199
- )
200
- except Exception as e:
201
- logger.warning_on_rank_0(
202
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
203
- f"when campare the result exception raise {e}"
204
- )
205
- return npu_consistent, max_fuzz_ratio
206
-
207
- def get_unequal_rows(self):
208
- return self.unequal_rows
209
-
210
- def _get_default_threshold(self, dtype):
211
- if self.params.pert_mode == PerturbationMode.NO_CHANGE:
212
- threshold = ThresholdConfig.COMP_CONSISTENT
213
- else:
214
- threshold = ThresholdConfig.DTYPE_PER_THD.get(
215
- dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
216
- )
217
- return threshold
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from typing import Any, Optional, Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.pytorch.free_benchmark import logger
24
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
25
+ from msprobe.pytorch.free_benchmark.common.enums import (
26
+ FuzzThreshold,
27
+ NormType,
28
+ PerturbationMode,
29
+ )
30
+ from msprobe.pytorch.free_benchmark.common.params import (
31
+ DataParams,
32
+ HandlerParams,
33
+ make_unequal_row,
34
+ )
35
+ from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
36
+
37
+
38
+ class FuzzHandler(ABC):
39
+ def __init__(self, params: HandlerParams) -> None:
40
+ self.params = params
41
+ self.unequal_rows = []
42
+
43
+ @staticmethod
44
+ def pre_process(origin_ouput, perturbed_output):
45
+ if (
46
+ isinstance(origin_ouput, tuple)
47
+ and hasattr(origin_ouput, "values")
48
+ and hasattr(origin_ouput, "indices")
49
+ ):
50
+ origin_ouput = origin_ouput.values
51
+ perturbed_output = perturbed_output.values
52
+ if hasattr(perturbed_output, "dtype"):
53
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
54
+ perturbed_output.dtype, FuzzThreshold.F32_THD
55
+ )
56
+ else:
57
+ abs_tol = FuzzThreshold.F32_THD
58
+ return (
59
+ origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
60
+ perturbed_output,
61
+ abs_tol,
62
+ )
63
+
64
+ @staticmethod
65
+ def tensor_split_for_error_calculate(origin_output, perturbed_output):
66
+ """
67
+ 对将投入误差值计算的扰动前后输出张量进行分块
68
+ :param origin_output: 原始输出
69
+ :param perturbed_output: 扰动后输出
70
+ :return origin_output_chunks: 切块后原始输出列表
71
+ :return perturbed_output_chunks: 切块后扰动后输出列表
72
+ """
73
+ single_output_mem = (
74
+ origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
75
+ )
76
+ if single_output_mem == 0 or origin_output.ndim == 0:
77
+ return [origin_output], [perturbed_output]
78
+ # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
79
+ chunks_exp = int(math.log(single_output_mem, 2)) - 4
80
+ chunks = 2**chunks_exp
81
+ chunks = max(chunks, 1)
82
+ chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
83
+ origin_output_chunks = TorchC.tensor_split(
84
+ TorchC.reshape(origin_output, (-1,)), chunks
85
+ )
86
+ perturbed_output_chunks = TorchC.tensor_split(
87
+ TorchC.reshape(perturbed_output, (-1,)), chunks
88
+ )
89
+ return origin_output_chunks, perturbed_output_chunks
90
+
91
+ @staticmethod
92
+ def convert_overflow_ratio_to_consistent(ratio):
93
+ if math.isnan(ratio) or math.isinf(ratio):
94
+ return ThresholdConfig.COMP_CONSISTENT
95
+ return ratio
96
+
97
+ @abstractmethod
98
+ def get_threshold(self, dtype):
99
+ pass
100
+
101
+ @abstractmethod
102
+ def handle(self, data_params: DataParams) -> Any:
103
+ pass
104
+
105
+ def get_ratio_from_specific_norm(
106
+ self, origin_output, perturbed_output, norm_type, abs_tol
107
+ ):
108
+ if norm_type == NormType.ENDLESS_NORM:
109
+ return self.calculate_error(origin_output, perturbed_output, abs_tol)
110
+ return ThresholdConfig.COMP_CONSISTENT
111
+
112
+ def calculate_error(self, origin_output, perturbed_output, abs_tol):
113
+ origin_output_chunks, perturbed_output_chunks = (
114
+ self.tensor_split_for_error_calculate(origin_output, perturbed_output)
115
+ )
116
+ norm1 = -np.inf
117
+ norm2 = -np.inf
118
+ norm3 = np.inf
119
+ for i, chunk_origin in enumerate(origin_output_chunks):
120
+ if chunk_origin.nelement() == 0:
121
+ break
122
+ chunk_perturbed = perturbed_output_chunks[i]
123
+ ratio_tensor1 = TorchC.where(
124
+ TorchC.abs(chunk_perturbed) > abs_tol,
125
+ TorchC.div(
126
+ TorchC.clamp(chunk_origin, min=abs_tol),
127
+ TorchC.clamp(chunk_perturbed, min=abs_tol),
128
+ ),
129
+ 1,
130
+ )
131
+ ratio_tensor2 = TorchC.where(
132
+ TorchC.abs(chunk_origin) > abs_tol,
133
+ TorchC.div(
134
+ TorchC.clamp(chunk_perturbed, min=abs_tol),
135
+ TorchC.clamp(chunk_origin, min=abs_tol),
136
+ ),
137
+ 1,
138
+ )
139
+ norm_values = TorchC.stack(
140
+ [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
141
+ )
142
+ max_ratio1, max_ratio2 = norm_values.tolist()
143
+ norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
144
+ norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
145
+ norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
146
+
147
+ if norm3 < 0:
148
+ ratio = ThresholdConfig.SYMBOL_FLIPPING
149
+ else:
150
+ ratio = max(norm1, norm2)
151
+ return ratio
152
+
153
+ def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
154
+ try:
155
+ origin_output, perturbed_output, abs_tol = self.pre_process(
156
+ origin_output, perturbed_output
157
+ )
158
+ except Exception as e:
159
+ logger.warning_on_rank_0(
160
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
161
+ f"when computing ratio,"
162
+ f" y1 or y2 dtype is not supported {e}"
163
+ )
164
+ return ThresholdConfig.COMP_NAN
165
+ if self.params.fuzz_stage == Const.BACKWARD:
166
+ abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
167
+ else:
168
+ abs_tol = abs_tol**0.5
169
+ return self.get_ratio_from_specific_norm(
170
+ origin_output, perturbed_output, norm_type, abs_tol
171
+ )
172
+
173
+ def npu_compare(
174
+ self, origin_output, perturbed_output
175
+ ) -> Tuple[bool, Optional[float]]:
176
+
177
+ if isinstance(perturbed_output, int):
178
+ return origin_output == perturbed_output, None
179
+ elif isinstance(perturbed_output, float):
180
+ if perturbed_output == 0:
181
+ origin_output += FuzzThreshold.F32_THD
182
+ perturbed_output += FuzzThreshold.F32_THD
183
+ return (
184
+ math.isclose(origin_output, perturbed_output),
185
+ origin_output / perturbed_output,
186
+ )
187
+ elif not isinstance(perturbed_output, torch.Tensor):
188
+ logger.warning_on_rank_0(
189
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
190
+ f"The compare for output type {type(perturbed_output)} is not supported"
191
+ )
192
+
193
+ threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
194
+ ratio = self.ratio_calculate(
195
+ origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
196
+ )
197
+ if ratio == ThresholdConfig.SYMBOL_FLIPPING:
198
+ is_consistent = False
199
+ else:
200
+ is_consistent = threshold >= ratio >= 1 / threshold
201
+ return is_consistent, ratio
202
+
203
+ def cmp_output_npu(self, data_params: DataParams):
204
+ npu_consistent = True
205
+ max_fuzz_ratio = 0
206
+ try:
207
+ if isinstance(data_params.original_result, torch.Tensor):
208
+ is_consistent, ratio = self.npu_compare(
209
+ data_params.original_result, data_params.perturbed_result
210
+ )
211
+ npu_consistent = is_consistent
212
+ max_fuzz_ratio = (
213
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
214
+ )
215
+ data_params.is_consistent = is_consistent and data_params.is_consistent
216
+ if not is_consistent and data_params.grad_unequal_flag:
217
+ self.unequal_rows.append(
218
+ make_unequal_row(data_params, self.params, ratio=ratio)
219
+ )
220
+
221
+ elif isinstance(data_params.original_result, (list, tuple)):
222
+ for index_, origin_item in enumerate(data_params.original_result):
223
+ is_consistent, ratio = self.npu_compare(
224
+ origin_item, data_params.perturbed_result[index_]
225
+ )
226
+ npu_consistent = npu_consistent and is_consistent
227
+ max_fuzz_ratio = (
228
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
229
+ )
230
+ data_params.is_consistent = (
231
+ is_consistent and data_params.is_consistent
232
+ )
233
+ if not is_consistent and data_params.grad_unequal_flag:
234
+ self.unequal_rows.append(
235
+ make_unequal_row(
236
+ data_params, self.params, ratio=ratio, index=index_
237
+ )
238
+ )
239
+ except Exception as e:
240
+ logger.warning_on_rank_0(
241
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
242
+ f"when campare the result exception raise {e}"
243
+ )
244
+ return npu_consistent, max_fuzz_ratio
245
+
246
+ def get_unequal_rows(self):
247
+ return self.unequal_rows
248
+
249
+ def _get_default_threshold(self, dtype):
250
+ if self.params.pert_mode == PerturbationMode.NO_CHANGE:
251
+ threshold = ThresholdConfig.COMP_CONSISTENT
252
+ else:
253
+ threshold = ThresholdConfig.DTYPE_PER_THD.get(
254
+ dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
255
+ )
256
+ return threshold