mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,509 @@
1
+ import torch
2
+ import numpy as np
3
+ from einops import rearrange
4
+ try:
5
+ import torch_npu
6
+ except ImportError:
7
+ is_gpu = True
8
+ try:
9
+ # flash_attn为gpu的fa三方库
10
+ from flash_attn import flash_attn_func
11
+ except ImportError:
12
+ #如果为cpu的ut环境,则不做任何处理
13
+ pass
14
+ else:
15
+ is_gpu = False
16
+
17
+
18
+ from msprobe.pytorch.common.utils import logger
19
+ from msprobe.core.common.const import Const, CompareConst
20
+
21
+ gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
22
+ softmax_build_mode = "QKV" # "MAX_SUM"
23
+
24
+ """
25
+ # 前向函数声明对比
26
+ 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
27
+ 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
28
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
30
+ gen_mask_parallel=True, sync=False
31
+
32
+ # 反向函数声明对比
33
+ 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
34
+ 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
35
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
36
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
37
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
38
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
39
+ """
40
+
41
+
42
+ def softmax_forward(x):
43
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
44
+ x_sub = x.sub(x_max)
45
+ y = torch.exp(x_sub)
46
+ x_sum = y.sum(dim=-1, keepdims=True)
47
+ res = y.div(x_sum)
48
+ return res, x_max, x_sum
49
+
50
+
51
+ def softmax_grad(dp, softmax_res):
52
+ muls = dp * softmax_res
53
+ muls_r = muls.sum(dim=-1, keepdims=True)
54
+ sub_r = dp - muls_r
55
+ res = sub_r * softmax_res
56
+ return res
57
+
58
+
59
+ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
60
+ if num_kv_heads == 0 or num_kv_heads > num_heads:
61
+ raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.")
62
+
63
+ factor = num_heads // num_kv_heads
64
+ kv_shape = kv_tensor.shape
65
+ B = kv_shape[0]
66
+ S = kv_shape[2]
67
+ D = kv_shape[3]
68
+ kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
69
+ for i in range(num_heads):
70
+ j = i // factor
71
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
72
+ return kv_res
73
+
74
+
75
+ def calculate_qk(q, k, atten_mask, pse, scale):
76
+ if pse is None or len(pse.shape) == 0:
77
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
78
+ else:
79
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
80
+ if atten_mask is None or len(atten_mask.shape) == 0:
81
+ return qk
82
+ else:
83
+ qk = qk + atten_mask.bool() * (-40000.0) # -10000
84
+ return qk
85
+
86
+
87
+ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
88
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
89
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
90
+ if drop_mask is None or len(drop_mask.shape) == 0:
91
+ drop_res = softmax_res
92
+ else:
93
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
94
+ y = torch.matmul(drop_res, v)
95
+ return y, softmax_max, softmax_sum
96
+
97
+
98
+ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
99
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
100
+ if drop_mask is None or len(drop_mask.shape) == 0:
101
+ drop_res = softmax_res.permute(0, 1, 3, 2)
102
+ dp_drop = dp
103
+ else:
104
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
105
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
106
+ dv = torch.matmul(drop_res, dx)
107
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
108
+ dq = torch.matmul(softmax_grad_res, k)
109
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
110
+ return dq, dk, dv
111
+
112
+
113
+ def parse_bsnd_args(query, key, head_num, input_layout):
114
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
115
+ B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
116
+
117
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
118
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
119
+
120
+ if input_layout == "TND":
121
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
122
+ try:
123
+ if input_layout == "BSH":
124
+ B, S1, H1 = query.shape
125
+ _, S2, H2 = key.shape
126
+ D = H1 // N1
127
+ N2 = H2 // D
128
+ elif input_layout == "SBH":
129
+ S1, B, H1 = query.shape
130
+ S2, _, H2 = key.shape
131
+ D = H1 // N1
132
+ N2 = H2 // D
133
+ elif input_layout == "BSND":
134
+ B, S1, N1, D = query.shape
135
+ _, S2, N2, _ = key.shape
136
+ H1 = N1 * D
137
+ H2 = N2 * D
138
+ elif input_layout == "BNSD":
139
+ B, N1, S1, D = query.shape
140
+ _, N2, S2, _ = key.shape
141
+ H1 = N1 * D
142
+ H2 = N2 * D
143
+ except Exception as e:
144
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
145
+
146
+ if D == 0:
147
+ raise ValueError(f"Value D must be non-zero.")
148
+ DTYPE = query.dtype
149
+ return B, S1, S2, N1, N2, D, H1, H2, DTYPE
150
+
151
+
152
+ def convert_from_bnsd(_input, input_layout):
153
+ if input_layout == "BSH":
154
+ # (B,N,S,D)=>(B,S,N*D)
155
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
156
+ elif input_layout == "SBH":
157
+ # (B,N,S,D)=>(S,B,N*D)
158
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
159
+ elif input_layout == "BSND":
160
+ # (B,N,S,D)=>(B,S,N,D)
161
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
162
+ elif input_layout == "TND":
163
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
164
+ else:
165
+ out = _input
166
+ return out
167
+
168
+
169
+ def convert_to_bnsd(_input, n, input_layout):
170
+ # 默认"BNSD"无需处理
171
+ if input_layout == "BSH":
172
+ # (B,S,N*D)=>(B,N,S,D)
173
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
174
+ elif input_layout == "SBH":
175
+ # (S,B,N*D)=>(B,N,S,D)
176
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
177
+ elif input_layout == "BSND":
178
+ # (B,S,N,D)=>(B,N,S,D)
179
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
180
+ elif input_layout == "TND":
181
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
182
+ else:
183
+ out = _input
184
+ if out.dim() != 4:
185
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
186
+ return out.to(gtype)
187
+
188
+
189
+ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
190
+ """
191
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
192
+ ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
193
+ """
194
+ shape = [S1, S2]
195
+
196
+ if atten_mask is not None:
197
+ # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
198
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
199
+ logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
200
+
201
+ if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
202
+ if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
203
+ if sparse_mode == 2:
204
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
205
+ elif sparse_mode == 3:
206
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
207
+ elif sparse_mode == 4:
208
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
209
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
210
+ atten_mask = atten_mask_u + atten_mask_l
211
+ logger.debug(f"反向转换atten_mask {atten_mask.shape}")
212
+ return atten_mask.to(dtype)
213
+
214
+ return atten_mask.to(dtype)
215
+
216
+ if atten_mask is not None:
217
+ if atten_mask.dim() == 2:
218
+ if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
219
+ raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
220
+ shape = [S1, S2]
221
+ elif atten_mask.dim() == 4:
222
+ if atten_mask.shape[1] == 1:
223
+ shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
224
+ else:
225
+ shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
226
+
227
+ if sparse_mode == 0:
228
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
229
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
230
+ atten_mask = atten_mask_u + atten_mask_l
231
+ elif sparse_mode == 1: # no sparse
232
+ atten_mask = torch.from_numpy(np.zeros(shape))
233
+ elif sparse_mode == 2:
234
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
235
+ elif sparse_mode == 3:
236
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
237
+ elif sparse_mode == 4:
238
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
239
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
240
+ atten_mask = atten_mask_u + atten_mask_l
241
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
242
+ # 因此可以认为FA的输入已经是正确的atten_mask了
243
+ return atten_mask.to(dtype)
244
+
245
+
246
+ def generate_kv(key, value, N1, N2):
247
+ # N不等长适配by cdy
248
+ if not (N1 == N2):
249
+ k_new = broadcast_kv(N1, N2, key, key.dtype)
250
+ v_new = broadcast_kv(N1, N2, value, value.dtype)
251
+ else:
252
+ k_new = key
253
+ v_new = value
254
+ return k_new, v_new
255
+
256
+
257
+ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
258
+ """
259
+ attention = softmax(QK^T/sqrt(d))V
260
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
261
+ """
262
+ logger.info("Using QKV to rebuild original softmax")
263
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
264
+ softmax_res, x_max, x_sum = softmax_forward(qk)
265
+ return softmax_res
266
+
267
+
268
+ def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
269
+ """
270
+ attention = softmax(QK^T/sqrt(d))V
271
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
272
+ """
273
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
274
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
275
+ if softmax_max.shape[-1] == 0:
276
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
277
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
278
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
279
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
280
+ return softmax_res
281
+
282
+
283
+ def get_head_num(*args, **kwargs):
284
+ if kwargs.get("head_num", None):
285
+ head_num = kwargs.get("head_num")
286
+ elif len(args) >= 4:
287
+ head_num = args[3]
288
+ else:
289
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
290
+ return head_num
291
+
292
+
293
+ def get_input_layout(*args, **kwargs):
294
+ if kwargs.get("input_layout", None):
295
+ input_layout = kwargs.get("input_layout")
296
+ elif len(args) >= 5:
297
+ input_layout = args[4]
298
+ else:
299
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
300
+ return input_layout
301
+
302
+
303
+ def npu_fusion_attention_forward_patch(*args, **kwargs):
304
+ # query, key, value, head_num, input_layout
305
+ head_num = get_head_num(*args, **kwargs)
306
+ input_layout = get_input_layout(*args, **kwargs)
307
+
308
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], head_num, input_layout)
309
+ if N1 == N2 and S1 == S2:
310
+ logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
311
+ else:
312
+ logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
313
+ if not (N1 % N2 == 0 and N1 >= N2):
314
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
315
+
316
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
317
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
318
+
319
+ new_kwargs = {"keep_prob": 1,
320
+ "scale": kwargs.get("scale", 1 / (D ** 0.5)),
321
+ "sparse_mode": kwargs.get("sparse_mode", 0),
322
+ "prefix": kwargs.get("prefix"),
323
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
324
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
325
+ "pse": kwargs.get("pse"),
326
+ "padding_mask": kwargs.get("padding_mask"),
327
+ "atten_mask": kwargs.get("atten_mask")}
328
+
329
+ return args, dims_kwargs, new_kwargs
330
+
331
+
332
+ def npu_fusion_attention_backward_patch(*args, **kwargs):
333
+ if len(args) != 6:
334
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
335
+
336
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
337
+ if N1 == N2 and S1 == S2:
338
+ logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
339
+ else:
340
+ logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
341
+ if not (N1 % N2 == 0 and N1 >= N2):
342
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
343
+
344
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
345
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
346
+
347
+ new_kwargs = {"keep_prob": 1,
348
+ "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
349
+ "sparse_mode": kwargs.get("sparse_mode", 0),
350
+ "prefix": kwargs.get("prefix"),
351
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
352
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
353
+ "pse": kwargs.get("pse"),
354
+ "padding_mask": kwargs.get("padding_mask"),
355
+ "softmax_max": kwargs.get("softmax_max"),
356
+ "softmax_sum": kwargs.get("softmax_sum"),
357
+ "softmax_in": kwargs.get("softmax_in"),
358
+ "attention_in": kwargs.get("attention_in"),
359
+ "seed": kwargs.get("seed", 0),
360
+ "offset": kwargs.get("offset", 0),
361
+ "numels": kwargs.get("numels", 0),
362
+ "atten_mask": kwargs.get("atten_mask")}
363
+
364
+ return args, dims_kwargs, new_kwargs
365
+
366
+
367
+ def npu_fusion_attention(*args, **kwargs):
368
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
369
+ query, key, value = new_args[0], new_args[1], new_args[2]
370
+ input_layout = get_input_layout(*args, **kwargs)
371
+ N1 = dims_kwargs.get("N1")
372
+ N2 = dims_kwargs.get("N2")
373
+ S1 = dims_kwargs.get("S1")
374
+ S2 = dims_kwargs.get("S2")
375
+ B = dims_kwargs.get("B")
376
+ DTYPE = dims_kwargs.get("DTYPE")
377
+ atten_mask = new_kwargs.get("atten_mask")
378
+ keep_prob = new_kwargs.get("keep_prob")
379
+ sparse_mode = new_kwargs.get("sparse_mode")
380
+ pre_tockens = new_kwargs.get("pre_tockens")
381
+ next_tockens = new_kwargs.get("next_tockens")
382
+ pse = new_kwargs.get("pse")
383
+ scale = new_kwargs.get("scale")
384
+
385
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
386
+ query = convert_to_bnsd(query, N1, input_layout)
387
+ key = convert_to_bnsd(key, N2, input_layout)
388
+ value = convert_to_bnsd(value, N2, input_layout)
389
+ k_new, v_new = generate_kv(key, value, N1, N2)
390
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
391
+ drop_mask=None, atten_mask=atten_mask,
392
+ pse=pse, scale=scale,
393
+ keep_prob=keep_prob)
394
+ if out_golden.dim() == 5:
395
+ out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
396
+ out_golden.size(4))
397
+ out_golden = convert_from_bnsd(out_golden, input_layout)
398
+
399
+ return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
400
+
401
+
402
+ def npu_fusion_attention_grad(*args, **kwargs):
403
+ # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
404
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
405
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
406
+ N1 = dims_kwargs.get("N1")
407
+ N2 = dims_kwargs.get("N2")
408
+ S1 = dims_kwargs.get("S1")
409
+ S2 = dims_kwargs.get("S2")
410
+ B = dims_kwargs.get("B")
411
+ D = dims_kwargs.get("D")
412
+ DTYPE = dims_kwargs.get("DTYPE")
413
+ atten_mask = new_kwargs.get("atten_mask")
414
+ keep_prob = new_kwargs.get("keep_prob")
415
+ sparse_mode = new_kwargs.get("sparse_mode")
416
+ pre_tockens = new_kwargs.get("pre_tockens")
417
+ next_tockens = new_kwargs.get("next_tockens")
418
+ pse = new_kwargs.get("pse")
419
+ softmax_max = new_kwargs.get("softmax_max")
420
+ softmax_sum = new_kwargs.get("softmax_sum")
421
+ scale_value = new_kwargs.get("scale_value")
422
+
423
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
424
+ query = convert_to_bnsd(query, N1, input_layout)
425
+ dx = convert_to_bnsd(dx, N1, input_layout)
426
+ key = convert_to_bnsd(key, N2, input_layout)
427
+ value = convert_to_bnsd(value, N2, input_layout)
428
+ k_new, v_new = generate_kv(key, value, N1, N2)
429
+
430
+ if softmax_build_mode == "QKV":
431
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
432
+ else:
433
+ softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
434
+
435
+ dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
436
+
437
+ # N不等长适配by cdy
438
+ if not (N1 == N2):
439
+ if N2 == 0:
440
+ raise ValueError("dims_kwargs.N2 must be non-zero.")
441
+ G = int(N1 / N2)
442
+ dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
443
+ dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
444
+
445
+ if dq.dim() == 5:
446
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
447
+ if dk.dim() == 5:
448
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
449
+ if dv.dim() == 5:
450
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
451
+
452
+ dq = convert_from_bnsd(dq, input_layout)
453
+ dk = convert_from_bnsd(dk, input_layout)
454
+ dv = convert_from_bnsd(dv, input_layout)
455
+
456
+ return dq.cpu(), dk.cpu(), dv.cpu()
457
+
458
+
459
+ def is_attention_off_due_to_mask(atten_mask_dtype):
460
+ return not atten_mask_dtype
461
+
462
+
463
+ def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1):
464
+ return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < S1)
465
+
466
+
467
+ def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2):
468
+ return sparse_mode == 0 and pre_tockens >= S1 and next_tockens >= S2
469
+
470
+
471
+ def gpu_fusion_attention(*args, **kwargs):
472
+ deterministic = False
473
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
474
+ query, key, value = new_args[0], new_args[1], new_args[2]
475
+ keep_prob = new_kwargs.get("keep_prob", 1.0)
476
+ scale = new_kwargs.get("scale")
477
+ N1 = dims_kwargs.get("N1")
478
+ N2 = dims_kwargs.get("N2")
479
+ S1 = dims_kwargs.get("S1")
480
+ S2 = dims_kwargs.get("S2")
481
+ B = dims_kwargs.get("B")
482
+ pse = new_kwargs.get("pse")
483
+ sparse_mode = new_kwargs.get("sparse_mode")
484
+ pre_tockens = new_kwargs.get("pre_tockens")
485
+ next_tockens = new_kwargs.get("next_tockens")
486
+ attn_mask = new_kwargs.get("atten_mask")
487
+ atten_mask_dtype = attn_mask.dtype if new_kwargs.get("atten_mask") is not None else None
488
+ pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
489
+ next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
490
+ atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
491
+ is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1) or
492
+ is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2))
493
+ causal_switch = not atten_off
494
+ if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
495
+ window_left = pre_tockens
496
+ window_right = next_tockens
497
+ else:
498
+ pre_tockens = next_tockens = CompareConst.MAX_TOKENS
499
+ window_left = pre_tockens - S1 + S2
500
+ window_right = next_tockens + S1 - S2
501
+
502
+ if pse is not None:
503
+ alibi_slopes = torch.rand(B, N1, dtype=torch.float32) * 0.3
504
+ else:
505
+ alibi_slopes = None
506
+
507
+ out = flash_attn_func(query, key, value, dropout_p=(1-keep_prob), softmax_scale=scale, causal=causal_switch,
508
+ window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic)
509
+ return out, Const.NONE, Const.NONE
@@ -0,0 +1,15 @@
1
+ import torch
2
+
3
+
4
+ def npu_rms_norm(x, gamma, epsilon=1e-5):
5
+ rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon)
6
+ res = x * rstd * gamma
7
+ return res, rstd.float()
8
+
9
+
10
+ def npu_rms_norm_backward(grad, x, gamma, rstd):
11
+ mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True)
12
+ grad_x = (grad * gamma - x * rstd * mean_gy) * rstd
13
+ grad_gamma = x * grad * rstd
14
+ return grad_x.cpu(), grad_gamma.cpu()
15
+
@@ -0,0 +1,52 @@
1
+ import torch
2
+
3
+
4
+ def npu_rotary_mul(x, r1, r2):
5
+ x1, x2 = torch.chunk(x, 2, -1)
6
+ x_new = torch.cat((-x2, x1), dim=-1)
7
+ output = r1 * x + r2 * x_new
8
+ return output
9
+
10
+
11
+ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
12
+ x.requires_grad = True
13
+ r1.requires_grad = True
14
+ r2.requires_grad = True
15
+ # golden
16
+ x1, x2 = torch.chunk(x, 2, -1)
17
+ x_new = torch.cat((-x2, x1), dim=-1)
18
+ golden_tensor = r1 * x + r2 * x_new
19
+ golden_tensor.backward(dy_tensor)
20
+ r1_shape = r1.shape
21
+ r1_grad = torch.zeros(r1_shape).type(torch.float32)
22
+ r2_grad = torch.zeros(r1_shape).type(torch.float32)
23
+ x1, x2 = torch.chunk(x.float(), 2, -1)
24
+ x_new2 = torch.cat((-x2, x1), dim=-1)
25
+ x_shape = x.shape
26
+ h = x.float()
27
+ grad = dy_tensor.float()
28
+ condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
29
+ ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
30
+ (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
31
+ condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
32
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
33
+ (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
34
+ condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
35
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
36
+ (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
37
+ if condition_1:
38
+ for i in range(x_shape[0]):
39
+ for j in range(x_shape[2]):
40
+ r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :])
41
+ r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :])
42
+ elif condition_2:
43
+ for i in range(x_shape[0]):
44
+ for j in range(x_shape[1]):
45
+ r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :])
46
+ r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :])
47
+ elif condition_3:
48
+ for i in range(x_shape[1]):
49
+ for j in range(x_shape[2]):
50
+ r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
51
+ r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
52
+ return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
@@ -0,0 +1,26 @@
1
+ import torch
2
+
3
+
4
+ def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask):
5
+ if fixed_triu_mask:
6
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
7
+ dtype = x.dtype
8
+ x = (x * scale).masked_fill(mask, value=-10000)
9
+ x = x - torch.max(x, dim=-1, keepdims=True)[0]
10
+ x = torch.exp(x.float())
11
+ y = torch.div(x, torch.sum(x, dim=-1, keepdims=True))
12
+ return y.to(dtype)
13
+
14
+
15
+ def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask):
16
+ if fixed_triu_mask:
17
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
18
+ dtype = y_grad.dtype
19
+ y_grad = y_grad.float()
20
+ y = y.float()
21
+ x_grad = y_grad * y
22
+ x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True)
23
+ x_grad = x_grad * y
24
+ x_grad = x_grad * scale
25
+ x_grad = x_grad.masked_fill(mask, value=0)
26
+ return x_grad.to(dtype).cpu()
@@ -0,0 +1,55 @@
1
+ import torch
2
+
3
+
4
+ def npu_swiglu(x, dim=-1):
5
+ tensor_dtype = x.dtype
6
+
7
+ inTensors = torch.chunk(x, 2, dim=dim)
8
+ if tensor_dtype == torch.float32:
9
+ tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
10
+ output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
11
+ else:
12
+ tensor_self_float = inTensors[0].type(torch.float)
13
+ tensor_other_float = inTensors[1].type(torch.float)
14
+ tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
15
+ torch.float32) * tensor_other_float
16
+ output_data = tensor_out_float.type(tensor_dtype)
17
+ return output_data
18
+
19
+
20
+ def npu_swiglu_backward(grad, x, dim=-1):
21
+ tensor_dtype = grad.dtype
22
+ in_tensors = torch.chunk(x, 2, dim=dim)
23
+ tensor_grad_out = grad
24
+
25
+ if tensor_dtype == torch.float16:
26
+ tensor_out1 = torch.mul(
27
+ torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))),
28
+ tensor_grad_out.type(torch.float32)).type(torch.float16)
29
+ tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32),
30
+ swish(1, in_tensors[0].type(torch.float32))).type(torch.float16)
31
+ output = torch.cat((tensor_out1, tensor_out2), dim)
32
+ elif tensor_dtype == torch.bfloat16:
33
+ tensor_self_float = in_tensors[0].type(torch.float)
34
+ tensor_other_float = in_tensors[1].type(torch.float)
35
+ tensor_gradout_float = tensor_grad_out.type(torch.float)
36
+
37
+ tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type(
38
+ torch.float32) * tensor_other_float
39
+ tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float
40
+ tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim)
41
+ output = tensor_out_float.type(torch.bfloat16)
42
+ else:
43
+ tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out)
44
+ tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0]))
45
+ output = torch.cat((tensor_out1, tensor_out2), dim)
46
+ return output.cpu()
47
+
48
+
49
+ def swish_grad(beta, x):
50
+ return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta
51
+
52
+
53
+ def swish(beta, x):
54
+ return x * torch.sigmoid(beta * x)
55
+
@@ -1,2 +1,2 @@
1
- from .parse_json import parse_json_info_forward_backward
2
- from .utils import seed_all
1
+ from .parse_json import parse_json_info_forward_backward
2
+ from .utils import seed_all