mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,39 +1,54 @@
1
- from typing import Any
2
-
3
- from msprobe.pytorch.free_benchmark import logger
4
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
6
- from msprobe.pytorch.free_benchmark.common.utils import Tools
7
- from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
8
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
9
-
10
-
11
- class CheckerHandler(FuzzHandler):
12
- def other_compare(self, data_params: DataParams) -> bool:
13
- is_consistent = SingleCompare().compare_seq(
14
- data_params.original_result, data_params.perturbed_result
15
- )
16
- if not is_consistent:
17
- self.unequal_rows.append(
18
- make_unequal_row(data_params, self.params)
19
- )
20
-
21
- def get_threshold(self, dtype):
22
- return self._get_default_threshold(dtype)
23
-
24
- def handle(self, data_params: DataParams) -> Any:
25
- if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
26
- data_params.perturbed_result
27
- ):
28
- return data_params.original_result
29
- try:
30
- if self.params.fuzz_device == DeviceType.NPU:
31
- self.cmp_output_npu(data_params)
32
- else:
33
- self.other_compare(data_params)
34
- except Exception as e:
35
- logger.warning_on_rank_0(
36
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
37
- f"when campare the result exception raise {e}"
38
- )
39
- return data_params.original_result
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ from msprobe.pytorch.free_benchmark import logger
19
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
20
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
21
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
22
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
23
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
24
+
25
+
26
+ class CheckerHandler(FuzzHandler):
27
+ def other_compare(self, data_params: DataParams) -> bool:
28
+ is_consistent = SingleCompare().compare_seq(
29
+ data_params.original_result, data_params.perturbed_result
30
+ )
31
+ if not is_consistent:
32
+ self.unequal_rows.append(
33
+ make_unequal_row(data_params, self.params)
34
+ )
35
+
36
+ def get_threshold(self, dtype):
37
+ return self._get_default_threshold(dtype)
38
+
39
+ def handle(self, data_params: DataParams) -> Any:
40
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
41
+ data_params.perturbed_result
42
+ ):
43
+ return data_params.original_result
44
+ try:
45
+ if self.params.fuzz_device == DeviceType.NPU:
46
+ self.cmp_output_npu(data_params)
47
+ else:
48
+ self.other_compare(data_params)
49
+ except Exception as e:
50
+ logger.warning_on_rank_0(
51
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
52
+ f"when campare the result exception raise {e}"
53
+ )
54
+ return data_params.original_result
@@ -1,24 +1,39 @@
1
- from typing import Any
2
-
3
- from msprobe.pytorch.free_benchmark.common.params import DataParams
4
- from msprobe.pytorch.free_benchmark.common.utils import Tools
5
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
- from msprobe.pytorch.free_benchmark import logger
7
-
8
-
9
- class FixHandler(FuzzHandler):
10
-
11
- def get_threshold(self, dtype):
12
- return self._get_default_threshold(dtype)
13
-
14
- def handle(self, data_params: DataParams) -> Any:
15
- try:
16
- return Tools.convert_fuzz_output_to_origin(
17
- data_params.original_result, data_params.perturbed_result
18
- )
19
- except Exception as e:
20
- logger.warning_on_rank_0(
21
- f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
- f"Fix output failed. "
23
- )
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
19
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
20
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
21
+ from msprobe.pytorch.free_benchmark import logger
22
+
23
+
24
+ class FixHandler(FuzzHandler):
25
+
26
+ def get_threshold(self, dtype):
27
+ return self._get_default_threshold(dtype)
28
+
29
+ def handle(self, data_params: DataParams) -> Any:
30
+ try:
31
+ return Tools.convert_fuzz_output_to_origin(
32
+ data_params.original_result, data_params.perturbed_result
33
+ )
34
+ except Exception as e:
35
+ logger.warning_on_rank_0(
36
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
37
+ f"Fix output failed. "
38
+ )
24
39
  return data_params.original_result
