mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,39 +1,39 @@
1
- from typing import Any
2
-
3
- from msprobe.pytorch.free_benchmark import logger
4
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
6
- from msprobe.pytorch.free_benchmark.common.utils import Tools
7
- from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
8
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
9
-
10
-
11
- class CheckerHandler(FuzzHandler):
12
- def other_compare(self, data_params: DataParams) -> bool:
13
- is_consistent = SingleCompare().compare_seq(
14
- data_params.original_result, data_params.perturbed_result
15
- )
16
- if not is_consistent:
17
- self.unequal_rows.append(
18
- make_unequal_row(data_params, self.params)
19
- )
20
-
21
- def get_threshold(self, dtype):
22
- return self._get_default_threshold(dtype)
23
-
24
- def handle(self, data_params: DataParams) -> Any:
25
- if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
26
- data_params.perturbed_result
27
- ):
28
- return data_params.original_result
29
- try:
30
- if self.params.fuzz_device == DeviceType.NPU:
31
- self.cmp_output_npu(data_params)
32
- else:
33
- self.other_compare(data_params)
34
- except Exception as e:
35
- logger.warning_on_rank_0(
36
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
37
- f"when campare the result exception raise {e}"
38
- )
39
- return data_params.original_result
1
+ from typing import Any
2
+
3
+ from msprobe.pytorch.free_benchmark import logger
4
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
6
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
7
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
8
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
9
+
10
+
11
+ class CheckerHandler(FuzzHandler):
12
+ def other_compare(self, data_params: DataParams) -> bool:
13
+ is_consistent = SingleCompare().compare_seq(
14
+ data_params.original_result, data_params.perturbed_result
15
+ )
16
+ if not is_consistent:
17
+ self.unequal_rows.append(
18
+ make_unequal_row(data_params, self.params)
19
+ )
20
+
21
+ def get_threshold(self, dtype):
22
+ return self._get_default_threshold(dtype)
23
+
24
+ def handle(self, data_params: DataParams) -> Any:
25
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
26
+ data_params.perturbed_result
27
+ ):
28
+ return data_params.original_result
29
+ try:
30
+ if self.params.fuzz_device == DeviceType.NPU:
31
+ self.cmp_output_npu(data_params)
32
+ else:
33
+ self.other_compare(data_params)
34
+ except Exception as e:
35
+ logger.warning_on_rank_0(
36
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
37
+ f"when campare the result exception raise {e}"
38
+ )
39
+ return data_params.original_result
@@ -1,24 +1,24 @@
1
- from typing import Any
2
-
3
- from msprobe.pytorch.free_benchmark.common.params import DataParams
4
- from msprobe.pytorch.free_benchmark.common.utils import Tools
5
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
- from msprobe.pytorch.free_benchmark import logger
7
-
8
-
9
- class FixHandler(FuzzHandler):
10
-
11
- def get_threshold(self, dtype):
12
- return self._get_default_threshold(dtype)
13
-
14
- def handle(self, data_params: DataParams) -> Any:
15
- try:
16
- return Tools.convert_fuzz_output_to_origin(
17
- data_params.original_result, data_params.perturbed_result
18
- )
19
- except Exception as e:
20
- logger.warning_on_rank_0(
21
- f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
- f"Fix output failed. "
23
- )
1
+ from typing import Any
2
+
3
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
4
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
5
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
+ from msprobe.pytorch.free_benchmark import logger
7
+
8
+
9
+ class FixHandler(FuzzHandler):
10
+
11
+ def get_threshold(self, dtype):
12
+ return self._get_default_threshold(dtype)
13
+
14
+ def handle(self, data_params: DataParams) -> Any:
15
+ try:
16
+ return Tools.convert_fuzz_output_to_origin(
17
+ data_params.original_result, data_params.perturbed_result
18
+ )
19
+ except Exception as e:
20
+ logger.warning_on_rank_0(
21
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
+ f"Fix output failed. "
23
+ )
24
24
  return data_params.original_result
