mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,421 +1,509 @@
1
- import torch
2
- import numpy as np
3
- from einops import rearrange
4
-
5
- from msprobe.pytorch.common.utils import logger
6
-
7
- gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
8
- softmax_build_mode = "QKV" # "MAX_SUM"
9
-
10
- """
11
- # 前向函数声明对比
12
- 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
13
- 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
14
- atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
15
- next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
16
- gen_mask_parallel=True, sync=False
17
-
18
- # 反向函数声明对比
19
- 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
20
- 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
21
- atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
22
- attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
23
- next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
24
- numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
25
- """
26
-
27
-
28
- def softmax_forward(x):
29
- x_max = torch.max(x, dim=-1, keepdims=True)[0]
30
- x_sub = x.sub(x_max)
31
- y = torch.exp(x_sub)
32
- x_sum = y.sum(dim=-1, keepdims=True)
33
- res = y.div(x_sum)
34
- return res, x_max, x_sum
35
-
36
-
37
- def softmax_grad(dp, softmax_res):
38
- muls = dp * softmax_res
39
- muls_r = muls.sum(dim=-1, keepdims=True)
40
- sub_r = dp - muls_r
41
- res = sub_r * softmax_res
42
- return res
43
-
44
-
45
- def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
46
- if num_kv_heads == 0 or num_kv_heads < num_heads:
47
- raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.")
48
-
49
- factor = num_heads // num_kv_heads
50
- kv_shape = kv_tensor.shape
51
- B = kv_shape[0]
52
- S = kv_shape[2]
53
- D = kv_shape[3]
54
- kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
55
- for i in range(num_heads):
56
- j = i // factor
57
- kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
58
- return kv_res
59
-
60
-
61
- def calculate_qk(q, k, atten_mask, pse, scale):
62
- if pse is None or len(pse.shape) == 0:
63
- qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
64
- else:
65
- qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
66
- if atten_mask is None or len(atten_mask.shape) == 0:
67
- return qk
68
- else:
69
- qk = qk + atten_mask.bool() * (-40000.0) # -10000
70
- return qk
71
-
72
-
73
- def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
74
- qk = calculate_qk(q, k, atten_mask, pse, scale)
75
- softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
76
- if drop_mask is None or len(drop_mask.shape) == 0:
77
- drop_res = softmax_res
78
- else:
79
- drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
80
- y = torch.matmul(drop_res, v)
81
- return y, softmax_max, softmax_sum
82
-
83
-
84
- def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
85
- dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
86
- if drop_mask is None or len(drop_mask.shape) == 0:
87
- drop_res = softmax_res.permute(0, 1, 3, 2)
88
- dp_drop = dp
89
- else:
90
- drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
91
- dp_drop = dp * drop_mask * (1.0 / keep_prob)
92
- dv = torch.matmul(drop_res, dx)
93
- softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
94
- dq = torch.matmul(softmax_grad_res, k)
95
- dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
96
- return dq, dk, dv
97
-
98
-
99
- def parse_bsnd_args(query, key, head_num, input_layout):
100
- supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
101
- B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
102
-
103
- if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
104
- raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
105
-
106
- if input_layout == "TND":
107
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
108
- try:
109
- if input_layout == "BSH":
110
- B, S1, H1 = query.shape
111
- _, S2, H2 = key.shape
112
- D = H1 // N1
113
- N2 = H2 // D
114
- elif input_layout == "SBH":
115
- S1, B, H1 = query.shape
116
- S2, _, H2 = key.shape
117
- D = H1 // N1
118
- N2 = H2 // D
119
- elif input_layout == "BSND":
120
- B, S1, N1, D = query.shape
121
- _, S2, N2, _ = key.shape
122
- H1 = N1 * D
123
- H2 = N2 * D
124
- elif input_layout == "BNSD":
125
- B, N1, S1, D = query.shape
126
- _, N2, S2, _ = key.shape
127
- H1 = N1 * D
128
- H2 = N2 * D
129
- except Exception as e:
130
- raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
131
-
132
- if D == 0:
133
- raise ValueError(f"Value D must be non-zero.")
134
- DTYPE = query.dtype
135
- return B, S1, S2, N1, N2, D, H1, H2, DTYPE
136
-
137
-
138
- def convert_from_bnsd(_input, input_layout):
139
- if input_layout == "BSH":
140
- # (B,N,S,D)=>(B,S,N*D)
141
- out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
142
- elif input_layout == "SBH":
143
- # (B,N,S,D)=>(S,B,N*D)
144
- out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
145
- elif input_layout == "BSND":
146
- # (B,N,S,D)=>(B,S,N,D)
147
- out = rearrange(_input, 'b n s d -> b s n d').contiguous()
148
- elif input_layout == "TND":
149
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
150
- else:
151
- out = _input
152
- return out
153
-
154
-
155
- def convert_to_bnsd(_input, n, input_layout):
156
- # 默认"BNSD"无需处理
157
- if input_layout == "BSH":
158
- # (B,S,N*D)=>(B,N,S,D)
159
- out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
160
- elif input_layout == "SBH":
161
- # (S,B,N*D)=>(B,N,S,D)
162
- out = rearrange(_input, 's b (n d) -> b n s d', n=n)
163
- elif input_layout == "BSND":
164
- # (B,S,N,D)=>(B,N,S,D)
165
- out = rearrange(_input, 'b s n d -> b n s d', n=n)
166
- elif input_layout == "TND":
167
- raise ValueError(f"input_layout {input_layout} does not supported for now.")
168
- else:
169
- out = _input
170
- if out.dim() != 4:
171
- raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
172
- return out.to(gtype)
173
-
174
-
175
- def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
176
- """
177
- # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
178
- ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
179
- """
180
- shape = [S1, S2]
181
-
182
- if atten_mask is not None:
183
- # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
184
- if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
185
- logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
186
-
187
- if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
188
- if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
189
- if sparse_mode == 2:
190
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
191
- elif sparse_mode == 3:
192
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
193
- elif sparse_mode == 4:
194
- atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
195
- atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
196
- atten_mask = atten_mask_u + atten_mask_l
197
- logger.debug(f"反向转换atten_mask {atten_mask.shape}")
198
- return atten_mask.to(dtype)
199
-
200
- return atten_mask.to(dtype)
201
-
202
- if atten_mask is not None:
203
- if atten_mask.dim() == 2:
204
- if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
205
- raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
206
- shape = [S1, S2]
207
- elif atten_mask.dim() == 4:
208
- if atten_mask.shape[1] == 1:
209
- shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
210
- else:
211
- shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
212
-
213
- if sparse_mode == 0:
214
- atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
215
- atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
216
- atten_mask = atten_mask_u + atten_mask_l
217
- elif sparse_mode == 1: # no sparse
218
- atten_mask = torch.from_numpy(np.zeros(shape))
219
- elif sparse_mode == 2:
220
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
221
- elif sparse_mode == 3:
222
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
223
- elif sparse_mode == 4:
224
- atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
225
- atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
226
- atten_mask = atten_mask_u + atten_mask_l
227
- # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
228
- # 因此可以认为FA的输入已经是正确的atten_mask了
229
- return atten_mask.to(dtype)
230
-
231
-
232
- def generate_kv(key, value, N1, N2):
233
- # N不等长适配by cdy
234
- if not (N1 == N2):
235
- k_new = broadcast_kv(N1, N2, key, key.dtype)
236
- v_new = broadcast_kv(N1, N2, value, value.dtype)
237
- else:
238
- k_new = key
239
- v_new = value
240
- return k_new, v_new
241
-
242
-
243
- def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
244
- """
245
- attention = softmax(QK^T/sqrt(d))V
246
- softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
247
- """
248
- logger.info("Using QKV to rebuild original softmax")
249
- qk = calculate_qk(q, k, atten_mask, pse, scale)
250
- softmax_res, x_max, x_sum = softmax_forward(qk)
251
- return softmax_res
252
-
253
-
254
- def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
255
- """
256
- attention = softmax(QK^T/sqrt(d))V
257
- softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
258
- """
259
- logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
260
- qk = calculate_qk(q, k, atten_mask, pse, scale)
261
- if softmax_max.shape[-1] == 0:
262
- raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
263
- repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
264
- softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
265
- softmax_sum.repeat(1, 1, 1, repeat_dim))
266
- return softmax_res
267
-
268
-
269
- def npu_fusion_attention_forward_patch(*args, **kwargs):
270
- # query, key, value, head_num, input_layout
271
- if len(args) != 5:
272
- raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
273
-
274
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4])
275
- if N1 == N2 and S1 == S2:
276
- logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
277
- else:
278
- logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
279
- if not (N1 % N2 == 0 and N1 >= N2):
280
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
281
-
282
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
283
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
284
-
285
- new_kwargs = {"keep_prob": 1,
286
- "scale": kwargs.get("scale", 1 / (D ** 0.5)),
287
- "sparse_mode": kwargs.get("sparse_mode", 0),
288
- "prefix": kwargs.get("prefix"),
289
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
290
- "next_tockens": kwargs.get("next_tockens", 2147483647),
291
- "pse": kwargs.get("pse"),
292
- "padding_mask": kwargs.get("padding_mask"),
293
- "atten_mask": kwargs.get("atten_mask")}
294
-
295
- return args, dims_kwargs, new_kwargs
296
-
297
-
298
- def npu_fusion_attention_backward_patch(*args, **kwargs):
299
- if len(args) != 6:
300
- raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
301
-
302
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
303
- if N1 == N2 and S1 == S2:
304
- logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
305
- else:
306
- logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
307
- if not (N1 % N2 == 0 and N1 >= N2):
308
- raise ValueError(f"N1N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
309
-
310
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
311
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
312
-
313
- new_kwargs = {"keep_prob": 1,
314
- "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
315
- "sparse_mode": kwargs.get("sparse_mode", 0),
316
- "prefix": kwargs.get("prefix"),
317
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
318
- "next_tockens": kwargs.get("next_tockens", 2147483647),
319
- "pse": kwargs.get("pse"),
320
- "padding_mask": kwargs.get("padding_mask"),
321
- "softmax_max": kwargs.get("softmax_max"),
322
- "softmax_sum": kwargs.get("softmax_sum"),
323
- "softmax_in": kwargs.get("softmax_in"),
324
- "attention_in": kwargs.get("attention_in"),
325
- "seed": kwargs.get("seed", 0),
326
- "offset": kwargs.get("offset", 0),
327
- "numels": kwargs.get("numels", 0),
328
- "atten_mask": kwargs.get("atten_mask")}
329
-
330
- return args, dims_kwargs, new_kwargs
331
-
332
-
333
- def npu_fusion_attention(*args, **kwargs):
334
- new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
335
- query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4]
336
- N1 = dims_kwargs.get("N1")
337
- N2 = dims_kwargs.get("N2")
338
- S1 = dims_kwargs.get("S1")
339
- S2 = dims_kwargs.get("S2")
340
- B = dims_kwargs.get("B")
341
- DTYPE = dims_kwargs.get("DTYPE")
342
- atten_mask = new_kwargs.get("atten_mask")
343
- keep_prob = new_kwargs.get("keep_prob")
344
- sparse_mode = new_kwargs.get("sparse_mode")
345
- pre_tockens = new_kwargs.get("pre_tockens")
346
- next_tockens = new_kwargs.get("next_tockens")
347
- pse = new_kwargs.get("pse")
348
- scale = new_kwargs.get("scale")
349
-
350
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
351
- query = convert_to_bnsd(query, N1, input_layout)
352
- key = convert_to_bnsd(key, N2, input_layout)
353
- value = convert_to_bnsd(value, N2, input_layout)
354
- k_new, v_new = generate_kv(key, value, N1, N2)
355
- out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
356
- drop_mask=None, atten_mask=atten_mask,
357
- pse=pse, scale=scale,
358
- keep_prob=keep_prob)
359
- if out_golden.dim() == 5:
360
- out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
361
- out_golden.size(4))
362
- out_golden = convert_from_bnsd(out_golden, input_layout)
363
-
364
- return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
365
-
366
-
367
- def npu_fusion_attention_grad(*args, **kwargs):
368
- # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
369
- new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
370
- query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
371
- N1 = dims_kwargs.get("N1")
372
- N2 = dims_kwargs.get("N2")
373
- S1 = dims_kwargs.get("S1")
374
- S2 = dims_kwargs.get("S2")
375
- B = dims_kwargs.get("B")
376
- D = dims_kwargs.get("D")
377
- DTYPE = dims_kwargs.get("DTYPE")
378
- atten_mask = new_kwargs.get("atten_mask")
379
- keep_prob = new_kwargs.get("keep_prob")
380
- sparse_mode = new_kwargs.get("sparse_mode")
381
- pre_tockens = new_kwargs.get("pre_tockens")
382
- next_tockens = new_kwargs.get("next_tockens")
383
- pse = new_kwargs.get("pse")
384
- softmax_max = new_kwargs.get("softmax_max")
385
- softmax_sum = new_kwargs.get("softmax_sum")
386
- scale_value = new_kwargs.get("scale_value")
387
-
388
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
389
- query = convert_to_bnsd(query, N1, input_layout)
390
- dx = convert_to_bnsd(dx, N1, input_layout)
391
- key = convert_to_bnsd(key, N2, input_layout)
392
- value = convert_to_bnsd(value, N2, input_layout)
393
- k_new, v_new = generate_kv(key, value, N1, N2)
394
-
395
- if softmax_build_mode == "QKV":
396
- softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
397
- else:
398
- softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
399
-
400
- dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
401
-
402
- # N不等长适配by cdy
403
- if not (N1 == N2):
404
- if N2 == 0:
405
- raise ValueError("dims_kwargs.N2 must be non-zero.")
406
- G = int(N1 / N2)
407
- dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
408
- dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
409
-
410
- if dq.dim() == 5:
411
- dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
412
- if dk.dim() == 5:
413
- dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
414
- if dv.dim() == 5:
415
- dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
416
-
417
- dq = convert_from_bnsd(dq, input_layout)
418
- dk = convert_from_bnsd(dk, input_layout)
419
- dv = convert_from_bnsd(dv, input_layout)
420
-
421
- return dq.cpu(), dk.cpu(), dv.cpu()
1
+ import torch
2
+ import numpy as np
3
+ from einops import rearrange
4
+ try:
5
+ import torch_npu
6
+ except ImportError:
7
+ is_gpu = True
8
+ try:
9
+ # flash_attn为gpu的fa三方库
10
+ from flash_attn import flash_attn_func
11
+ except ImportError:
12
+ #如果为cpu的ut环境,则不做任何处理
13
+ pass
14
+ else:
15
+ is_gpu = False
16
+
17
+
18
+ from msprobe.pytorch.common.utils import logger
19
+ from msprobe.core.common.const import Const, CompareConst
20
+
21
+ gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
22
+ softmax_build_mode = "QKV" # "MAX_SUM"
23
+
24
+ """
25
+ # 前向函数声明对比
26
+ 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
27
+ 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
28
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
30
+ gen_mask_parallel=True, sync=False
31
+
32
+ # 反向函数声明对比
33
+ 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
34
+ 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
35
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
36
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
37
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
38
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
39
+ """
40
+
41
+
42
+ def softmax_forward(x):
43
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
44
+ x_sub = x.sub(x_max)
45
+ y = torch.exp(x_sub)
46
+ x_sum = y.sum(dim=-1, keepdims=True)
47
+ res = y.div(x_sum)
48
+ return res, x_max, x_sum
49
+
50
+
51
+ def softmax_grad(dp, softmax_res):
52
+ muls = dp * softmax_res
53
+ muls_r = muls.sum(dim=-1, keepdims=True)
54
+ sub_r = dp - muls_r
55
+ res = sub_r * softmax_res
56
+ return res
57
+
58
+
59
+ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
60
+ if num_kv_heads == 0 or num_kv_heads > num_heads:
61
+ raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.")
62
+
63
+ factor = num_heads // num_kv_heads
64
+ kv_shape = kv_tensor.shape
65
+ B = kv_shape[0]
66
+ S = kv_shape[2]
67
+ D = kv_shape[3]
68
+ kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
69
+ for i in range(num_heads):
70
+ j = i // factor
71
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
72
+ return kv_res
73
+
74
+
75
+ def calculate_qk(q, k, atten_mask, pse, scale):
76
+ if pse is None or len(pse.shape) == 0:
77
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
78
+ else:
79
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
80
+ if atten_mask is None or len(atten_mask.shape) == 0:
81
+ return qk
82
+ else:
83
+ qk = qk + atten_mask.bool() * (-40000.0) # -10000
84
+ return qk
85
+
86
+
87
+ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
88
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
89
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
90
+ if drop_mask is None or len(drop_mask.shape) == 0:
91
+ drop_res = softmax_res
92
+ else:
93
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
94
+ y = torch.matmul(drop_res, v)
95
+ return y, softmax_max, softmax_sum
96
+
97
+
98
+ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
99
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
100
+ if drop_mask is None or len(drop_mask.shape) == 0:
101
+ drop_res = softmax_res.permute(0, 1, 3, 2)
102
+ dp_drop = dp
103
+ else:
104
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
105
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
106
+ dv = torch.matmul(drop_res, dx)
107
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
108
+ dq = torch.matmul(softmax_grad_res, k)
109
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
110
+ return dq, dk, dv
111
+
112
+
113
+ def parse_bsnd_args(query, key, head_num, input_layout):
114
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
115
+ B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
116
+
117
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
118
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
119
+
120
+ if input_layout == "TND":
121
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
122
+ try:
123
+ if input_layout == "BSH":
124
+ B, S1, H1 = query.shape
125
+ _, S2, H2 = key.shape
126
+ D = H1 // N1
127
+ N2 = H2 // D
128
+ elif input_layout == "SBH":
129
+ S1, B, H1 = query.shape
130
+ S2, _, H2 = key.shape
131
+ D = H1 // N1
132
+ N2 = H2 // D
133
+ elif input_layout == "BSND":
134
+ B, S1, N1, D = query.shape
135
+ _, S2, N2, _ = key.shape
136
+ H1 = N1 * D
137
+ H2 = N2 * D
138
+ elif input_layout == "BNSD":
139
+ B, N1, S1, D = query.shape
140
+ _, N2, S2, _ = key.shape
141
+ H1 = N1 * D
142
+ H2 = N2 * D
143
+ except Exception as e:
144
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
145
+
146
+ if D == 0:
147
+ raise ValueError(f"Value D must be non-zero.")
148
+ DTYPE = query.dtype
149
+ return B, S1, S2, N1, N2, D, H1, H2, DTYPE
150
+
151
+
152
+ def convert_from_bnsd(_input, input_layout):
153
+ if input_layout == "BSH":
154
+ # (B,N,S,D)=>(B,S,N*D)
155
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
156
+ elif input_layout == "SBH":
157
+ # (B,N,S,D)=>(S,B,N*D)
158
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
159
+ elif input_layout == "BSND":
160
+ # (B,N,S,D)=>(B,S,N,D)
161
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
162
+ elif input_layout == "TND":
163
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
164
+ else:
165
+ out = _input
166
+ return out
167
+
168
+
169
+ def convert_to_bnsd(_input, n, input_layout):
170
+ # 默认"BNSD"无需处理
171
+ if input_layout == "BSH":
172
+ # (B,S,N*D)=>(B,N,S,D)
173
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
174
+ elif input_layout == "SBH":
175
+ # (S,B,N*D)=>(B,N,S,D)
176
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
177
+ elif input_layout == "BSND":
178
+ # (B,S,N,D)=>(B,N,S,D)
179
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
180
+ elif input_layout == "TND":
181
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
182
+ else:
183
+ out = _input
184
+ if out.dim() != 4:
185
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
186
+ return out.to(gtype)
187
+
188
+
189
+ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
190
+ """
191
+ # sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
192
+ ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
193
+ """
194
+ shape = [S1, S2]
195
+
196
+ if atten_mask is not None:
197
+ # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
198
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
199
+ logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
200
+
201
+ if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
202
+ if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
203
+ if sparse_mode == 2:
204
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
205
+ elif sparse_mode == 3:
206
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
207
+ elif sparse_mode == 4:
208
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
209
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
210
+ atten_mask = atten_mask_u + atten_mask_l
211
+ logger.debug(f"反向转换atten_mask {atten_mask.shape}")
212
+ return atten_mask.to(dtype)
213
+
214
+ return atten_mask.to(dtype)
215
+
216
+ if atten_mask is not None:
217
+ if atten_mask.dim() == 2:
218
+ if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
219
+ raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
220
+ shape = [S1, S2]
221
+ elif atten_mask.dim() == 4:
222
+ if atten_mask.shape[1] == 1:
223
+ shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
224
+ else:
225
+ shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
226
+
227
+ if sparse_mode == 0:
228
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
229
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
230
+ atten_mask = atten_mask_u + atten_mask_l
231
+ elif sparse_mode == 1: # no sparse
232
+ atten_mask = torch.from_numpy(np.zeros(shape))
233
+ elif sparse_mode == 2:
234
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
235
+ elif sparse_mode == 3:
236
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
237
+ elif sparse_mode == 4:
238
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
239
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
240
+ atten_mask = atten_mask_u + atten_mask_l
241
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
242
+ # 因此可以认为FA的输入已经是正确的atten_mask了
243
+ return atten_mask.to(dtype)
244
+
245
+
246
+ def generate_kv(key, value, N1, N2):
247
+ # N不等长适配by cdy
248
+ if not (N1 == N2):
249
+ k_new = broadcast_kv(N1, N2, key, key.dtype)
250
+ v_new = broadcast_kv(N1, N2, value, value.dtype)
251
+ else:
252
+ k_new = key
253
+ v_new = value
254
+ return k_new, v_new
255
+
256
+
257
+ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
258
+ """
259
+ attention = softmax(QK^T/sqrt(d))V
260
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
261
+ """
262
+ logger.info("Using QKV to rebuild original softmax")
263
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
264
+ softmax_res, x_max, x_sum = softmax_forward(qk)
265
+ return softmax_res
266
+
267
+
268
+ def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
269
+ """
270
+ attention = softmax(QK^T/sqrt(d))V
271
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
272
+ """
273
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
274
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
275
+ if softmax_max.shape[-1] == 0:
276
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
277
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
278
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
279
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
280
+ return softmax_res
281
+
282
+
283
+ def get_head_num(*args, **kwargs):
284
+ if kwargs.get("head_num", None):
285
+ head_num = kwargs.get("head_num")
286
+ elif len(args) >= 4:
287
+ head_num = args[3]
288
+ else:
289
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
290
+ return head_num
291
+
292
+
293
+ def get_input_layout(*args, **kwargs):
294
+ if kwargs.get("input_layout", None):
295
+ input_layout = kwargs.get("input_layout")
296
+ elif len(args) >= 5:
297
+ input_layout = args[4]
298
+ else:
299
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
300
+ return input_layout
301
+
302
+
303
+ def npu_fusion_attention_forward_patch(*args, **kwargs):
304
+ # query, key, value, head_num, input_layout
305
+ head_num = get_head_num(*args, **kwargs)
306
+ input_layout = get_input_layout(*args, **kwargs)
307
+
308
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], head_num, input_layout)
309
+ if N1 == N2 and S1 == S2:
310
+ logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
311
+ else:
312
+ logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
313
+ if not (N1 % N2 == 0 and N1 >= N2):
314
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
315
+
316
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
317
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
318
+
319
+ new_kwargs = {"keep_prob": 1,
320
+ "scale": kwargs.get("scale", 1 / (D ** 0.5)),
321
+ "sparse_mode": kwargs.get("sparse_mode", 0),
322
+ "prefix": kwargs.get("prefix"),
323
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
324
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
325
+ "pse": kwargs.get("pse"),
326
+ "padding_mask": kwargs.get("padding_mask"),
327
+ "atten_mask": kwargs.get("atten_mask")}
328
+
329
+ return args, dims_kwargs, new_kwargs
330
+
331
+
332
+ def npu_fusion_attention_backward_patch(*args, **kwargs):
333
+ if len(args) != 6:
334
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
335
+
336
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
337
+ if N1 == N2 and S1 == S2:
338
+ logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
339
+ else:
340
+ logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
341
+ if not (N1 % N2 == 0 and N1 >= N2):
342
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
343
+
344
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
345
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
346
+
347
+ new_kwargs = {"keep_prob": 1,
348
+ "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
349
+ "sparse_mode": kwargs.get("sparse_mode", 0),
350
+ "prefix": kwargs.get("prefix"),
351
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
352
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
353
+ "pse": kwargs.get("pse"),
354
+ "padding_mask": kwargs.get("padding_mask"),
355
+ "softmax_max": kwargs.get("softmax_max"),
356
+ "softmax_sum": kwargs.get("softmax_sum"),
357
+ "softmax_in": kwargs.get("softmax_in"),
358
+ "attention_in": kwargs.get("attention_in"),
359
+ "seed": kwargs.get("seed", 0),
360
+ "offset": kwargs.get("offset", 0),
361
+ "numels": kwargs.get("numels", 0),
362
+ "atten_mask": kwargs.get("atten_mask")}
363
+
364
+ return args, dims_kwargs, new_kwargs
365
+
366
+
367
+ def npu_fusion_attention(*args, **kwargs):
368
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
369
+ query, key, value = new_args[0], new_args[1], new_args[2]
370
+ input_layout = get_input_layout(*args, **kwargs)
371
+ N1 = dims_kwargs.get("N1")
372
+ N2 = dims_kwargs.get("N2")
373
+ S1 = dims_kwargs.get("S1")
374
+ S2 = dims_kwargs.get("S2")
375
+ B = dims_kwargs.get("B")
376
+ DTYPE = dims_kwargs.get("DTYPE")
377
+ atten_mask = new_kwargs.get("atten_mask")
378
+ keep_prob = new_kwargs.get("keep_prob")
379
+ sparse_mode = new_kwargs.get("sparse_mode")
380
+ pre_tockens = new_kwargs.get("pre_tockens")
381
+ next_tockens = new_kwargs.get("next_tockens")
382
+ pse = new_kwargs.get("pse")
383
+ scale = new_kwargs.get("scale")
384
+
385
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
386
+ query = convert_to_bnsd(query, N1, input_layout)
387
+ key = convert_to_bnsd(key, N2, input_layout)
388
+ value = convert_to_bnsd(value, N2, input_layout)
389
+ k_new, v_new = generate_kv(key, value, N1, N2)
390
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
391
+ drop_mask=None, atten_mask=atten_mask,
392
+ pse=pse, scale=scale,
393
+ keep_prob=keep_prob)
394
+ if out_golden.dim() == 5:
395
+ out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
396
+ out_golden.size(4))
397
+ out_golden = convert_from_bnsd(out_golden, input_layout)
398
+
399
+ return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
400
+
401
+
402
+ def npu_fusion_attention_grad(*args, **kwargs):
403
+ # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
404
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
405
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
406
+ N1 = dims_kwargs.get("N1")
407
+ N2 = dims_kwargs.get("N2")
408
+ S1 = dims_kwargs.get("S1")
409
+ S2 = dims_kwargs.get("S2")
410
+ B = dims_kwargs.get("B")
411
+ D = dims_kwargs.get("D")
412
+ DTYPE = dims_kwargs.get("DTYPE")
413
+ atten_mask = new_kwargs.get("atten_mask")
414
+ keep_prob = new_kwargs.get("keep_prob")
415
+ sparse_mode = new_kwargs.get("sparse_mode")
416
+ pre_tockens = new_kwargs.get("pre_tockens")
417
+ next_tockens = new_kwargs.get("next_tockens")
418
+ pse = new_kwargs.get("pse")
419
+ softmax_max = new_kwargs.get("softmax_max")
420
+ softmax_sum = new_kwargs.get("softmax_sum")
421
+ scale_value = new_kwargs.get("scale_value")
422
+
423
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
424
+ query = convert_to_bnsd(query, N1, input_layout)
425
+ dx = convert_to_bnsd(dx, N1, input_layout)
426
+ key = convert_to_bnsd(key, N2, input_layout)
427
+ value = convert_to_bnsd(value, N2, input_layout)
428
+ k_new, v_new = generate_kv(key, value, N1, N2)
429
+
430
+ if softmax_build_mode == "QKV":
431
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
432
+ else:
433
+ softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
434
+
435
+ dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
436
+
437
+ # N不等长适配by cdy
438
+ if not (N1 == N2):
439
+ if N2 == 0:
440
+ raise ValueError("dims_kwargs.N2 must be non-zero.")
441
+ G = int(N1 / N2)
442
+ dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
443
+ dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
444
+
445
+ if dq.dim() == 5:
446
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
447
+ if dk.dim() == 5:
448
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
449
+ if dv.dim() == 5:
450
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
451
+
452
+ dq = convert_from_bnsd(dq, input_layout)
453
+ dk = convert_from_bnsd(dk, input_layout)
454
+ dv = convert_from_bnsd(dv, input_layout)
455
+
456
+ return dq.cpu(), dk.cpu(), dv.cpu()
457
+
458
+
459
+ def is_attention_off_due_to_mask(atten_mask_dtype):
460
+ return not atten_mask_dtype
461
+
462
+
463
+ def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1):
464
+ return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < S1)
465
+
466
+
467
+ def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2):
468
+ return sparse_mode == 0 and pre_tockens >= S1 and next_tockens >= S2
469
+
470
+
471
+ def gpu_fusion_attention(*args, **kwargs):
472
+ deterministic = False
473
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
474
+ query, key, value = new_args[0], new_args[1], new_args[2]
475
+ keep_prob = new_kwargs.get("keep_prob", 1.0)
476
+ scale = new_kwargs.get("scale")
477
+ N1 = dims_kwargs.get("N1")
478
+ N2 = dims_kwargs.get("N2")
479
+ S1 = dims_kwargs.get("S1")
480
+ S2 = dims_kwargs.get("S2")
481
+ B = dims_kwargs.get("B")
482
+ pse = new_kwargs.get("pse")
483
+ sparse_mode = new_kwargs.get("sparse_mode")
484
+ pre_tockens = new_kwargs.get("pre_tockens")
485
+ next_tockens = new_kwargs.get("next_tockens")
486
+ attn_mask = new_kwargs.get("atten_mask")
487
+ atten_mask_dtype = attn_mask.dtype if new_kwargs.get("atten_mask") is not None else None
488
+ pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
489
+ next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
490
+ atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
491
+ is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1) or
492
+ is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2))
493
+ causal_switch = not atten_off
494
+ if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
495
+ window_left = pre_tockens
496
+ window_right = next_tockens
497
+ else:
498
+ pre_tockens = next_tockens = CompareConst.MAX_TOKENS
499
+ window_left = pre_tockens - S1 + S2
500
+ window_right = next_tockens + S1 - S2
501
+
502
+ if pse is not None:
503
+ alibi_slopes = torch.rand(B, N1, dtype=torch.float32) * 0.3
504
+ else:
505
+ alibi_slopes = None
506
+
507
+ out = flash_attn_func(query, key, value, dropout_p=(1-keep_prob), softmax_scale=scale, causal=causal_switch,
508
+ window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic)
509
+ return out, Const.NONE, Const.NONE