@@ -1,30 +1,45 @@
1
- from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
- from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
- from msprobe.pytorch.free_benchmark.common.enums import HandlerType
4
- from msprobe.pytorch.free_benchmark.common.params import HandlerParams
5
- from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
6
- from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
7
- from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
8
-
9
-
10
- class FuzzHandlerFactory:
11
-
12
- result_handlers = {
13
- HandlerType.CHECK: CheckerHandler,
14
- HandlerType.FIX: FixHandler,
15
- HandlerType.PREHEAT: PreheatHandler,
16
- }
17
-
18
- @staticmethod
19
- def create(params: HandlerParams):
20
- if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
21
- if not if_preheat:
22
- handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
- else:
24
- handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
- if not handler:
26
- raise FreeBenchmarkException(
27
- FreeBenchmarkException.UnsupportedType,
28
- f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
29
- )
30
- return handler(params)
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkException
17
+ from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
18
+ from msprobe.pytorch.free_benchmark.common.enums import HandlerType
19
+ from msprobe.pytorch.free_benchmark.common.params import HandlerParams
20
+ from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
21
+ from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
22
+ from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
23
+
24
+
25
+ class FuzzHandlerFactory:
26
+
27
+ result_handlers = {
28
+ HandlerType.CHECK: CheckerHandler,
29
+ HandlerType.FIX: FixHandler,
30
+ HandlerType.PREHEAT: PreheatHandler,
31
+ }
32
+
33
+ @staticmethod
34
+ def create(params: HandlerParams):
35
+ if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
36
+ if not if_preheat:
37
+ handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
38
+ else:
39
+ handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
40
+ if not handler:
41
+ raise FreeBenchmarkException(
42
+ FreeBenchmarkException.UnsupportedType,
43
+ f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
44
+ )
45
+ return handler(params)
@@ -1,170 +1,185 @@
1
- import math
2
- from typing import Any
3
-
4
- from msprobe.pytorch.free_benchmark import logger
5
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
- from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
7
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
8
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
9
- from msprobe.pytorch.free_benchmark.common.utils import Tools
10
- from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
11
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
12
-
13
-
14
- class PreheatHandler(FuzzHandler):
15
-
16
- def __init__(self, params: HandlerParams) -> None:
17
- super().__init__(params)
18
- self.pure_name = Tools.get_pure_api_name(self.params.api_name)
19
-
20
- def get_threshold(self, dtype):
21
- return preheat_counter.get_api_thd(self.pure_name, dtype)
22
-
23
- def compare_npu_and_cpu(self, data_params: DataParams):
24
- args = Tools.convert_device_and_dtype(
25
- data_params.args, DeviceType.CPU, change_dtype=True
26
- )
27
- kwargs = Tools.convert_device_and_dtype(
28
- data_params.kwargs, DeviceType.CPU, change_dtype=True
29
- )
30
- cpu_result = data_params.origin_func(*args, **kwargs)
31
- return SingleCompare().compare_seq(data_params.original_result, cpu_result)
32
-
33
- def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
34
- # 存储当前step所有输出比值和对应npu\cpu比对结果
35
- preheat_counter.update_preheat_record(
36
- self.pure_name,
37
- first_dtype,
38
- (max_fuzz_ratio, cpu_consistent),
39
- )
40
- if self._need_adjust_threshold():
41
- self._adjust_threshold()
42
-
43
- def handle(self, data_params: DataParams) -> Any:
44
-
45
- if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
46
- data_params.perturbed_result
47
- ):
48
- return data_params.original_result
49
-
50
- if self.params.step == 0:
51
- preheat_counter.add_one_step_used_api(self.pure_name)
52
- return data_params.original_result
53
-
54
- # 如果当前api,step需要预热
55
- npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
56
- data_params.is_consistent = npu_consistent
57
-
58
- preheat_counter.check_step(self.params.step)
59
-
60
- if self.params.preheat_config.get("preheat_step") <= self.params.step:
61
- return data_params.original_result
62
-
63
- if not data_params.grad_unequal_flag:
64
- data_params.grad_unequal_flag = True
65
- data_params.is_consistent = False
66
- return data_params.original_result
67
- preheat_counter.add_api_called_time(self.pure_name)
68
-
69
- if not self._is_take_a_sample():
70
- return data_params.original_result
71
-
72
- cpu_consistent = True
73
- try:
74
- cpu_consistent = self.compare_npu_and_cpu(data_params)
75
- except Exception as e:
76
- logger.warning_on_rank_0(
77
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
78
- f"when campare to cpu exception raise {e}"
79
- )
80
- try:
81
- first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
82
- except RuntimeError:
83
- logger.warning_on_rank_0(
84
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
85
- f"the output sequence does not contain tensors."
86
- )
87
- if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
88
- self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
89
-
90
- return data_params.original_result
91
-
92
- def _is_take_a_sample(self) -> bool:
93
- need_sample_set = self._get_need_sample_set()
94
- curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
95
- res = curr_called_seq in need_sample_set
96
- if res:
97
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
98
- logger.info_on_rank_0(
99
- f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
100
- f"api_name {self.params.api_name}, "
101
- f"curr_called_seq: {curr_called_seq}/{total_count}"
102
- )
103
- preheat_counter.add_api_sample_time(self.pure_name)
104
- return res
105
-
106
- def _get_sample_count_per_step(self) -> set:
107
- """
108
- 每一个step中应该采集的样本数
109
- """
110
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
111
- preheat_step = self.params.preheat_config.get("preheat_step")
112
- max_sample = self.params.preheat_config.get("max_sample")
113
- return min(math.ceil(total_count / preheat_step), max_sample)
114
-
115
- def _get_need_sample_set(self):
116
- """
117
- 需要采集的api集合
118
- """
119
- # 每一步样本数
120
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
- sample_count_per_step = self._get_sample_count_per_step()
122
- need_sample_set = set()
123
- prehead_step = self.params.preheat_config.get("preheat_step")
124
- for i in range(1, sample_count_per_step + 1):
125
- count = (prehead_step * (i - 1) + self.params.step) % total_count
126
- if count == 0:
127
- count = total_count
128
- need_sample_set.add(count)
129
- return need_sample_set
130
-
131
- def _need_adjust_threshold(self) -> bool:
132
- sample_count_per_step = self._get_sample_count_per_step()
133
- sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
134
- res = sampled_time >= sample_count_per_step
135
- return res
136
-
137
- def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
- con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
- incon_ratio = [
140
- ratio for ratio, is_consistent in compare_result if not is_consistent
141
- ]
142
- old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
- new_thd = old_thd
144
- # 正例负例都存在
145
- if con_ratio and incon_ratio:
146
- if min(incon_ratio) > max(con_ratio):
147
- new_thd = min(min(incon_ratio), old_thd)
148
- preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
149
- elif con_ratio:
150
- # 存在漏报
151
- if max(con_ratio) > old_thd:
152
- new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
153
- else:
154
- new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
155
- else:
156
- new_thd = min(min(incon_ratio), old_thd)
157
- preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
158
- return new_thd
159
-
160
- def _adjust_threshold(self):
161
- for dtype_str, compare_result in preheat_counter.preheat_record[
162
- self.pure_name
163
- ].items():
164
- new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
165
- threshold = self._get_default_threshold(
166
- preheat_counter.dtype_map.get(dtype_str)
167
- )
168
- preheat_counter.update_api_thd(
169
- self.pure_name, dtype_str, new_thd, threshold
170
- )
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Any
18
+
19
+ from msprobe.pytorch.free_benchmark import logger
20
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
21
+ from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
22
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
23
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
24
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
25
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
26
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
27
+
28
+
29
+ class PreheatHandler(FuzzHandler):
30
+
31
+ def __init__(self, params: HandlerParams) -> None:
32
+ super().__init__(params)
33
+ self.pure_name = Tools.get_pure_api_name(self.params.api_name)
34
+
35
+ def get_threshold(self, dtype):
36
+ return preheat_counter.get_api_thd(self.pure_name, dtype)
37
+
38
+ def compare_npu_and_cpu(self, data_params: DataParams):
39
+ args = Tools.convert_device_and_dtype(
40
+ data_params.args, DeviceType.CPU, change_dtype=True
41
+ )
42
+ kwargs = Tools.convert_device_and_dtype(
43
+ data_params.kwargs, DeviceType.CPU, change_dtype=True
44
+ )
45
+ cpu_result = data_params.origin_func(*args, **kwargs)
46
+ return SingleCompare().compare_seq(data_params.original_result, cpu_result)
47
+
48
+ def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
49
+ # 存储当前step所有输出比值和对应npu\cpu比对结果
50
+ preheat_counter.update_preheat_record(
51
+ self.pure_name,
52
+ first_dtype,
53
+ (max_fuzz_ratio, cpu_consistent),
54
+ )
55
+ if self._need_adjust_threshold():
56
+ self._adjust_threshold()
57
+
58
+ def handle(self, data_params: DataParams) -> Any:
59
+
60
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
61
+ data_params.perturbed_result
62
+ ):
63
+ return data_params.original_result
64
+
65
+ if self.params.step == 0:
66
+ preheat_counter.add_one_step_used_api(self.pure_name)
67
+ return data_params.original_result
68
+
69
+ # 如果当前api,step需要预热
70
+ npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
71
+ data_params.is_consistent = npu_consistent
72
+
73
+ preheat_counter.check_step(self.params.step)
74
+
75
+ if self.params.preheat_config.get("preheat_step") <= self.params.step:
76
+ return data_params.original_result
77
+
78
+ if not data_params.grad_unequal_flag:
79
+ data_params.grad_unequal_flag = True
80
+ data_params.is_consistent = False
81
+ return data_params.original_result
82
+ preheat_counter.add_api_called_time(self.pure_name)
83
+
84
+ if not self._is_take_a_sample():
85
+ return data_params.original_result
86
+
87
+ cpu_consistent = True
88
+ try:
89
+ cpu_consistent = self.compare_npu_and_cpu(data_params)
90
+ except Exception as e:
91
+ logger.warning_on_rank_0(
92
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
93
+ f"when campare to cpu exception raise {e}"
94
+ )
95
+ try:
96
+ first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
97
+ except RuntimeError:
98
+ logger.warning_on_rank_0(
99
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
100
+ f"the output sequence does not contain tensors."
101
+ )
102
+ if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
103
+ self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
104
+
105
+ return data_params.original_result
106
+
107
+ def _is_take_a_sample(self) -> bool:
108
+ need_sample_set = self._get_need_sample_set()
109
+ curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
110
+ res = curr_called_seq in need_sample_set
111
+ if res:
112
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
113
+ logger.info_on_rank_0(
114
+ f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
115
+ f"api_name {self.params.api_name}, "
116
+ f"curr_called_seq: {curr_called_seq}/{total_count}"
117
+ )
118
+ preheat_counter.add_api_sample_time(self.pure_name)
119
+ return res
120
+
121
+ def _get_sample_count_per_step(self) -> set:
122
+ """
123
+ 每一个step中应该采集的样本数
124
+ """
125
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
126
+ preheat_step = self.params.preheat_config.get("preheat_step")
127
+ max_sample = self.params.preheat_config.get("max_sample")
128
+ return min(math.ceil(total_count / preheat_step), max_sample)
129
+
130
+ def _get_need_sample_set(self):
131
+ """
132
+ 需要采集的api集合
133
+ """
134
+ # 每一步样本数
135
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
136
+ need_sample_set = set()
137
+ if total_count == 0:
138
+ return need_sample_set
139
+ sample_count_per_step = self._get_sample_count_per_step()
140
+ prehead_step = self.params.preheat_config.get("preheat_step")
141
+ for i in range(1, sample_count_per_step + 1):
142
+ count = (prehead_step * (i - 1) + self.params.step) % total_count
143
+ if count == 0:
144
+ count = total_count
145
+ need_sample_set.add(count)
146
+ return need_sample_set
147
+
148
+ def _need_adjust_threshold(self) -> bool:
149
+ sample_count_per_step = self._get_sample_count_per_step()
150
+ sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
151
+ res = sampled_time >= sample_count_per_step
152
+ return res
153
+
154
+ def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
155
+ con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
156
+ incon_ratio = [ratio for ratio, is_consistent in compare_result if not is_consistent]
157
+ old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
158
+ new_thd = old_thd
159
+ # 正例负例都存在
160
+ if con_ratio and incon_ratio:
161
+ if min(incon_ratio) > max(con_ratio):
162
+ new_thd = min(min(incon_ratio), old_thd)
163
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
164
+ elif con_ratio:
165
+ # 存在漏报
166
+ if max(con_ratio) > old_thd:
167
+ new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
168
+ else:
169
+ new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
170
+ else:
171
+ new_thd = min(min(incon_ratio), old_thd)
172
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
173
+ return new_thd
174
+
175
+ def _adjust_threshold(self):
176
+ for dtype_str, compare_result in preheat_counter.preheat_record[
177
+ self.pure_name
178
+ ].items():
179
+ new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
180
+ threshold = self._get_default_threshold(
181
+ preheat_counter.dtype_map.get(dtype_str)
182
+ )
183
+ preheat_counter.update_api_thd(
184
+ self.pure_name, dtype_str, new_thd, threshold
185
+ )