mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,116 +1,131 @@
1
- import os
2
- import inspect
3
- import importlib
4
-
5
- import mindspore as ms
6
- from mindspore.communication import comm_func
7
-
8
- from msprobe.core.common.utils import load_yaml
9
- from msprobe.core.common.const import Const
10
- from msprobe.mindspore.common.const import FreeBenchmarkConst
11
- from msprobe.mindspore.free_benchmark.common.config import Config
12
- from msprobe.core.common.file_check import check_path_length
13
- from msprobe.mindspore.common.log import logger
14
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
15
- from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
16
-
17
-
18
- class ApiPyNativeSelFCheck:
19
- def __init__(self, config: DebuggerConfig):
20
- Config.is_enable = True
21
- Config.handler_type = config.handler_type
22
- Config.pert_type = config.pert_type
23
- Config.stage = config.stage
24
- Config.dump_level = config.dump_level
25
- Config.steps = config.step
26
- Config.ranks = config.rank
27
- Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
28
- check_path_length(Config.dump_path)
29
-
30
- self.api_list = config.list
31
- all_api = get_supported_ops()
32
- if not self.api_list:
33
- self.api_list = all_api
34
- else:
35
- self.api_list = set(self.api_list) & all_api
36
-
37
- def handle(self):
38
- for api_name in self.api_list:
39
- hijack(api_name)
40
-
41
-
42
- def get_supported_ops():
43
- supported_ops = []
44
- cur_path = os.path.dirname(os.path.realpath(__file__))
45
- yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
46
-
47
- yaml_data = load_yaml(yaml_path)
48
- for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
49
- ops = yaml_data.get(k)
50
- if ops:
51
- ops = [v + i for i in ops]
52
- supported_ops += ops
53
-
54
- _all_functional_ops = []
55
- ms_ops = dir(ms.ops)
56
- ms_ops = [FreeBenchmarkConst.OPS_PREFIX + i for i in ms_ops]
57
- _all_functional_ops += ms_ops
58
-
59
- ms_tensor = dir(ms.Tensor)
60
- ms_tensor = [FreeBenchmarkConst.Tensor_PREFIX + i for i in ms_tensor]
61
- _all_functional_ops += ms_tensor
62
-
63
- ms_mint = dir(ms.mint)
64
- ms_mint = [FreeBenchmarkConst.MINT_PREFIX + i for i in ms_mint]
65
- _all_functional_ops += ms_mint
66
-
67
- ms_mint_nn_func = dir(ms.mint.nn.functional)
68
- ms_mint_nn_func = [FreeBenchmarkConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
69
- _all_functional_ops += ms_mint_nn_func
70
-
71
- ms_communication = dir(comm_func)
72
- ms_communication = [FreeBenchmarkConst.COMM_PREFIX + i for i in ms_communication]
73
- _all_functional_ops += ms_communication
74
-
75
- return set(supported_ops) & set(_all_functional_ops)
76
-
77
-
78
- def get_decorate_func():
79
- return decorate_forward_function
80
-
81
-
82
- def is_func_support_decorate(orig_func):
83
- return not inspect.isclass(orig_func) and callable(orig_func)
84
-
85
-
86
- def get_wrapper_obj(orig_func, api_name):
87
- if is_func_support_decorate(orig_func):
88
- wrapped_obj = get_decorate_func()(orig_func, api_name)
89
- else:
90
- wrapped_obj = orig_func
91
- return wrapped_obj
92
-
93
-
94
- def get_module(api_name):
95
- func_name_list = api_name.split(Const.SEP)
96
- func_name = func_name_list[-1]
97
- module_obj = importlib.import_module(func_name_list[0])
98
- for i, module_name in enumerate(func_name_list[1:-1]):
99
- if not hasattr(module_obj, module_name):
100
- importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
101
- module_obj = getattr(module_obj, module_name)
102
- orig_func = getattr(module_obj, func_name)
103
-
104
- return module_obj, orig_func
105
-
106
-
107
- def hijack(api_name):
108
- if not api_name.strip():
109
- return
110
- try:
111
- func_name = api_name.split(Const.SEP)[-1]
112
- module_obj, origin_func = get_module(api_name)
113
- wrapped_obj = get_wrapper_obj(origin_func, api_name)
114
- setattr(module_obj, func_name, wrapped_obj)
115
- except Exception as e:
116
- logger.error(f"Failed decorator {api_name}: {e}")
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import inspect
18
+ import os
19
+
20
+ import mindspore as ms
21
+ from mindspore.communication import comm_func
22
+
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import check_path_length, load_yaml
25
+ from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
27
+ from msprobe.mindspore.common.log import logger
28
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
29
+ from msprobe.mindspore.free_benchmark.common.config import Config
30
+ from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
31
+
32
+
33
+ class ApiPyNativeSelFCheck:
34
+ def __init__(self, config: DebuggerConfig):
35
+ Config.is_enable = True
36
+ Config.handler_type = config.handler_type
37
+ Config.pert_type = config.pert_type
38
+ Config.stage = config.stage
39
+ Config.dump_level = config.dump_level
40
+ Config.steps = config.step
41
+ Config.ranks = config.rank
42
+ Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
43
+ check_path_length(Config.dump_path)
44
+
45
+ self.api_list = config.list
46
+ all_api = get_supported_ops()
47
+ if not self.api_list:
48
+ self.api_list = all_api
49
+ else:
50
+ self.api_list = set(self.api_list) & all_api
51
+
52
+ def handle(self):
53
+ for api_name in self.api_list:
54
+ hijack(api_name)
55
+
56
+
57
+ def get_supported_ops():
58
+ supported_ops = []
59
+ cur_path = os.path.dirname(os.path.realpath(__file__))
60
+ yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
61
+
62
+ yaml_data = load_yaml(yaml_path)
63
+ for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
64
+ ops = yaml_data.get(k)
65
+ if ops:
66
+ ops = [v + i for i in ops]
67
+ supported_ops += ops
68
+
69
+ _all_functional_ops = []
70
+ ms_ops = dir(ms.ops)
71
+ ms_ops = [MsConst.OPS_PREFIX + i for i in ms_ops]
72
+ _all_functional_ops += ms_ops
73
+
74
+ ms_tensor = dir(ms.Tensor)
75
+ ms_tensor = [MsConst.Tensor_PREFIX + i for i in ms_tensor]
76
+ _all_functional_ops += ms_tensor
77
+
78
+ ms_mint = dir(ms.mint)
79
+ ms_mint = [MsConst.MINT_PREFIX + i for i in ms_mint]
80
+ _all_functional_ops += ms_mint
81
+
82
+ ms_mint_nn_func = dir(ms.mint.nn.functional)
83
+ ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
84
+ _all_functional_ops += ms_mint_nn_func
85
+
86
+ ms_communication = dir(comm_func)
87
+ ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
88
+ _all_functional_ops += ms_communication
89
+
90
+ return set(supported_ops) & set(_all_functional_ops)
91
+
92
+
93
+ def get_decorate_func():
94
+ return decorate_forward_function
95
+
96
+
97
+ def is_func_support_decorate(orig_func):
98
+ return not inspect.isclass(orig_func) and callable(orig_func)
99
+
100
+
101
+ def get_wrapper_obj(orig_func, api_name):
102
+ if is_func_support_decorate(orig_func):
103
+ wrapped_obj = get_decorate_func()(orig_func, api_name)
104
+ else:
105
+ wrapped_obj = orig_func
106
+ return wrapped_obj
107
+
108
+
109
+ def get_module(api_name):
110
+ func_name_list = api_name.split(Const.SEP)
111
+ func_name = func_name_list[-1]
112
+ module_obj = importlib.import_module(func_name_list[0])
113
+ for i, module_name in enumerate(func_name_list[1:-1]):
114
+ if not hasattr(module_obj, module_name):
115
+ importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
116
+ module_obj = getattr(module_obj, module_name)
117
+ orig_func = getattr(module_obj, func_name)
118
+
119
+ return module_obj, orig_func
120
+
121
+
122
+ def hijack(api_name):
123
+ if not api_name.strip():
124
+ return
125
+ try:
126
+ func_name = api_name.split(Const.SEP)[-1]
127
+ module_obj, origin_func = get_module(api_name)
128
+ wrapped_obj = get_wrapper_obj(origin_func, api_name)
129
+ setattr(module_obj, func_name, wrapped_obj)
130
+ except Exception as e:
131
+ logger.error(f"Failed decorator {api_name}: {e}")
@@ -1,12 +1,27 @@
1
- from msprobe.mindspore.common.const import FreeBenchmarkConst
2
-
3
-
4
- class Config:
5
- is_enable: bool = False
6
- handler_type = FreeBenchmarkConst.DEFAULT_HANDLER_TYPE
7
- pert_type = FreeBenchmarkConst.DEFAULT_PERT_TYPE
8
- stage = FreeBenchmarkConst.DEFAULT_STAGE
9
- dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
10
- steps: list = []
11
- ranks: list = []
12
- dump_path: str = ""
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
17
+
18
+
19
+ class Config:
20
+ is_enable: bool = False
21
+ handler_type = FreeBenchmarkConst.DEFAULT_HANDLER_TYPE
22
+ pert_type = FreeBenchmarkConst.DEFAULT_PERT_TYPE
23
+ stage = FreeBenchmarkConst.DEFAULT_STAGE
24
+ dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
25
+ steps: list = []
26
+ ranks: list = []
27
+ dump_path: str = ""
@@ -1,17 +1,32 @@
1
- from typing import Optional, Any, Tuple, Dict, Callable
2
-
3
-
4
- class HandlerParams:
5
- """
6
- 参数结合体
7
-
8
- """
9
- args: Optional[Tuple] = None
10
- kwargs: Optional[Dict] = None
11
- index: Optional[int] = None
12
- original_result: Optional[Any] = None
13
- fuzzed_result: Optional[Any] = None
14
- is_consistent: Optional[bool] = True
15
- save_flag: Optional[bool] = True
16
- fuzzed_value: Optional[Any] = None
17
- original_func: Optional[Callable] = None
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Any, Tuple, Dict, Callable
17
+
18
+
19
+ class HandlerParams:
20
+ """
21
+ 参数结合体
22
+
23
+ """
24
+ args: Optional[Tuple] = None
25
+ kwargs: Optional[Dict] = None
26
+ index: Optional[int] = None
27
+ original_result: Optional[Any] = None
28
+ fuzzed_result: Optional[Any] = None
29
+ is_consistent: Optional[bool] = True
30
+ save_flag: Optional[bool] = True
31
+ fuzzed_value: Optional[Any] = None
32
+ original_func: Optional[Callable] = None
@@ -1,71 +1,85 @@
1
- from typing import Any
2
- from typing import Optional
3
- from dataclasses import dataclass
4
-
5
- import mindspore as ms
6
- from mindspore import Tensor
7
-
8
- from msprobe.mindspore.runtime import Runtime
9
- from msprobe.mindspore.common.const import FreeBenchmarkConst
10
- from .config import Config
11
- from .handler_params import HandlerParams
12
-
13
-
14
- class Tools:
15
-
16
- @staticmethod
17
- def get_first_tensor_dtype(tensor_seq: Any):
18
- if isinstance(tensor_seq, Tensor):
19
- return tensor_seq.dtype
20
- if isinstance(tensor_seq, (list, tuple)):
21
- for i in tensor_seq:
22
- if isinstance(i, Tensor):
23
- return i.dtype
24
- raise Exception("The sequence does not contain tensors.")
25
-
26
- @staticmethod
27
- def get_default_error_threshold(dtype):
28
- if Config.pert_type == FreeBenchmarkConst.NO_CHANGE:
29
- return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
30
- return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
31
-
32
-
33
- @dataclass
34
- class UnequalRow:
35
- rank: Optional[int] = None
36
- pert_type: Optional[str] = None
37
- stage: Optional[str] = None
38
- step: Optional[int] = None
39
- api_name: Optional[str] = None
40
- max_rel: Optional[float] = None
41
- dtype: Optional[str] = None
42
- shape: Optional[str] = None
43
- output_index: Optional[int] = None
44
-
45
-
46
- def make_unequal_row(
47
- api_name: str,
48
- params: HandlerParams,
49
- ratio: float = None,
50
- index: int = None,
51
- ):
52
- row = UnequalRow(
53
- api_name=api_name,
54
- pert_type=Config.pert_type,
55
- output_index=index,
56
- stage=Config.stage,
57
- step=Runtime.step_count
58
- )
59
- if isinstance(ratio, float):
60
- row.max_rel = ratio - 1
61
- original_tensor = params.original_result
62
- fuzzed_tensor = params.fuzzed_result
63
- if index:
64
- original_tensor = original_tensor[index]
65
- fuzzed_tensor = fuzzed_tensor[index]
66
- row.output_index = index
67
- if isinstance(original_tensor, Tensor):
68
- row.dtype = original_tensor.dtype
69
- row.shape = original_tensor.shape
70
- row.rank = Runtime.rank_id if Runtime.rank_id != -1 else None
71
- return row
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Optional
18
+
19
+ import mindspore as ms
20
+ from mindspore import Tensor
21
+
22
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
23
+ from msprobe.mindspore.free_benchmark.common.config import Config
24
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
+ from msprobe.mindspore.runtime import Runtime
26
+
27
+
28
+ class Tools:
29
+
30
+ @staticmethod
31
+ def get_first_tensor_dtype(tensor_seq: Any):
32
+ if isinstance(tensor_seq, Tensor):
33
+ return tensor_seq.dtype
34
+ if isinstance(tensor_seq, (list, tuple)):
35
+ for i in tensor_seq:
36
+ if isinstance(i, Tensor):
37
+ return i.dtype
38
+ raise Exception("The sequence does not contain tensors.")
39
+
40
+ @staticmethod
41
+ def get_default_error_threshold(dtype):
42
+ if Config.pert_type == FreeBenchmarkConst.NO_CHANGE:
43
+ return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
44
+ return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
45
+
46
+
47
+ @dataclass
48
+ class UnequalRow:
49
+ rank: Optional[int] = None
50
+ pert_type: Optional[str] = None
51
+ stage: Optional[str] = None
52
+ step: Optional[int] = None
53
+ api_name: Optional[str] = None
54
+ max_rel: Optional[float] = None
55
+ dtype: Optional[str] = None
56
+ shape: Optional[str] = None
57
+ output_index: Optional[int] = None
58
+
59
+
60
+ def make_unequal_row(
61
+ api_name: str,
62
+ params: HandlerParams,
63
+ ratio: float = None,
64
+ index: int = None,
65
+ ):
66
+ row = UnequalRow(
67
+ api_name=api_name,
68
+ pert_type=Config.pert_type,
69
+ output_index=index,
70
+ stage=Config.stage,
71
+ step=Runtime.step_count
72
+ )
73
+ if isinstance(ratio, float):
74
+ row.max_rel = ratio - 1
75
+ original_tensor = params.original_result
76
+ fuzzed_tensor = params.fuzzed_result
77
+ if index is not None:
78
+ original_tensor = original_tensor[index]
79
+ fuzzed_tensor = fuzzed_tensor[index]
80
+ row.output_index = index
81
+ if isinstance(original_tensor, Tensor):
82
+ row.dtype = original_tensor.dtype
83
+ row.shape = original_tensor.shape
84
+ row.rank = Runtime.rank_id if Runtime.rank_id != -1 else None
85
+ return row