mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,421 +1,538 @@
1
- import torch
2
- import numpy as np
3
- from einops import rearrange
4
-
5
- from msprobe.pytorch.common.utils import logger
6
-
7
- gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
8
- softmax_build_mode = "QKV" # "MAX_SUM"
9
-
10
- """
11
- # 前向函数声明对比
12
- 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
13
- 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
14
- atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
15
- next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
16
- gen_mask_parallel=True, sync=False
17
-
18
- # 反向函数声明对比
19
- 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
20
- 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
21
- atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
22
- attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
23
- next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
24
- numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
25
- """
26
-
27
-
28
- def softmax_forward(x):
29
- x_max = torch.max(x, dim=-1, keepdims=True)[0]
30
- x_sub = x.sub(x_max)
31
- y = torch.exp(x_sub)
32
- x_sum = y.sum(dim=-1, keepdims=True)
33
- res = y.div(x_sum)
34
- return res, x_max, x_sum
35
-
36
-
37
- def softmax_grad(dp, softmax_res):
38
- muls = dp * softmax_res
39
- muls_r = muls.sum(dim=-1, keepdims=True)
40
- sub_r = dp - muls_r
41
- res = sub_r * softmax_res
42
- return res
43
-
44
-
45
- def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
46
- if num_kv_heads == 0 or num_kv_heads < num_heads:
47
- raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.")
48
-
49
- factor = num_heads // num_kv_heads
50
- kv_shape = kv_tensor.shape
51
- B = kv_shape[0]
52
- S = kv_shape[2]
53
- D = kv_shape[3]
54
- kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
55
- for i in range(num_heads):
56
- j = i // factor
57
- kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
58
- return kv_res
59
-
60
-
61
- def calculate_qk(q, k, atten_mask, pse, scale):
62
- if pse is None or len(pse.shape) == 0:
63
- qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
64
- else:
65
- qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
66
- if atten_mask is None or len(atten_mask.shape) == 0:
67
- return qk
68
- else:
69
- qk = qk + atten_mask.bool() * (-40000.0) # -10000
70
- return qk
71
-
72
-
73
- def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
74
- qk = calculate_qk(q, k, atten_mask, pse, scale)
75
- softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
76
- if drop_mask is None or len(drop_mask.shape) == 0:
77
- drop_res = softmax_res
78
- else:
79
- drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
80
- y = torch.matmul(drop_res, v)
81
- return y, softmax_max, softmax_sum
82
-
83
-
84
- def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
85
- dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
86
- if drop_mask is None or len(drop_mask.shape) == 0:
87
- drop_res = softmax_res.permute(0, 1, 3, 2)
88
- dp_drop = dp
89
- else:
90
- drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
91
- dp_drop = dp * drop_mask * (1.0 / keep_prob)
92
- dv = torch.matmul(drop_res, dx)
93
- softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
94
- dq = torch.matmul(softmax_grad_res, k)
95
- dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
96
- return dq, dk, dv
97
-
98
-
99
- def parse_bsnd_args(query, key, head_num, input_layout):
100
- supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
101
- B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
102
-
103
- if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
104
- raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
105
-
106
- if input_layout == "TND":
107
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
108
- try:
109
- if input_layout == "BSH":
110
- B, S1, H1 = query.shape
111
- _, S2, H2 = key.shape
112
- D = H1 // N1
113
- N2 = H2 // D
114
- elif input_layout == "SBH":
115
- S1, B, H1 = query.shape
116
- S2, _, H2 = key.shape
117
- D = H1 // N1
118
- N2 = H2 // D
119
- elif input_layout == "BSND":
120
- B, S1, N1, D = query.shape
121
- _, S2, N2, _ = key.shape
122
- H1 = N1 * D
123
- H2 = N2 * D
124
- elif input_layout == "BNSD":
125
- B, N1, S1, D = query.shape
126
- _, N2, S2, _ = key.shape
127
- H1 = N1 * D
128
- H2 = N2 * D
129
- except Exception as e:
130
- raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
131
-
132
- if D == 0:
133
- raise ValueError(f"Value D must be non-zero.")
134
- DTYPE = query.dtype
135
- return B, S1, S2, N1, N2, D, H1, H2, DTYPE
136
-
137
-
138
- def convert_from_bnsd(_input, input_layout):
139
- if input_layout == "BSH":
140
- # (B,N,S,D)=>(B,S,N*D)
141
- out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
142
- elif input_layout == "SBH":
143
- # (B,N,S,D)=>(S,B,N*D)
144
- out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
145
- elif input_layout == "BSND":
146
- # (B,N,S,D)=>(B,S,N,D)
147
- out = rearrange(_input, 'b n s d -> b s n d').contiguous()
148
- elif input_layout == "TND":
149
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
150
- else:
151
- out = _input
152
- return out
153
-
154
-
155
- def convert_to_bnsd(_input, n, input_layout):
156
- # 默认"BNSD"无需处理
157
- if input_layout == "BSH":
158
- # (B,S,N*D)=>(B,N,S,D)
159
- out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
160
- elif input_layout == "SBH":
161
- # (S,B,N*D)=>(B,N,S,D)
162
- out = rearrange(_input, 's b (n d) -> b n s d', n=n)
163
- elif input_layout == "BSND":
164
- # (B,S,N,D)=>(B,N,S,D)
165
- out = rearrange(_input, 'b s n d -> b n s d', n=n)
166
- elif input_layout == "TND":
167
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
168
- else:
169
- out = _input
170
- if out.dim() != 4:
171
- raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
172
- return out.to(gtype)
173
-
174
-
175
- def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
176
- """
177
- # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
178
- ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
179
- """
180
- shape = [S1, S2]
181
-
182
- if atten_mask is not None:
183
- # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
184
- if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
185
- logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
186
-
187
- if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
188
- if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
189
- if sparse_mode == 2:
190
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
191
- elif sparse_mode == 3:
192
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
193
- elif sparse_mode == 4:
194
- atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
195
- atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
196
- atten_mask = atten_mask_u + atten_mask_l
197
- logger.debug(f"反向转换atten_mask {atten_mask.shape}")
198
- return atten_mask.to(dtype)
199
-
200
- return atten_mask.to(dtype)
201
-
202
- if atten_mask is not None:
203
- if atten_mask.dim() == 2:
204
- if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
205
- raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
206
- shape = [S1, S2]
207
- elif atten_mask.dim() == 4:
208
- if atten_mask.shape[1] == 1:
209
- shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
210
- else:
211
- shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
212
-
213
- if sparse_mode == 0:
214
- atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
215
- atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
216
- atten_mask = atten_mask_u + atten_mask_l
217
- elif sparse_mode == 1: # no sparse
218
- atten_mask = torch.from_numpy(np.zeros(shape))
219
- elif sparse_mode == 2:
220
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
221
- elif sparse_mode == 3:
222
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
223
- elif sparse_mode == 4:
224
- atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
225
- atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
226
- atten_mask = atten_mask_u + atten_mask_l
227
- # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
228
- # 因此可以认为FA的输入已经是正确的atten_mask了
229
- return atten_mask.to(dtype)
230
-
231
-
232
- def generate_kv(key, value, N1, N2):
233
- # N不等长适配by cdy
234
- if not (N1 == N2):
235
- k_new = broadcast_kv(N1, N2, key, key.dtype)
236
- v_new = broadcast_kv(N1, N2, value, value.dtype)
237
- else:
238
- k_new = key
239
- v_new = value
240
- return k_new, v_new
241
-
242
-
243
- def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
244
- """
245
- attention = softmax(QK^T/sqrt(d))V
246
- softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
247
- """
248
- logger.info("Using QKV to rebuild original softmax")
249
- qk = calculate_qk(q, k, atten_mask, pse, scale)
250
- softmax_res, x_max, x_sum = softmax_forward(qk)
251
- return softmax_res
252
-
253
-
254
- def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
255
- """
256
- attention = softmax(QK^T/sqrt(d))V
257
- softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
258
- """
259
- logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
260
- qk = calculate_qk(q, k, atten_mask, pse, scale)
261
- if softmax_max.shape[-1] == 0:
262
- raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
263
- repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
264
- softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
265
- softmax_sum.repeat(1, 1, 1, repeat_dim))
266
- return softmax_res
267
-
268
-
269
- def npu_fusion_attention_forward_patch(*args, **kwargs):
270
- # query, key, value, head_num, input_layout
271
- if len(args) != 5:
272
- raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
273
-
274
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4])
275
- if N1 == N2 and S1 == S2:
276
- logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
277
- else:
278
- logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
279
- if not (N1 % N2 == 0 and N1 >= N2):
280
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
281
-
282
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
283
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
284
-
285
- new_kwargs = {"keep_prob": 1,
286
- "scale": kwargs.get("scale", 1 / (D ** 0.5)),
287
- "sparse_mode": kwargs.get("sparse_mode", 0),
288
- "prefix": kwargs.get("prefix"),
289
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
290
- "next_tockens": kwargs.get("next_tockens", 2147483647),
291
- "pse": kwargs.get("pse"),
292
- "padding_mask": kwargs.get("padding_mask"),
293
- "atten_mask": kwargs.get("atten_mask")}
294
-
295
- return args, dims_kwargs, new_kwargs
296
-
297
-
298
- def npu_fusion_attention_backward_patch(*args, **kwargs):
299
- if len(args) != 6:
300
- raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
301
-
302
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
303
- if N1 == N2 and S1 == S2:
304
- logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
305
- else:
306
- logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
307
- if not (N1 % N2 == 0 and N1 >= N2):
308
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
309
-
310
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
311
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
312
-
313
- new_kwargs = {"keep_prob": 1,
314
- "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
315
- "sparse_mode": kwargs.get("sparse_mode", 0),
316
- "prefix": kwargs.get("prefix"),
317
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
318
- "next_tockens": kwargs.get("next_tockens", 2147483647),
319
- "pse": kwargs.get("pse"),
320
- "padding_mask": kwargs.get("padding_mask"),
321
- "softmax_max": kwargs.get("softmax_max"),
322
- "softmax_sum": kwargs.get("softmax_sum"),
323
- "softmax_in": kwargs.get("softmax_in"),
324
- "attention_in": kwargs.get("attention_in"),
325
- "seed": kwargs.get("seed", 0),
326
- "offset": kwargs.get("offset", 0),
327
- "numels": kwargs.get("numels", 0),
328
- "atten_mask": kwargs.get("atten_mask")}
329
-
330
- return args, dims_kwargs, new_kwargs
331
-
332
-
333
- def npu_fusion_attention(*args, **kwargs):
334
- new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
335
- query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4]
336
- N1 = dims_kwargs.get("N1")
337
- N2 = dims_kwargs.get("N2")
338
- S1 = dims_kwargs.get("S1")
339
- S2 = dims_kwargs.get("S2")
340
- B = dims_kwargs.get("B")
341
- DTYPE = dims_kwargs.get("DTYPE")
342
- atten_mask = new_kwargs.get("atten_mask")
343
- keep_prob = new_kwargs.get("keep_prob")
344
- sparse_mode = new_kwargs.get("sparse_mode")
345
- pre_tockens = new_kwargs.get("pre_tockens")
346
- next_tockens = new_kwargs.get("next_tockens")
347
- pse = new_kwargs.get("pse")
348
- scale = new_kwargs.get("scale")
349
-
350
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
351
- query = convert_to_bnsd(query, N1, input_layout)
352
- key = convert_to_bnsd(key, N2, input_layout)
353
- value = convert_to_bnsd(value, N2, input_layout)
354
- k_new, v_new = generate_kv(key, value, N1, N2)
355
- out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
356
- drop_mask=None, atten_mask=atten_mask,
357
- pse=pse, scale=scale,
358
- keep_prob=keep_prob)
359
- if out_golden.dim() == 5:
360
- out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
361
- out_golden.size(4))
362
- out_golden = convert_from_bnsd(out_golden, input_layout)
363
-
364
- return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
365
-
366
-
367
- def npu_fusion_attention_grad(*args, **kwargs):
368
- # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
369
- new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
370
- query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
371
- N1 = dims_kwargs.get("N1")
372
- N2 = dims_kwargs.get("N2")
373
- S1 = dims_kwargs.get("S1")
374
- S2 = dims_kwargs.get("S2")
375
- B = dims_kwargs.get("B")
376
- D = dims_kwargs.get("D")
377
- DTYPE = dims_kwargs.get("DTYPE")
378
- atten_mask = new_kwargs.get("atten_mask")
379
- keep_prob = new_kwargs.get("keep_prob")
380
- sparse_mode = new_kwargs.get("sparse_mode")
381
- pre_tockens = new_kwargs.get("pre_tockens")
382
- next_tockens = new_kwargs.get("next_tockens")
383
- pse = new_kwargs.get("pse")
384
- softmax_max = new_kwargs.get("softmax_max")
385
- softmax_sum = new_kwargs.get("softmax_sum")
386
- scale_value = new_kwargs.get("scale_value")
387
-
388
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
389
- query = convert_to_bnsd(query, N1, input_layout)
390
- dx = convert_to_bnsd(dx, N1, input_layout)
391
- key = convert_to_bnsd(key, N2, input_layout)
392
- value = convert_to_bnsd(value, N2, input_layout)
393
- k_new, v_new = generate_kv(key, value, N1, N2)
394
-
395
- if softmax_build_mode == "QKV":
396
- softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
397
- else:
398
- softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
399
-
400
- dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
401
-
402
- # N不等长适配by cdy
403
- if not (N1 == N2):
404
- if N2 == 0:
405
- raise ValueError("dims_kwargs.N2 must be non-zero.")
406
- G = int(N1 / N2)
407
- dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
408
- dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
409
-
410
- if dq.dim() == 5:
411
- dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
412
- if dk.dim() == 5:
413
- dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
414
- if dv.dim() == 5:
415
- dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
416
-
417
- dq = convert_from_bnsd(dq, input_layout)
418
- dk = convert_from_bnsd(dk, input_layout)
419
- dv = convert_from_bnsd(dv, input_layout)
420
-
421
- return dq.cpu(), dk.cpu(), dv.cpu()
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ # 前向函数声明对比
18
+ 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
19
+ 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
20
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
21
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
22
+ gen_mask_parallel=True, sync=False
23
+
24
+ # 反向函数声明对比
25
+ 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
26
+ 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
27
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
28
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
30
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ from einops import rearrange
36
+
37
+ try:
38
+ import torch_npu
39
+ except ImportError:
40
+ is_gpu = True
41
+ try:
42
+ # flash_attn为gpu的fa三方库
43
+ from flash_attn import flash_attn_func
44
+ except ImportError:
45
+ # 如果为cpu的ut环境,则不做任何处理
46
+ pass
47
+ else:
48
+ is_gpu = False
49
+
50
+ from msprobe.pytorch.common.utils import logger
51
+ from msprobe.core.common.const import Const, CompareConst
52
+
53
+ gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
54
+ softmax_build_mode = "QKV" # "MAX_SUM"
55
+
56
+
57
+ def softmax_forward(x):
58
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
59
+ x_sub = x.sub(x_max)
60
+ y = torch.exp(x_sub)
61
+ x_sum = y.sum(dim=-1, keepdims=True)
62
+ res = y.div(x_sum)
63
+ return res, x_max, x_sum
64
+
65
+
66
+ def softmax_grad(dp, softmax_res):
67
+ muls = dp * softmax_res
68
+ muls_r = muls.sum(dim=-1, keepdims=True)
69
+ sub_r = dp - muls_r
70
+ res = sub_r * softmax_res
71
+ return res
72
+
73
+
74
+ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
75
+ if num_kv_heads == 0 or num_kv_heads > num_heads:
76
+ raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.")
77
+
78
+ factor = num_heads // num_kv_heads
79
+ kv_shape = kv_tensor.shape
80
+ b = kv_shape[0]
81
+ s = kv_shape[2]
82
+ d = kv_shape[3]
83
+ kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
84
+ for i in range(num_heads):
85
+ j = i // factor
86
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
87
+ return kv_res
88
+
89
+
90
+ def calculate_qk(q, k, atten_mask, pse, scale):
91
+ if pse is None or len(pse.shape) == 0:
92
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
93
+ else:
94
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
95
+ if atten_mask is None or len(atten_mask.shape) == 0:
96
+ return qk
97
+ else:
98
+ qk = qk + atten_mask.bool() * (-40000.0) # -10000
99
+ return qk
100
+
101
+
102
+ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
103
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
104
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
105
+ if drop_mask is None or len(drop_mask.shape) == 0:
106
+ drop_res = softmax_res
107
+ else:
108
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
109
+ y = torch.matmul(drop_res, v)
110
+ return y, softmax_max, softmax_sum
111
+
112
+
113
+ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
114
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
115
+ if drop_mask is None or len(drop_mask.shape) == 0:
116
+ drop_res = softmax_res.permute(0, 1, 3, 2)
117
+ dp_drop = dp
118
+ else:
119
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
120
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
121
+ dv = torch.matmul(drop_res, dx)
122
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
123
+ dq = torch.matmul(softmax_grad_res, k)
124
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
125
+ return dq, dk, dv
126
+
127
+
128
+ def parse_bsnd_args(query, key, head_num, input_layout):
129
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
130
+ b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
131
+
132
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
133
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
134
+
135
+ if input_layout == "TND":
136
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
137
+ try:
138
+ if input_layout == "BSH":
139
+ b, s1, h1 = query.shape
140
+ _, s2, h2 = key.shape
141
+ d = h1 // n1
142
+ n2 = h2 // d
143
+ elif input_layout == "SBH":
144
+ s1, b, h1 = query.shape
145
+ s2, _, h2 = key.shape
146
+ d = h1 // n1
147
+ n2 = h2 // d
148
+ elif input_layout == "BSND":
149
+ b, s1, n1, d = query.shape
150
+ _, s2, n2, _ = key.shape
151
+ h1 = n1 * d
152
+ h2 = n2 * d
153
+ elif input_layout == "BNSD":
154
+ b, n1, s1, d = query.shape
155
+ _, n2, s2, _ = key.shape
156
+ h1 = n1 * d
157
+ h2 = n2 * d
158
+ except Exception as e:
159
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
160
+
161
+ if d == 0:
162
+ raise ValueError(f"Value d must be non-zero.")
163
+ _dtype = query.dtype
164
+ ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
165
+ return ret
166
+
167
+
168
+ def convert_from_bnsd(_input, input_layout):
169
+ if input_layout == "BSH":
170
+ # (B,N,S,D)=>(B,S,N*D)
171
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
172
+ elif input_layout == "SBH":
173
+ # (B,N,S,D)=>(S,B,N*D)
174
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
175
+ elif input_layout == "BSND":
176
+ # (B,N,S,D)=>(B,S,N,D)
177
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
178
+ elif input_layout == "TND":
179
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
180
+ else:
181
+ out = _input
182
+ return out
183
+
184
+
185
+ def convert_to_bnsd(_input, n, input_layout):
186
+ # 默认"BNSD"无需处理
187
+ if input_layout == "BSH":
188
+ # (B,S,N*D)=>(B,N,S,D)
189
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
190
+ elif input_layout == "SBH":
191
+ # (S,B,N*D)=>(B,N,S,D)
192
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
193
+ elif input_layout == "BSND":
194
+ # (B,S,N,D)=>(B,N,S,D)
195
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
196
+ elif input_layout == "TND":
197
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
198
+ else:
199
+ out = _input
200
+ if out.dim() != 4:
201
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
202
+ return out.to(gtype)
203
+
204
+
205
+ def generate_atten_mask(*args):
206
+ """
207
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
208
+ ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
209
+ """
210
+
211
+ sparse_mode, atten_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
212
+ shape = [s1, s2]
213
+
214
+ if atten_mask is not None:
215
+ # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
216
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
217
+ logger.info(f"s1: {s1}, s2:{s2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
218
+
219
+ if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
220
+ if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
221
+ if sparse_mode == 2:
222
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
223
+ elif sparse_mode == 3:
224
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
225
+ elif sparse_mode == 4:
226
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
227
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
228
+ atten_mask = atten_mask_u + atten_mask_l
229
+ logger.debug(f"反向转换atten_mask {atten_mask.shape}")
230
+ return atten_mask.to(dtype)
231
+
232
+ return atten_mask.to(dtype)
233
+
234
+ if atten_mask is not None:
235
+ if atten_mask.dim() == 2:
236
+ if atten_mask.shape[0] != s1 or atten_mask.shape[1] != s2:
237
+ raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
238
+ shape = [s1, s2]
239
+ elif atten_mask.dim() == 4:
240
+ if atten_mask.shape[1] == 1:
241
+ shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
242
+ else:
243
+ shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
244
+
245
+ if sparse_mode == 0:
246
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
247
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
248
+ atten_mask = atten_mask_u + atten_mask_l
249
+ elif sparse_mode == 1: # no sparse
250
+ atten_mask = torch.from_numpy(np.zeros(shape))
251
+ elif sparse_mode == 2:
252
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
253
+ elif sparse_mode == 3:
254
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
255
+ elif sparse_mode == 4:
256
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
257
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
258
+ atten_mask = atten_mask_u + atten_mask_l
259
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
260
+ # 因此可以认为FA的输入已经是正确的atten_mask
261
+ return atten_mask.to(dtype)
262
+
263
+
264
+ def generate_kv(key, value, n1, n2):
265
+ # N不等长适配by cdy
266
+ if not (n1 == n2):
267
+ k_new = broadcast_kv(n1, n2, key, key.dtype)
268
+ v_new = broadcast_kv(n1, n2, value, value.dtype)
269
+ else:
270
+ k_new = key
271
+ v_new = value
272
+ return k_new, v_new
273
+
274
+
275
+ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
276
+ """
277
+ attention = softmax(QK^T/sqrt(d))V
278
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
279
+ """
280
+ logger.info("Using QKV to rebuild original softmax")
281
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
282
+ softmax_res, x_max, x_sum = softmax_forward(qk)
283
+ return softmax_res
284
+
285
+
286
+ def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
287
+ """
288
+ attention = softmax(QK^T/sqrt(d))V
289
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
290
+ """
291
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
292
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
293
+ if softmax_max.shape[-1] == 0:
294
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
295
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
296
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
297
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
298
+ return softmax_res
299
+
300
+
301
+ def get_head_num(*args, **kwargs):
302
+ if kwargs.get("head_num", None):
303
+ head_num = kwargs.get("head_num")
304
+ elif len(args) >= 4:
305
+ head_num = args[3]
306
+ else:
307
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
308
+ return head_num
309
+
310
+
311
+ def get_input_layout(*args, **kwargs):
312
+ if kwargs.get("input_layout", None):
313
+ input_layout = kwargs.get("input_layout")
314
+ elif len(args) >= 5:
315
+ input_layout = args[4]
316
+ else:
317
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
318
+ return input_layout
319
+
320
+
321
+ def npu_fusion_attention_forward_patch(*args, **kwargs):
322
+ # query, key, value, head_num, input_layout
323
+ head_num = get_head_num(*args, **kwargs)
324
+ input_layout = get_input_layout(*args, **kwargs)
325
+
326
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
327
+ if n1 == n2 and s1 == s2:
328
+ logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
329
+ else:
330
+ logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
331
+ if not (n1 % n2 == 0 and n1 >= n2):
332
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
333
+
334
+ dims_kwargs = {
335
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
336
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
337
+ }
338
+
339
+ new_kwargs = {
340
+ "keep_prob": 1,
341
+ "scale": kwargs.get("scale", 1 / (d ** 0.5)),
342
+ "sparse_mode": kwargs.get("sparse_mode", 0),
343
+ "prefix": kwargs.get("prefix"),
344
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
345
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
346
+ "pse": kwargs.get("pse"),
347
+ "padding_mask": kwargs.get("padding_mask"),
348
+ "atten_mask": kwargs.get("atten_mask")
349
+ }
350
+
351
+ return args, dims_kwargs, new_kwargs
352
+
353
+
354
+ def npu_fusion_attention_backward_patch(*args, **kwargs):
355
+ if len(args) != 6:
356
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
357
+
358
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
359
+ if n1 == n2 and s1 == s2:
360
+ logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
361
+ else:
362
+ logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
363
+ if not (n1 % n2 == 0 and n1 >= n2):
364
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
365
+
366
+ dims_kwargs = {
367
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
368
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
369
+ }
370
+
371
+ new_kwargs = {
372
+ "keep_prob": 1,
373
+ "scale_value": kwargs.get("scale_value", 1 / (d ** 0.5)),
374
+ "sparse_mode": kwargs.get("sparse_mode", 0),
375
+ "prefix": kwargs.get("prefix"),
376
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
377
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
378
+ "pse": kwargs.get("pse"),
379
+ "padding_mask": kwargs.get("padding_mask"),
380
+ "softmax_max": kwargs.get("softmax_max"),
381
+ "softmax_sum": kwargs.get("softmax_sum"),
382
+ "softmax_in": kwargs.get("softmax_in"),
383
+ "attention_in": kwargs.get("attention_in"),
384
+ "seed": kwargs.get("seed", 0),
385
+ "offset": kwargs.get("offset", 0),
386
+ "numels": kwargs.get("numels", 0),
387
+ "atten_mask": kwargs.get("atten_mask")
388
+ }
389
+
390
+ return args, dims_kwargs, new_kwargs
391
+
392
+
393
+ def npu_fusion_attention(*args, **kwargs):
394
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
395
+ query, key, value = new_args[0], new_args[1], new_args[2]
396
+ input_layout = get_input_layout(*args, **kwargs)
397
+ n1 = dims_kwargs.get("n1")
398
+ n2 = dims_kwargs.get("n2")
399
+ s1 = dims_kwargs.get("s1")
400
+ s2 = dims_kwargs.get("s2")
401
+ b = dims_kwargs.get("b")
402
+ dtype = dims_kwargs.get("dtype")
403
+ atten_mask = new_kwargs.get("atten_mask")
404
+ keep_prob = new_kwargs.get("keep_prob")
405
+ sparse_mode = new_kwargs.get("sparse_mode")
406
+ pre_tockens = new_kwargs.get("pre_tockens")
407
+ next_tockens = new_kwargs.get("next_tockens")
408
+ pse = new_kwargs.get("pse")
409
+ scale = new_kwargs.get("scale")
410
+ args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
411
+ atten_mask = generate_atten_mask(*args_temp)
412
+ query = convert_to_bnsd(query, n1, input_layout)
413
+ key = convert_to_bnsd(key, n2, input_layout)
414
+ value = convert_to_bnsd(value, n2, input_layout)
415
+ k_new, v_new = generate_kv(key, value, n1, n2)
416
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
417
+ drop_mask=None, atten_mask=atten_mask,
418
+ pse=pse, scale=scale,
419
+ keep_prob=keep_prob)
420
+ if out_golden.dim() == 5:
421
+ out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
422
+ out_golden.size(4))
423
+ out_golden = convert_from_bnsd(out_golden, input_layout)
424
+
425
+ return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
426
+
427
+
428
+ def npu_fusion_attention_grad(*args, **kwargs):
429
+ # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
430
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
431
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
432
+ n1 = dims_kwargs.get("n1")
433
+ n2 = dims_kwargs.get("n2")
434
+ s1 = dims_kwargs.get("s1")
435
+ s2 = dims_kwargs.get("s2")
436
+ b = dims_kwargs.get("b")
437
+ d = dims_kwargs.get("d")
438
+ dtype = dims_kwargs.get("dtype")
439
+ atten_mask = new_kwargs.get("atten_mask")
440
+ keep_prob = new_kwargs.get("keep_prob")
441
+ sparse_mode = new_kwargs.get("sparse_mode")
442
+ pre_tockens = new_kwargs.get("pre_tockens")
443
+ next_tockens = new_kwargs.get("next_tockens")
444
+ pse = new_kwargs.get("pse")
445
+ softmax_max = new_kwargs.get("softmax_max")
446
+ softmax_sum = new_kwargs.get("softmax_sum")
447
+ scale_value = new_kwargs.get("scale_value")
448
+
449
+ args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
450
+ atten_mask = generate_atten_mask(*args_temp)
451
+ query = convert_to_bnsd(query, n1, input_layout)
452
+ dx = convert_to_bnsd(dx, n1, input_layout)
453
+ key = convert_to_bnsd(key, n2, input_layout)
454
+ value = convert_to_bnsd(value, n2, input_layout)
455
+ k_new, v_new = generate_kv(key, value, n1, n2)
456
+
457
+ if softmax_build_mode == "QKV":
458
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
459
+ else:
460
+ softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
461
+
462
+ dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
463
+
464
+ # N不等长适配by cdy
465
+ if not (n1 == n2):
466
+ if n2 == 0:
467
+ raise ValueError("dims_kwargs.n2 must be non-zero.")
468
+ g = int(n1 / n2)
469
+ dk = torch.sum(dk.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
470
+ dv = torch.sum(dv.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
471
+
472
+ if dq.dim() == 5:
473
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
474
+ if dk.dim() == 5:
475
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
476
+ if dv.dim() == 5:
477
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
478
+
479
+ dq = convert_from_bnsd(dq, input_layout)
480
+ dk = convert_from_bnsd(dk, input_layout)
481
+ dv = convert_from_bnsd(dv, input_layout)
482
+
483
+ return dq.cpu(), dk.cpu(), dv.cpu()
484
+
485
+
486
+ def is_attention_off_due_to_mask(atten_mask_dtype):
487
+ return not atten_mask_dtype
488
+
489
+
490
+ def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1):
491
+ return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < s1)
492
+
493
+
494
+ def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2):
495
+ return sparse_mode == 0 and pre_tockens >= s1 and next_tockens >= s2
496
+
497
+
498
+ def gpu_fusion_attention(*args, **kwargs):
499
+ deterministic = False
500
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
501
+ query, key, value = new_args[0], new_args[1], new_args[2]
502
+ keep_prob = new_kwargs.get("keep_prob", 1.0)
503
+ scale = new_kwargs.get("scale")
504
+ n1 = dims_kwargs.get("n1")
505
+ n2 = dims_kwargs.get("n2")
506
+ s1 = dims_kwargs.get("s1")
507
+ s2 = dims_kwargs.get("s2")
508
+ b = dims_kwargs.get("b")
509
+ pse = new_kwargs.get("pse")
510
+ sparse_mode = new_kwargs.get("sparse_mode")
511
+ pre_tockens = new_kwargs.get("pre_tockens")
512
+ next_tockens = new_kwargs.get("next_tockens")
513
+ attn_mask = new_kwargs.get("atten_mask")
514
+ atten_mask_dtype = attn_mask.dtype if new_kwargs.get("atten_mask") is not None else None
515
+ pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
516
+ next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
517
+ atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
518
+ is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1) or
519
+ is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2))
520
+ causal_switch = not atten_off
521
+ if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
522
+ window_left = pre_tockens
523
+ window_right = next_tockens
524
+ else:
525
+ pre_tockens = next_tockens = CompareConst.MAX_TOKENS
526
+ window_left = pre_tockens - s1 + s2
527
+ window_right = next_tockens + s1 - s2
528
+
529
+ if pse is not None:
530
+ alibi_slopes = torch.rand(b, n1, dtype=torch.float32) * 0.3
531
+ else:
532
+ alibi_slopes = None
533
+
534
+ out = flash_attn_func(
535
+ query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
536
+ window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
537
+ )
538
+ return out, Const.NONE, Const.NONE