mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,37 +1,65 @@
1
- class PerturbationMode:
2
- ADD_NOISE = "add_noise"
3
- CHANGE_VALUE = "change_value"
4
- IMPROVE_PRECISION = "improve_precision"
5
- NO_CHANGE = "no_change"
6
- BIT_NOISE = "bit_noise"
7
- TO_CPU = "to_cpu"
8
-
9
-
10
- class DeviceType:
11
- NPU = "npu"
12
- CPU = "cpu"
13
-
14
-
15
- class FuzzThreshold:
16
- BF16_THD = 1e-4
17
- F16_THD = 1e-6
18
- F32_THD = 1e-8
19
- F64_THD = 1e-16
20
-
21
-
22
- class NormType:
23
- ONE_NORM = (1, "one_norm")
24
- TWO_NORM = (2, "two_norm")
25
- ENDLESS_NORM = (3, "endless_norm")
26
-
27
-
28
- class HandlerType:
29
- CHECK = "check"
30
- PREHEAT = "preheat"
31
- FIX = "fix"
32
-
33
-
34
- class FuzzLevel:
35
- BASE_LEVEL = "L1"
36
- ADV_LEVEL = "L2"
37
- REAL_LEVEL = "L3"
1
+ from msprobe.core.common.const import Const
2
+
3
+
4
+ class PerturbationMode:
5
+ ADD_NOISE = "add_noise"
6
+ CHANGE_VALUE = "change_value"
7
+ IMPROVE_PRECISION = "improve_precision"
8
+ NO_CHANGE = "no_change"
9
+ BIT_NOISE = "bit_noise"
10
+ TO_CPU = "to_cpu"
11
+
12
+
13
+ class DeviceType:
14
+ NPU = "npu"
15
+ CPU = "cpu"
16
+
17
+
18
+ class FuzzThreshold:
19
+ BF16_THD = 1e-4
20
+ F16_THD = 1e-6
21
+ F32_THD = 1e-8
22
+ F64_THD = 1e-16
23
+
24
+
25
+ class NormType:
26
+ ONE_NORM = (1, "one_norm")
27
+ TWO_NORM = (2, "two_norm")
28
+ ENDLESS_NORM = (3, "endless_norm")
29
+
30
+
31
+ class HandlerType:
32
+ CHECK = "check"
33
+ PREHEAT = "preheat"
34
+ FIX = "fix"
35
+
36
+
37
+ class FuzzLevel:
38
+ BASE_LEVEL = "L1"
39
+ ADV_LEVEL = "L2"
40
+ REAL_LEVEL = "L3"
41
+
42
+
43
+ class PytorchFreeBenchmarkConst:
44
+ PERTURBATION_MODE_LIST = [
45
+ PerturbationMode.ADD_NOISE,
46
+ PerturbationMode.CHANGE_VALUE,
47
+ PerturbationMode.IMPROVE_PRECISION,
48
+ PerturbationMode.NO_CHANGE,
49
+ PerturbationMode.BIT_NOISE,
50
+ PerturbationMode.TO_CPU,
51
+ ]
52
+ DEFAULT_MODE = PerturbationMode.IMPROVE_PRECISION
53
+ DEVICE_LIST = [DeviceType.NPU, DeviceType.CPU]
54
+ DEFAULT_DEVICE = DeviceType.NPU
55
+ HANDLER_LIST = [HandlerType.CHECK, HandlerType.FIX]
56
+ DEFAULT_HANDLER = HandlerType.CHECK
57
+ FUZZ_LEVEL_LIST = [FuzzLevel.BASE_LEVEL]
58
+ DEFAULT_FUZZ_LEVEL = FuzzLevel.BASE_LEVEL
59
+ FUZZ_STAGE_LIST = [Const.FORWARD, Const.BACKWARD]
60
+ FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU]
61
+ DEFAULT_FUZZ_STAGE = Const.FORWARD
62
+ DEFAULT_PREHEAT_STEP = 15
63
+ DEFAULT_MAX_SAMPLE = 20
64
+ CPU_MODE_LIST = [PerturbationMode.TO_CPU]
65
+ FIX_STAGE_LIST = [Const.FORWARD]
@@ -1,129 +1,144 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Callable, Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from msprobe.pytorch.free_benchmark import logger
6
- from msprobe.pytorch.free_benchmark.common.enums import (
7
- DeviceType,
8
- FuzzLevel,
9
- PerturbationMode,
10
- )
11
- from msprobe.pytorch.free_benchmark.common.utils import Tools
12
-
13
-
14
- @dataclass
15
- class DataParams:
16
- args: Optional[Tuple] = None
17
- kwargs: Optional[Dict] = None
18
- valid_input_index: Optional[int] = None
19
- original_result: Optional[Any] = None
20
- perturbed_result: Optional[Any] = None
21
- is_consistent: Optional[bool] = True
22
- perturbed_value: Optional[Any] = None
23
- origin_func: Optional[Callable] = None
24
- api_type: Optional[str] = None
25
- fuzz_stage: Optional[str] = None
26
- grad_unequal_flag: Optional[bool] = True
27
-
28
-
29
- @dataclass
30
- class HandlerParams:
31
- handler_type: Optional[str] = None
32
- api_name: Optional[str] = None
33
- pert_mode: Optional[PerturbationMode] = None
34
- step: Optional[int] = None
35
- fuzz_stage: Optional[str] = None
36
- fuzz_device: Optional[DeviceType] = None
37
- preheat_config: Optional[Dict] = None
38
- fuzz_level: Optional[str] = None
39
-
40
-
41
- @dataclass
42
- class UnequalRow:
43
- rank: Optional[int] = None
44
- pert_mode: Optional[PerturbationMode] = None
45
- stage: Optional[str] = None
46
- step: Optional[int] = None
47
- api_name: Optional[str] = None
48
- max_rel: Optional[float] = None
49
- dtype: Optional[str] = None
50
- shape: Optional[str] = None
51
- output_index: Optional[int] = None
52
-
53
-
54
- @dataclass
55
- class BenchmarkThd:
56
- rtol: Optional[float] = None # 相对误差阈值
57
- small_value: Optional[float] = None # 小值域
58
- small_value_atol: Optional[float] = None # 小值域绝对阈值
59
- err_balance: Optional[float] = None # 误差均衡性
60
-
61
-
62
- def check_args_type(args: Tuple) -> int:
63
- for i, arg in enumerate(args):
64
- if torch.is_tensor(arg):
65
- if arg.is_meta:
66
- continue
67
- if not torch.is_floating_point(arg):
68
- continue
69
- return i
70
- if isinstance(arg, (List, Tuple, Dict)):
71
- return i
72
- return -1
73
-
74
-
75
- def data_pre_deal(name, func, args, kwargs):
76
- data_params = DataParams(args=args, kwargs=kwargs, origin_func=func)
77
- index = check_args_type(args)
78
- data_params.valid_input_index = index
79
- if index == -1:
80
- logger.warning_on_rank_0(
81
- f"[msprobe] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}."
82
- )
83
- return data_params
84
-
85
-
86
- def make_handler_params(name, config, step):
87
- handler_params = HandlerParams()
88
- handler_params.api_name = name
89
- handler_params.step = step
90
- handler_params.handler_type = config.handler_type
91
- handler_params.fuzz_stage = config.fuzz_stage
92
- handler_params.fuzz_device = config.fuzz_device
93
- handler_params.preheat_config = config.preheat_config
94
- handler_params.fuzz_level = config.fuzz_level
95
- handler_params.pert_mode = config.pert_mode
96
- return handler_params
97
-
98
-
99
- def make_unequal_row(
100
- data_params: DataParams,
101
- handle_params: HandlerParams,
102
- ratio: float = None,
103
- index: int = None,
104
- ):
105
- row = UnequalRow(
106
- api_name=handle_params.api_name,
107
- pert_mode=handle_params.pert_mode,
108
- output_index=index,
109
- stage=handle_params.fuzz_stage,
110
- step=handle_params.step,
111
- )
112
- if isinstance(ratio, float):
113
- row.max_rel = ratio - 1
114
- origin_tensor = data_params.original_result
115
- perturbed_tensor = data_params.perturbed_result
116
- if index:
117
- origin_tensor = origin_tensor[index]
118
- perturbed_tensor = perturbed_tensor[index]
119
- row.output_index = index
120
- if isinstance(origin_tensor, torch.Tensor):
121
- row.dtype = origin_tensor.dtype
122
- row.shape = origin_tensor.shape
123
- row.rank = Tools.get_dist_rank()
124
- # 以下暂不支持
125
- if handle_params.fuzz_level == FuzzLevel.ADV_LEVEL:
126
- pass
127
- if handle_params.fuzz_level == FuzzLevel.REAL_LEVEL:
128
- pass
129
- return row
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+ from msprobe.pytorch.free_benchmark import logger
21
+ from msprobe.pytorch.free_benchmark.common.enums import (
22
+ DeviceType,
23
+ FuzzLevel,
24
+ PerturbationMode,
25
+ )
26
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
27
+
28
+
29
+ @dataclass
30
+ class DataParams:
31
+ args: Optional[Tuple] = None
32
+ kwargs: Optional[Dict] = None
33
+ valid_input_index: Optional[int] = None
34
+ original_result: Optional[Any] = None
35
+ perturbed_result: Optional[Any] = None
36
+ is_consistent: Optional[bool] = True
37
+ perturbed_value: Optional[Any] = None
38
+ origin_func: Optional[Callable] = None
39
+ api_type: Optional[str] = None
40
+ fuzz_stage: Optional[str] = None
41
+ grad_unequal_flag: Optional[bool] = True
42
+
43
+
44
+ @dataclass
45
+ class HandlerParams:
46
+ handler_type: Optional[str] = None
47
+ api_name: Optional[str] = None
48
+ pert_mode: Optional[PerturbationMode] = None
49
+ step: Optional[int] = None
50
+ fuzz_stage: Optional[str] = None
51
+ fuzz_device: Optional[DeviceType] = None
52
+ preheat_config: Optional[Dict] = None
53
+ fuzz_level: Optional[str] = None
54
+
55
+
56
+ @dataclass
57
+ class UnequalRow:
58
+ rank: Optional[int] = None
59
+ pert_mode: Optional[PerturbationMode] = None
60
+ stage: Optional[str] = None
61
+ step: Optional[int] = None
62
+ api_name: Optional[str] = None
63
+ max_rel: Optional[float] = None
64
+ dtype: Optional[str] = None
65
+ shape: Optional[str] = None
66
+ output_index: Optional[int] = None
67
+
68
+
69
+ @dataclass
70
+ class BenchmarkThd:
71
+ rtol: Optional[float] = None # 相对误差阈值
72
+ small_value: Optional[float] = None # 小值域
73
+ small_value_atol: Optional[float] = None # 小值域绝对阈值
74
+ err_balance: Optional[float] = None # 误差均衡性
75
+
76
+
77
+ def check_args_type(args: Tuple) -> int:
78
+ for i, arg in enumerate(args):
79
+ if torch.is_tensor(arg):
80
+ if arg.is_meta:
81
+ continue
82
+ if not torch.is_floating_point(arg):
83
+ continue
84
+ return i
85
+ if isinstance(arg, (List, Tuple, Dict)):
86
+ return i
87
+ return -1
88
+
89
+
90
+ def data_pre_deal(name, func, args, kwargs):
91
+ data_params = DataParams(args=args, kwargs=kwargs, origin_func=func)
92
+ index = check_args_type(args)
93
+ data_params.valid_input_index = index
94
+ if index == -1:
95
+ logger.warning_on_rank_0(
96
+ f"[msprobe] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}."
97
+ )
98
+ return data_params
99
+
100
+
101
+ def make_handler_params(name, config, step):
102
+ handler_params = HandlerParams()
103
+ handler_params.api_name = name
104
+ handler_params.step = step
105
+ handler_params.handler_type = config.handler_type
106
+ handler_params.fuzz_stage = config.fuzz_stage
107
+ handler_params.fuzz_device = config.fuzz_device
108
+ handler_params.preheat_config = config.preheat_config
109
+ handler_params.fuzz_level = config.fuzz_level
110
+ handler_params.pert_mode = config.pert_mode
111
+ return handler_params
112
+
113
+
114
+ def make_unequal_row(
115
+ data_params: DataParams,
116
+ handle_params: HandlerParams,
117
+ ratio: float = None,
118
+ index: int = None,
119
+ ):
120
+ row = UnequalRow(
121
+ api_name=handle_params.api_name,
122
+ pert_mode=handle_params.pert_mode,
123
+ output_index=index,
124
+ stage=handle_params.fuzz_stage,
125
+ step=handle_params.step,
126
+ )
127
+ if isinstance(ratio, float):
128
+ row.max_rel = ratio - 1
129
+ origin_tensor = data_params.original_result
130
+ perturbed_tensor = data_params.perturbed_result
131
+ if index:
132
+ origin_tensor = origin_tensor[index]
133
+ perturbed_tensor = perturbed_tensor[index]
134
+ row.output_index = index
135
+ if isinstance(origin_tensor, torch.Tensor):
136
+ row.dtype = origin_tensor.dtype
137
+ row.shape = origin_tensor.shape
138
+ row.rank = Tools.get_dist_rank()
139
+ # 以下暂不支持
140
+ if handle_params.fuzz_level == FuzzLevel.ADV_LEVEL:
141
+ pass
142
+ if handle_params.fuzz_level == FuzzLevel.REAL_LEVEL:
143
+ pass
144
+ return row
@@ -1,102 +1,118 @@
1
- import torch
2
- from msprobe.pytorch.free_benchmark.common.enums import DeviceType
3
-
4
-
5
- class Tools:
6
-
7
- @staticmethod
8
- def is_float_tensor(tensor) -> bool:
9
- if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor):
10
- return True
11
- if isinstance(tensor, (list, tuple)):
12
- for value in tensor:
13
- if isinstance(value, torch.Tensor) and torch.is_floating_point(value):
14
- return True
15
- return False
16
-
17
- @staticmethod
18
- def get_dist_rank():
19
- try:
20
- return torch.distributed.get_rank()
21
- except RuntimeError:
22
- return 0
23
-
24
- @staticmethod
25
- def get_first_tensor_dtype(tensor_seq):
26
- if isinstance(tensor_seq, torch.Tensor):
27
- return tensor_seq.dtype
28
- if isinstance(tensor_seq, (list, tuple)):
29
- for object_ in tensor_seq:
30
- if isinstance(object_, torch.Tensor):
31
- return object_.dtype
32
- raise RuntimeError("The sequence does not contain tensors.")
33
-
34
- @staticmethod
35
- def get_pure_api_name(api_name: str):
36
- return api_name.rsplit(".", 2)[0]
37
-
38
- @staticmethod
39
- def convert_device_and_dtype(
40
- tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
41
- ):
42
- if isinstance(tensor_seq, torch.Tensor):
43
- if change_dtype and tensor_seq.dtype in [torch.float16, torch.bfloat16]:
44
- return tensor_seq.detach().to(device).to(torch.float32)
45
- return tensor_seq.detach().to(device)
46
- if isinstance(tensor_seq, dict):
47
- return {
48
- key: Tools.convert_device_and_dtype(value, device, change_dtype)
49
- for key, value in tensor_seq.items()
50
- }
51
- if isinstance(tensor_seq, (tuple, list)):
52
- return type(tensor_seq)(
53
- [
54
- Tools.convert_device_and_dtype(value, device, change_dtype)
55
- for value in tensor_seq
56
- ]
57
- )
58
- return tensor_seq
59
-
60
- @staticmethod
61
- def convert_fuzz_output_to_origin(origin, perturbed):
62
- if isinstance(origin, torch.Tensor):
63
- origin.data = perturbed.to(origin.dtype).to(origin.device)
64
- return origin
65
- if isinstance(origin, dict):
66
- output = dict()
67
- for key, value in origin.items():
68
- output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
69
- return output
70
- if isinstance(origin, (tuple, list)):
71
- result = list()
72
- for index_, value in enumerate(origin):
73
- result.append(
74
- Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
75
- )
76
- return type(origin)(result)
77
- return origin
78
-
79
- class TorchC:
80
- sum = torch._C._VariableFunctionsClass.sum
81
- isinf = torch._C._VariableFunctionsClass.isinf
82
- isfinite = torch._C._VariableFunctionsClass.isfinite
83
- isnan = torch._C._VariableFunctionsClass.isnan
84
- logical_not = torch._C._VariableFunctionsClass.logical_not
85
- subtract = torch._C._VariableFunctionsClass.subtract
86
- abs = torch._C._VariableFunctionsClass.abs
87
- where = torch._C._VariableFunctionsClass.where
88
- div = torch._C._VariableFunctionsClass.div
89
- max = torch._C._VariableFunctionsClass.max
90
- min = torch._C._VariableFunctionsClass.min
91
- gt = torch._C._VariableFunctionsClass.gt
92
- ge = torch._C._VariableFunctionsClass.ge
93
- lt = torch._C._VariableFunctionsClass.lt
94
- mean = torch._C._VariableFunctionsClass.mean
95
- full = torch._C._VariableFunctionsClass.full
96
- add = torch._C._VariableFunctionsClass.add
97
- bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
98
- clone = torch._C._VariableFunctionsClass.clone
99
- clamp = torch._C._VariableFunctionsClass.clamp
100
- tensor_split = torch._C._VariableFunctionsClass.tensor_split
101
- stack = torch._C._VariableFunctionsClass.stack
102
- reshape = torch._C._VariableFunctionsClass.reshape
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
18
+
19
+
20
+ class Tools:
21
+
22
+ @staticmethod
23
+ def is_float_tensor(tensor) -> bool:
24
+ if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor):
25
+ return True
26
+ if isinstance(tensor, (list, tuple)):
27
+ for value in tensor:
28
+ if isinstance(value, torch.Tensor) and torch.is_floating_point(value):
29
+ return True
30
+ return False
31
+
32
+ @staticmethod
33
+ def get_dist_rank():
34
+ try:
35
+ return torch.distributed.get_rank()
36
+ except RuntimeError:
37
+ return 0
38
+
39
+ @staticmethod
40
+ def get_first_tensor_dtype(tensor_seq):
41
+ if isinstance(tensor_seq, torch.Tensor):
42
+ return tensor_seq.dtype
43
+ if isinstance(tensor_seq, (list, tuple)):
44
+ for object_ in tensor_seq:
45
+ if isinstance(object_, torch.Tensor):
46
+ return object_.dtype
47
+ raise RuntimeError("The sequence does not contain tensors.")
48
+
49
+ @staticmethod
50
+ def get_pure_api_name(api_name: str):
51
+ return api_name.rsplit(".", 2)[0]
52
+
53
+ @staticmethod
54
+ def convert_device_and_dtype(
55
+ tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
56
+ ):
57
+ if isinstance(tensor_seq, torch.Tensor):
58
+ if change_dtype and tensor_seq.dtype in [torch.float16, torch.bfloat16]:
59
+ return tensor_seq.detach().to(device).to(torch.float32)
60
+ return tensor_seq.detach().to(device)
61
+ if isinstance(tensor_seq, dict):
62
+ return {
63
+ key: Tools.convert_device_and_dtype(value, device, change_dtype)
64
+ for key, value in tensor_seq.items()
65
+ }
66
+ if isinstance(tensor_seq, (tuple, list)):
67
+ return type(tensor_seq)(
68
+ [
69
+ Tools.convert_device_and_dtype(value, device, change_dtype)
70
+ for value in tensor_seq
71
+ ]
72
+ )
73
+ return tensor_seq
74
+
75
+ @staticmethod
76
+ def convert_fuzz_output_to_origin(origin, perturbed):
77
+ if isinstance(origin, torch.Tensor):
78
+ origin.data = perturbed.to(origin.dtype).to(origin.device)
79
+ return origin
80
+ if isinstance(origin, dict):
81
+ output = dict()
82
+ for key, value in origin.items():
83
+ output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
84
+ return output
85
+ if isinstance(origin, (tuple, list)):
86
+ result = list()
87
+ for index_, value in enumerate(origin):
88
+ result.append(
89
+ Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
90
+ )
91
+ return type(origin)(result)
92
+ return origin
93
+
94
+
95
+ class TorchC:
96
+ sum = torch._C._VariableFunctionsClass.sum
97
+ isinf = torch._C._VariableFunctionsClass.isinf
98
+ isfinite = torch._C._VariableFunctionsClass.isfinite
99
+ isnan = torch._C._VariableFunctionsClass.isnan
100
+ logical_not = torch._C._VariableFunctionsClass.logical_not
101
+ subtract = torch._C._VariableFunctionsClass.subtract
102
+ abs = torch._C._VariableFunctionsClass.abs
103
+ where = torch._C._VariableFunctionsClass.where
104
+ div = torch._C._VariableFunctionsClass.div
105
+ max = torch._C._VariableFunctionsClass.max
106
+ min = torch._C._VariableFunctionsClass.min
107
+ gt = torch._C._VariableFunctionsClass.gt
108
+ ge = torch._C._VariableFunctionsClass.ge
109
+ lt = torch._C._VariableFunctionsClass.lt
110
+ mean = torch._C._VariableFunctionsClass.mean
111
+ full = torch._C._VariableFunctionsClass.full
112
+ add = torch._C._VariableFunctionsClass.add
113
+ bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
114
+ clone = torch._C._VariableFunctionsClass.clone
115
+ clamp = torch._C._VariableFunctionsClass.clamp
116
+ tensor_split = torch._C._VariableFunctionsClass.tensor_split
117
+ stack = torch._C._VariableFunctionsClass.stack
118
+ reshape = torch._C._VariableFunctionsClass.reshape