@@ -1,30 +1,30 @@
1
- from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
- from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
- from msprobe.pytorch.free_benchmark.common.enums import HandlerType
4
- from msprobe.pytorch.free_benchmark.common.params import HandlerParams
5
- from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
6
- from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
7
- from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
8
-
9
-
10
- class FuzzHandlerFactory:
11
-
12
- result_handlers = {
13
- HandlerType.CHECK: CheckerHandler,
14
- HandlerType.FIX: FixHandler,
15
- HandlerType.PREHEAT: PreheatHandler,
16
- }
17
-
18
- @staticmethod
19
- def create(params: HandlerParams):
20
- if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
21
- if not if_preheat:
22
- handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
- else:
24
- handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
- if not handler:
26
- raise FreeBenchmarkException(
27
- FreeBenchmarkException.UnsupportedType,
28
- f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
29
- )
30
- return handler(params)
1
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
+ from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
+ from msprobe.pytorch.free_benchmark.common.enums import HandlerType
4
+ from msprobe.pytorch.free_benchmark.common.params import HandlerParams
5
+ from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
6
+ from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
7
+ from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
8
+
9
+
10
+ class FuzzHandlerFactory:
11
+
12
+ result_handlers = {
13
+ HandlerType.CHECK: CheckerHandler,
14
+ HandlerType.FIX: FixHandler,
15
+ HandlerType.PREHEAT: PreheatHandler,
16
+ }
17
+
18
+ @staticmethod
19
+ def create(params: HandlerParams):
20
+ if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
21
+ if not if_preheat:
22
+ handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
+ else:
24
+ handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
+ if not handler:
26
+ raise FreeBenchmarkException(
27
+ FreeBenchmarkException.UnsupportedType,
28
+ f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
29
+ )
30
+ return handler(params)
@@ -1,170 +1,170 @@
1
- import math
2
- from typing import Any
3
-
4
- from msprobe.pytorch.free_benchmark import logger
5
- from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
- from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
7
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
8
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
9
- from msprobe.pytorch.free_benchmark.common.utils import Tools
10
- from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
11
- from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
12
-
13
-
14
- class PreheatHandler(FuzzHandler):
15
-
16
- def __init__(self, params: HandlerParams) -> None:
17
- super().__init__(params)
18
- self.pure_name = Tools.get_pure_api_name(self.params.api_name)
19
-
20
- def get_threshold(self, dtype):
21
- return preheat_counter.get_api_thd(self.pure_name, dtype)
22
-
23
- def compare_npu_and_cpu(self, data_params: DataParams):
24
- args = Tools.convert_device_and_dtype(
25
- data_params.args, DeviceType.CPU, change_dtype=True
26
- )
27
- kwargs = Tools.convert_device_and_dtype(
28
- data_params.kwargs, DeviceType.CPU, change_dtype=True
29
- )
30
- cpu_result = data_params.origin_func(*args, **kwargs)
31
- return SingleCompare().compare_seq(data_params.original_result, cpu_result)
32
-
33
- def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
34
- # 存储当前step所有输出比值和对应npu\cpu比对结果
35
- preheat_counter.update_preheat_record(
36
- self.pure_name,
37
- first_dtype,
38
- (max_fuzz_ratio, cpu_consistent),
39
- )
40
- if self._need_adjust_threshold():
41
- self._adjust_threshold()
42
-
43
- def handle(self, data_params: DataParams) -> Any:
44
-
45
- if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
46
- data_params.perturbed_result
47
- ):
48
- return data_params.original_result
49
-
50
- if self.params.step == 0:
51
- preheat_counter.add_one_step_used_api(self.pure_name)
52
- return data_params.original_result
53
-
54
- # 如果当前api,step需要预热
55
- npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
56
- data_params.is_consistent = npu_consistent
57
-
58
- preheat_counter.check_step(self.params.step)
59
-
60
- if self.params.preheat_config.get("preheat_step") <= self.params.step:
61
- return data_params.original_result
62
-
63
- if not data_params.grad_unequal_flag:
64
- data_params.grad_unequal_flag = True
65
- data_params.is_consistent = False
66
- return data_params.original_result
67
- preheat_counter.add_api_called_time(self.pure_name)
68
-
69
- if not self._is_take_a_sample():
70
- return data_params.original_result
71
-
72
- cpu_consistent = True
73
- try:
74
- cpu_consistent = self.compare_npu_and_cpu(data_params)
75
- except Exception as e:
76
- logger.warning_on_rank_0(
77
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
78
- f"when campare to cpu exception raise {e}"
79
- )
80
- try:
81
- first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
82
- except RuntimeError:
83
- logger.warning_on_rank_0(
84
- f"[msprobe] Free Benchmark: For {self.params.api_name}, "
85
- f"the output sequence does not contain tensors."
86
- )
87
- if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
88
- self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
89
-
90
- return data_params.original_result
91
-
92
- def _is_take_a_sample(self) -> bool:
93
- need_sample_set = self._get_need_sample_set()
94
- curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
95
- res = curr_called_seq in need_sample_set
96
- if res:
97
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
98
- logger.info_on_rank_0(
99
- f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
100
- f"api_name {self.params.api_name}, "
101
- f"curr_called_seq: {curr_called_seq}/{total_count}"
102
- )
103
- preheat_counter.add_api_sample_time(self.pure_name)
104
- return res
105
-
106
- def _get_sample_count_per_step(self) -> set:
107
- """
108
- 每一个step中应该采集的样本数
109
- """
110
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
111
- preheat_step = self.params.preheat_config.get("preheat_step")
112
- max_sample = self.params.preheat_config.get("max_sample")
113
- return min(math.ceil(total_count / preheat_step), max_sample)
114
-
115
- def _get_need_sample_set(self):
116
- """
117
- 需要采集的api集合
118
- """
119
- # 每一步样本数
120
- total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
- sample_count_per_step = self._get_sample_count_per_step()
122
- need_sample_set = set()
123
- prehead_step = self.params.preheat_config.get("preheat_step")
124
- for i in range(1, sample_count_per_step + 1):
125
- count = (prehead_step * (i - 1) + self.params.step) % total_count
126
- if count == 0:
127
- count = total_count
128
- need_sample_set.add(count)
129
- return need_sample_set
130
-
131
- def _need_adjust_threshold(self) -> bool:
132
- sample_count_per_step = self._get_sample_count_per_step()
133
- sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
134
- res = sampled_time >= sample_count_per_step
135
- return res
136
-
137
- def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
- con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
- incon_ratio = [
140
- ratio for ratio, is_consistent in compare_result if not is_consistent
141
- ]
142
- old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
- new_thd = old_thd
144
- # 正例负例都存在
145
- if con_ratio and incon_ratio:
146
- if min(incon_ratio) > max(con_ratio):
147
- new_thd = min(min(incon_ratio), old_thd)
148
- preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
149
- elif con_ratio:
150
- # 存在漏报
151
- if max(con_ratio) > old_thd:
152
- new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
153
- else:
154
- new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
155
- else:
156
- new_thd = min(min(incon_ratio), old_thd)
157
- preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
158
- return new_thd
159
-
160
- def _adjust_threshold(self):
161
- for dtype_str, compare_result in preheat_counter.preheat_record[
162
- self.pure_name
163
- ].items():
164
- new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
165
- threshold = self._get_default_threshold(
166
- preheat_counter.dtype_map.get(dtype_str)
167
- )
168
- preheat_counter.update_api_thd(
169
- self.pure_name, dtype_str, new_thd, threshold
170
- )
1
+ import math
2
+ from typing import Any
3
+
4
+ from msprobe.pytorch.free_benchmark import logger
5
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
+ from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
7
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
8
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
9
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
10
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
11
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
12
+
13
+
14
+ class PreheatHandler(FuzzHandler):
15
+
16
+ def __init__(self, params: HandlerParams) -> None:
17
+ super().__init__(params)
18
+ self.pure_name = Tools.get_pure_api_name(self.params.api_name)
19
+
20
+ def get_threshold(self, dtype):
21
+ return preheat_counter.get_api_thd(self.pure_name, dtype)
22
+
23
+ def compare_npu_and_cpu(self, data_params: DataParams):
24
+ args = Tools.convert_device_and_dtype(
25
+ data_params.args, DeviceType.CPU, change_dtype=True
26
+ )
27
+ kwargs = Tools.convert_device_and_dtype(
28
+ data_params.kwargs, DeviceType.CPU, change_dtype=True
29
+ )
30
+ cpu_result = data_params.origin_func(*args, **kwargs)
31
+ return SingleCompare().compare_seq(data_params.original_result, cpu_result)
32
+
33
+ def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
34
+ # 存储当前step所有输出比值和对应npu\cpu比对结果
35
+ preheat_counter.update_preheat_record(
36
+ self.pure_name,
37
+ first_dtype,
38
+ (max_fuzz_ratio, cpu_consistent),
39
+ )
40
+ if self._need_adjust_threshold():
41
+ self._adjust_threshold()
42
+
43
+ def handle(self, data_params: DataParams) -> Any:
44
+
45
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
46
+ data_params.perturbed_result
47
+ ):
48
+ return data_params.original_result
49
+
50
+ if self.params.step == 0:
51
+ preheat_counter.add_one_step_used_api(self.pure_name)
52
+ return data_params.original_result
53
+
54
+ # 如果当前api,step需要预热
55
+ npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
56
+ data_params.is_consistent = npu_consistent
57
+
58
+ preheat_counter.check_step(self.params.step)
59
+
60
+ if self.params.preheat_config.get("preheat_step") <= self.params.step:
61
+ return data_params.original_result
62
+
63
+ if not data_params.grad_unequal_flag:
64
+ data_params.grad_unequal_flag = True
65
+ data_params.is_consistent = False
66
+ return data_params.original_result
67
+ preheat_counter.add_api_called_time(self.pure_name)
68
+
69
+ if not self._is_take_a_sample():
70
+ return data_params.original_result
71
+
72
+ cpu_consistent = True
73
+ try:
74
+ cpu_consistent = self.compare_npu_and_cpu(data_params)
75
+ except Exception as e:
76
+ logger.warning_on_rank_0(
77
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
78
+ f"when campare to cpu exception raise {e}"
79
+ )
80
+ try:
81
+ first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
82
+ except RuntimeError:
83
+ logger.warning_on_rank_0(
84
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
85
+ f"the output sequence does not contain tensors."
86
+ )
87
+ if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
88
+ self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
89
+
90
+ return data_params.original_result
91
+
92
+ def _is_take_a_sample(self) -> bool:
93
+ need_sample_set = self._get_need_sample_set()
94
+ curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
95
+ res = curr_called_seq in need_sample_set
96
+ if res:
97
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
98
+ logger.info_on_rank_0(
99
+ f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
100
+ f"api_name {self.params.api_name}, "
101
+ f"curr_called_seq: {curr_called_seq}/{total_count}"
102
+ )
103
+ preheat_counter.add_api_sample_time(self.pure_name)
104
+ return res
105
+
106
+ def _get_sample_count_per_step(self) -> set:
107
+ """
108
+ 每一个step中应该采集的样本数
109
+ """
110
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
111
+ preheat_step = self.params.preheat_config.get("preheat_step")
112
+ max_sample = self.params.preheat_config.get("max_sample")
113
+ return min(math.ceil(total_count / preheat_step), max_sample)
114
+
115
+ def _get_need_sample_set(self):
116
+ """
117
+ 需要采集的api集合
118
+ """
119
+ # 每一步样本数
120
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
+ sample_count_per_step = self._get_sample_count_per_step()
122
+ need_sample_set = set()
123
+ prehead_step = self.params.preheat_config.get("preheat_step")
124
+ for i in range(1, sample_count_per_step + 1):
125
+ count = (prehead_step * (i - 1) + self.params.step) % total_count
126
+ if count == 0:
127
+ count = total_count
128
+ need_sample_set.add(count)
129
+ return need_sample_set
130
+
131
+ def _need_adjust_threshold(self) -> bool:
132
+ sample_count_per_step = self._get_sample_count_per_step()
133
+ sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
134
+ res = sampled_time >= sample_count_per_step
135
+ return res
136
+
137
+ def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
+ con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
+ incon_ratio = [
140
+ ratio for ratio, is_consistent in compare_result if not is_consistent
141
+ ]
142
+ old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
+ new_thd = old_thd
144
+ # 正例负例都存在
145
+ if con_ratio and incon_ratio:
146
+ if min(incon_ratio) > max(con_ratio):
147
+ new_thd = min(min(incon_ratio), old_thd)
148
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
149
+ elif con_ratio:
150
+ # 存在漏报
151
+ if max(con_ratio) > old_thd:
152
+ new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
153
+ else:
154
+ new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
155
+ else:
156
+ new_thd = min(min(incon_ratio), old_thd)
157
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
158
+ return new_thd
159
+
160
+ def _adjust_threshold(self):
161
+ for dtype_str, compare_result in preheat_counter.preheat_record[
162
+ self.pure_name
163
+ ].items():
164
+ new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
165
+ threshold = self._get_default_threshold(
166
+ preheat_counter.dtype_map.get(dtype_str)
167
+ )
168
+ preheat_counter.update_api_thd(
169
+ self.pure_name, dtype_str, new_thd, threshold
170
+ )