mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,6 +1,39 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ # 前向函数声明对比
18
+ 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
19
+ 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
20
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
21
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
22
+ gen_mask_parallel=True, sync=False
23
+
24
+ # 反向函数声明对比
25
+ 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
26
+ 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
27
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
28
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
30
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
31
+ """
32
+
1
33
  import torch
2
34
  import numpy as np
3
35
  from einops import rearrange
36
+
4
37
  try:
5
38
  import torch_npu
6
39
  except ImportError:
@@ -9,34 +42,16 @@ except ImportError:
9
42
  # flash_attn为gpu的fa三方库
10
43
  from flash_attn import flash_attn_func
11
44
  except ImportError:
12
- #如果为cpu的ut环境,则不做任何处理
45
+ # 如果为cpu的ut环境,则不做任何处理
13
46
  pass
14
47
  else:
15
48
  is_gpu = False
16
49
 
17
-
18
50
  from msprobe.pytorch.common.utils import logger
19
51
  from msprobe.core.common.const import Const, CompareConst
20
52
 
21
- gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
22
- softmax_build_mode = "QKV" # "MAX_SUM"
23
-
24
- """
25
- # 前向函数声明对比
26
- 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
27
- 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
28
- atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
- next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
30
- gen_mask_parallel=True, sync=False
31
-
32
- # 反向函数声明对比
33
- 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
34
- 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
35
- atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
36
- attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
37
- next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
38
- numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
39
- """
53
+ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
54
+ SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
40
55
 
41
56
 
42
57
  def softmax_forward(x):
@@ -62,10 +77,10 @@ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
62
77
 
63
78
  factor = num_heads // num_kv_heads
64
79
  kv_shape = kv_tensor.shape
65
- B = kv_shape[0]
66
- S = kv_shape[2]
67
- D = kv_shape[3]
68
- kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
80
+ b = kv_shape[0]
81
+ s = kv_shape[2]
82
+ d = kv_shape[3]
83
+ kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
69
84
  for i in range(num_heads):
70
85
  j = i // factor
71
86
  kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
@@ -112,7 +127,7 @@ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, k
112
127
 
113
128
  def parse_bsnd_args(query, key, head_num, input_layout):
114
129
  supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
115
- B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
130
+ b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
116
131
 
117
132
  if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
118
133
  raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
@@ -121,35 +136,48 @@ def parse_bsnd_args(query, key, head_num, input_layout):
121
136
  raise ValueError(f"input_layout {input_layout} does not supported for now.")
122
137
  try:
123
138
  if input_layout == "BSH":
124
- B, S1, H1 = query.shape
125
- _, S2, H2 = key.shape
126
- D = H1 // N1
127
- N2 = H2 // D
139
+ b, s1, h1 = query.shape
140
+ _, s2, h2 = key.shape
141
+ d = h1 // n1
142
+ n2 = h2 // d
128
143
  elif input_layout == "SBH":
129
- S1, B, H1 = query.shape
130
- S2, _, H2 = key.shape
131
- D = H1 // N1
132
- N2 = H2 // D
144
+ s1, b, h1 = query.shape
145
+ s2, _, h2 = key.shape
146
+ d = h1 // n1
147
+ n2 = h2 // d
133
148
  elif input_layout == "BSND":
134
- B, S1, N1, D = query.shape
135
- _, S2, N2, _ = key.shape
136
- H1 = N1 * D
137
- H2 = N2 * D
149
+ b, s1, n1, d = query.shape
150
+ _, s2, n2, _ = key.shape
151
+ h1 = n1 * d
152
+ h2 = n2 * d
138
153
  elif input_layout == "BNSD":
139
- B, N1, S1, D = query.shape
140
- _, N2, S2, _ = key.shape
141
- H1 = N1 * D
142
- H2 = N2 * D
154
+ b, n1, s1, d = query.shape
155
+ _, n2, s2, _ = key.shape
156
+ h1 = n1 * d
157
+ h2 = n2 * d
143
158
  except Exception as e:
144
159
  raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
145
160
 
146
- if D == 0:
147
- raise ValueError(f"Value D must be non-zero.")
148
- DTYPE = query.dtype
149
- return B, S1, S2, N1, N2, D, H1, H2, DTYPE
161
+ if d == 0:
162
+ raise ValueError(f"Value d must be non-zero.")
163
+ _dtype = query.dtype
164
+ ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
165
+ return ret
150
166
 
151
167
 
152
168
  def convert_from_bnsd(_input, input_layout):
169
+ """
170
+ transform qkv from bnsd to input_layout.
171
+ B: batch_size
172
+ S: sequence_length
173
+ N: num_heads
174
+ D: head_dim
175
+ Args:
176
+ _input (torch.Tensor): tensor of shape (B,N,S,D)
177
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
178
+ Returns:
179
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
180
+ """
153
181
  if input_layout == "BSH":
154
182
  # (B,N,S,D)=>(B,S,N*D)
155
183
  out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
@@ -167,7 +195,19 @@ def convert_from_bnsd(_input, input_layout):
167
195
 
168
196
 
169
197
  def convert_to_bnsd(_input, n, input_layout):
170
- # 默认"BNSD"无需处理
198
+ """
199
+ transform qkv from input_layout to bnsd.
200
+ B: batch_size
201
+ S: sequence_length
202
+ N: num_heads
203
+ D: head_dim
204
+ Args:
205
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
206
+ n (int): num_heads
207
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
208
+ Returns:
209
+ tensor of shape (B,N,S,D)
210
+ """
171
211
  if input_layout == "BSH":
172
212
  # (B,S,N*D)=>(B,N,S,D)
173
213
  out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
@@ -183,27 +223,90 @@ def convert_to_bnsd(_input, n, input_layout):
183
223
  out = _input
184
224
  if out.dim() != 4:
185
225
  raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
186
- return out.to(gtype)
226
+ return out.to(GTYPE)
187
227
 
188
228
 
189
- def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
229
+ def convert_from_bsnd(_input, input_layout):
230
+ """
231
+ transform qkv from bsnd to input_layout.
232
+ B: batch_size
233
+ S: sequence_length
234
+ N: num_heads
235
+ D: head_dim
236
+ Args:
237
+ _input (torch.Tensor): tensor of shape (B,S,N,D)
238
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
239
+ Returns:
240
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
241
+ """
242
+ if input_layout == "BSH":
243
+ # (B,S,N,D)=>(B,S,N*D)
244
+ out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
245
+ elif input_layout == "SBH":
246
+ # (B,S,N,D)=>(S,B,N*D)
247
+ out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
248
+ elif input_layout == "BNSD":
249
+ # (B,S,N,D)=>(B,N,S,D)
250
+ out = rearrange(_input, 'b s n d -> b n s d').contiguous()
251
+ elif input_layout == "TND":
252
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
253
+ else:
254
+ out = _input
255
+ return out
256
+
257
+
258
+ def convert_to_bsnd(_input, n, input_layout):
259
+ """
260
+ transform qkv from input_layout to bsnd.
261
+ B: batch_size
262
+ S: sequence_length
263
+ N: num_heads
264
+ D: head_dim
265
+ Args:
266
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
267
+ n (int): num_heads
268
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
269
+ Returns:
270
+ tensor of shape (B,S,N,D)
271
+ """
272
+ if input_layout == "BSH":
273
+ # (B,S,N*D)=>(B,S,N,D)
274
+ out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
275
+ elif input_layout == "SBH":
276
+ # (S,B,N*D)=>(B,S,N,D)
277
+ out = rearrange(_input, 's b (n d) -> b s n d', n=n)
278
+ elif input_layout == "BNSD":
279
+ # (B,N,S,D)=>(B,S,N,D)
280
+ out = rearrange(_input, 'b n s d -> b s n d', n=n)
281
+ elif input_layout == "TND":
282
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
283
+ else:
284
+ out = _input
285
+ if out.dim() != 4:
286
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
287
+ return out
288
+
289
+
290
+ def generate_atten_mask(*args):
190
291
  """
191
292
  # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
192
293
  ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
193
294
  """
194
- shape = [S1, S2]
295
+
296
+ sparse_mode, atten_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
297
+ shape = [s1, s2]
195
298
 
196
299
  if atten_mask is not None:
197
300
  # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
198
301
  if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
199
- logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
302
+ logger.info(f"s1: {s1}, s2:{s2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
200
303
 
201
304
  if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
202
305
  if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
203
306
  if sparse_mode == 2:
204
307
  atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
205
308
  elif sparse_mode == 3:
206
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
309
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
207
310
  elif sparse_mode == 4:
208
311
  atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
209
312
  atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
@@ -215,14 +318,14 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
215
318
 
216
319
  if atten_mask is not None:
217
320
  if atten_mask.dim() == 2:
218
- if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
321
+ if atten_mask.shape[0] != s1 or atten_mask.shape[1] != s2:
219
322
  raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
220
- shape = [S1, S2]
323
+ shape = [s1, s2]
221
324
  elif atten_mask.dim() == 4:
222
325
  if atten_mask.shape[1] == 1:
223
- shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
326
+ shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
224
327
  else:
225
- shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
328
+ shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
226
329
 
227
330
  if sparse_mode == 0:
228
331
  atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
@@ -233,7 +336,7 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
233
336
  elif sparse_mode == 2:
234
337
  atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
235
338
  elif sparse_mode == 3:
236
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
339
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
237
340
  elif sparse_mode == 4:
238
341
  atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
239
342
  atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
@@ -243,11 +346,11 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
243
346
  return atten_mask.to(dtype)
244
347
 
245
348
 
246
- def generate_kv(key, value, N1, N2):
349
+ def generate_kv(key, value, n1, n2):
247
350
  # N不等长适配by cdy
248
- if not (N1 == N2):
249
- k_new = broadcast_kv(N1, N2, key, key.dtype)
250
- v_new = broadcast_kv(N1, N2, value, value.dtype)
351
+ if not (n1 == n2):
352
+ k_new = broadcast_kv(n1, n2, key, key.dtype)
353
+ v_new = broadcast_kv(n1, n2, value, value.dtype)
251
354
  else:
252
355
  k_new = key
253
356
  v_new = value
@@ -261,7 +364,7 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
261
364
  """
262
365
  logger.info("Using QKV to rebuild original softmax")
263
366
  qk = calculate_qk(q, k, atten_mask, pse, scale)
264
- softmax_res, x_max, x_sum = softmax_forward(qk)
367
+ softmax_res, _, _ = softmax_forward(qk)
265
368
  return softmax_res
266
369
 
267
370
 
@@ -301,30 +404,38 @@ def get_input_layout(*args, **kwargs):
301
404
 
302
405
 
303
406
  def npu_fusion_attention_forward_patch(*args, **kwargs):
407
+
408
+ if len(args) < 2:
409
+ raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
410
+
304
411
  # query, key, value, head_num, input_layout
305
412
  head_num = get_head_num(*args, **kwargs)
306
413
  input_layout = get_input_layout(*args, **kwargs)
307
414
 
308
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], head_num, input_layout)
309
- if N1 == N2 and S1 == S2:
310
- logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
415
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
416
+ if n1 == n2 and s1 == s2:
417
+ logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
311
418
  else:
312
- logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
313
- if not (N1 % N2 == 0 and N1 >= N2):
314
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
315
-
316
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
317
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
318
-
319
- new_kwargs = {"keep_prob": 1,
320
- "scale": kwargs.get("scale", 1 / (D ** 0.5)),
321
- "sparse_mode": kwargs.get("sparse_mode", 0),
322
- "prefix": kwargs.get("prefix"),
323
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
324
- "next_tockens": kwargs.get("next_tockens", 2147483647),
325
- "pse": kwargs.get("pse"),
326
- "padding_mask": kwargs.get("padding_mask"),
327
- "atten_mask": kwargs.get("atten_mask")}
419
+ logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
420
+ if not (n1 % n2 == 0 and n1 >= n2):
421
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
422
+
423
+ dims_kwargs = {
424
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
425
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
426
+ }
427
+
428
+ new_kwargs = {
429
+ "keep_prob": 1,
430
+ "scale": kwargs.get("scale", 1 / (d ** 0.5)),
431
+ "sparse_mode": kwargs.get("sparse_mode", 0),
432
+ "prefix": kwargs.get("prefix"),
433
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
434
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
435
+ "pse": kwargs.get("pse"),
436
+ "padding_mask": kwargs.get("padding_mask"),
437
+ "atten_mask": kwargs.get("atten_mask")
438
+ }
328
439
 
329
440
  return args, dims_kwargs, new_kwargs
330
441
 
@@ -333,33 +444,37 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
333
444
  if len(args) != 6:
334
445
  raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
335
446
 
336
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
337
- if N1 == N2 and S1 == S2:
338
- logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
447
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
448
+ if n1 == n2 and s1 == s2:
449
+ logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
339
450
  else:
340
- logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
341
- if not (N1 % N2 == 0 and N1 >= N2):
342
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
343
-
344
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
345
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
346
-
347
- new_kwargs = {"keep_prob": 1,
348
- "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
349
- "sparse_mode": kwargs.get("sparse_mode", 0),
350
- "prefix": kwargs.get("prefix"),
351
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
352
- "next_tockens": kwargs.get("next_tockens", 2147483647),
353
- "pse": kwargs.get("pse"),
354
- "padding_mask": kwargs.get("padding_mask"),
355
- "softmax_max": kwargs.get("softmax_max"),
356
- "softmax_sum": kwargs.get("softmax_sum"),
357
- "softmax_in": kwargs.get("softmax_in"),
358
- "attention_in": kwargs.get("attention_in"),
359
- "seed": kwargs.get("seed", 0),
360
- "offset": kwargs.get("offset", 0),
361
- "numels": kwargs.get("numels", 0),
362
- "atten_mask": kwargs.get("atten_mask")}
451
+ logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
452
+ if not (n1 % n2 == 0 and n1 >= n2):
453
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
454
+
455
+ dims_kwargs = {
456
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
457
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
458
+ }
459
+
460
+ new_kwargs = {
461
+ "keep_prob": 1,
462
+ "scale_value": kwargs.get("scale_value", 1 / (d ** 0.5)),
463
+ "sparse_mode": kwargs.get("sparse_mode", 0),
464
+ "prefix": kwargs.get("prefix"),
465
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
466
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
467
+ "pse": kwargs.get("pse"),
468
+ "padding_mask": kwargs.get("padding_mask"),
469
+ "softmax_max": kwargs.get("softmax_max"),
470
+ "softmax_sum": kwargs.get("softmax_sum"),
471
+ "softmax_in": kwargs.get("softmax_in"),
472
+ "attention_in": kwargs.get("attention_in"),
473
+ "seed": kwargs.get("seed", 0),
474
+ "offset": kwargs.get("offset", 0),
475
+ "numels": kwargs.get("numels", 0),
476
+ "atten_mask": kwargs.get("atten_mask")
477
+ }
363
478
 
364
479
  return args, dims_kwargs, new_kwargs
365
480
 
@@ -368,12 +483,12 @@ def npu_fusion_attention(*args, **kwargs):
368
483
  new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
369
484
  query, key, value = new_args[0], new_args[1], new_args[2]
370
485
  input_layout = get_input_layout(*args, **kwargs)
371
- N1 = dims_kwargs.get("N1")
372
- N2 = dims_kwargs.get("N2")
373
- S1 = dims_kwargs.get("S1")
374
- S2 = dims_kwargs.get("S2")
375
- B = dims_kwargs.get("B")
376
- DTYPE = dims_kwargs.get("DTYPE")
486
+ n1 = dims_kwargs.get("n1")
487
+ n2 = dims_kwargs.get("n2")
488
+ s1 = dims_kwargs.get("s1")
489
+ s2 = dims_kwargs.get("s2")
490
+ b = dims_kwargs.get("b")
491
+ dtype = dims_kwargs.get("dtype")
377
492
  atten_mask = new_kwargs.get("atten_mask")
378
493
  keep_prob = new_kwargs.get("keep_prob")
379
494
  sparse_mode = new_kwargs.get("sparse_mode")
@@ -381,12 +496,12 @@ def npu_fusion_attention(*args, **kwargs):
381
496
  next_tockens = new_kwargs.get("next_tockens")
382
497
  pse = new_kwargs.get("pse")
383
498
  scale = new_kwargs.get("scale")
384
-
385
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
386
- query = convert_to_bnsd(query, N1, input_layout)
387
- key = convert_to_bnsd(key, N2, input_layout)
388
- value = convert_to_bnsd(value, N2, input_layout)
389
- k_new, v_new = generate_kv(key, value, N1, N2)
499
+ args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
500
+ atten_mask = generate_atten_mask(*args_temp)
501
+ query = convert_to_bnsd(query, n1, input_layout)
502
+ key = convert_to_bnsd(key, n2, input_layout)
503
+ value = convert_to_bnsd(value, n2, input_layout)
504
+ k_new, v_new = generate_kv(key, value, n1, n2)
390
505
  out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
391
506
  drop_mask=None, atten_mask=atten_mask,
392
507
  pse=pse, scale=scale,
@@ -403,13 +518,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
403
518
  # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
404
519
  new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
405
520
  query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
406
- N1 = dims_kwargs.get("N1")
407
- N2 = dims_kwargs.get("N2")
408
- S1 = dims_kwargs.get("S1")
409
- S2 = dims_kwargs.get("S2")
410
- B = dims_kwargs.get("B")
411
- D = dims_kwargs.get("D")
412
- DTYPE = dims_kwargs.get("DTYPE")
521
+ n1 = dims_kwargs.get("n1")
522
+ n2 = dims_kwargs.get("n2")
523
+ s1 = dims_kwargs.get("s1")
524
+ s2 = dims_kwargs.get("s2")
525
+ b = dims_kwargs.get("b")
526
+ d = dims_kwargs.get("d")
527
+ dtype = dims_kwargs.get("dtype")
413
528
  atten_mask = new_kwargs.get("atten_mask")
414
529
  keep_prob = new_kwargs.get("keep_prob")
415
530
  sparse_mode = new_kwargs.get("sparse_mode")
@@ -420,14 +535,15 @@ def npu_fusion_attention_grad(*args, **kwargs):
420
535
  softmax_sum = new_kwargs.get("softmax_sum")
421
536
  scale_value = new_kwargs.get("scale_value")
422
537
 
423
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
424
- query = convert_to_bnsd(query, N1, input_layout)
425
- dx = convert_to_bnsd(dx, N1, input_layout)
426
- key = convert_to_bnsd(key, N2, input_layout)
427
- value = convert_to_bnsd(value, N2, input_layout)
428
- k_new, v_new = generate_kv(key, value, N1, N2)
538
+ args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
539
+ atten_mask = generate_atten_mask(*args_temp)
540
+ query = convert_to_bnsd(query, n1, input_layout)
541
+ dx = convert_to_bnsd(dx, n1, input_layout)
542
+ key = convert_to_bnsd(key, n2, input_layout)
543
+ value = convert_to_bnsd(value, n2, input_layout)
544
+ k_new, v_new = generate_kv(key, value, n1, n2)
429
545
 
430
- if softmax_build_mode == "QKV":
546
+ if SOFTMAX_BUILD_MODE == "QKV":
431
547
  softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
432
548
  else:
433
549
  softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
@@ -435,12 +551,12 @@ def npu_fusion_attention_grad(*args, **kwargs):
435
551
  dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
436
552
 
437
553
  # N不等长适配by cdy
438
- if not (N1 == N2):
439
- if N2 == 0:
440
- raise ValueError("dims_kwargs.N2 must be non-zero.")
441
- G = int(N1 / N2)
442
- dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
443
- dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
554
+ if not (n1 == n2):
555
+ if n2 == 0:
556
+ raise ValueError("dims_kwargs.n2 must be non-zero.")
557
+ g = int(n1 / n2)
558
+ dk = torch.sum(dk.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
559
+ dv = torch.sum(dv.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
444
560
 
445
561
  if dq.dim() == 5:
446
562
  dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
@@ -460,12 +576,12 @@ def is_attention_off_due_to_mask(atten_mask_dtype):
460
576
  return not atten_mask_dtype
461
577
 
462
578
 
463
- def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1):
464
- return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < S1)
579
+ def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1):
580
+ return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < s1)
465
581
 
466
582
 
467
- def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2):
468
- return sparse_mode == 0 and pre_tockens >= S1 and next_tockens >= S2
583
+ def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2):
584
+ return sparse_mode == 0 and pre_tockens >= s1 and next_tockens >= s2
469
585
 
470
586
 
471
587
  def gpu_fusion_attention(*args, **kwargs):
@@ -474,11 +590,11 @@ def gpu_fusion_attention(*args, **kwargs):
474
590
  query, key, value = new_args[0], new_args[1], new_args[2]
475
591
  keep_prob = new_kwargs.get("keep_prob", 1.0)
476
592
  scale = new_kwargs.get("scale")
477
- N1 = dims_kwargs.get("N1")
478
- N2 = dims_kwargs.get("N2")
479
- S1 = dims_kwargs.get("S1")
480
- S2 = dims_kwargs.get("S2")
481
- B = dims_kwargs.get("B")
593
+ n1 = dims_kwargs.get("n1")
594
+ n2 = dims_kwargs.get("n2")
595
+ s1 = dims_kwargs.get("s1")
596
+ s2 = dims_kwargs.get("s2")
597
+ b = dims_kwargs.get("b")
482
598
  pse = new_kwargs.get("pse")
483
599
  sparse_mode = new_kwargs.get("sparse_mode")
484
600
  pre_tockens = new_kwargs.get("pre_tockens")
@@ -488,22 +604,29 @@ def gpu_fusion_attention(*args, **kwargs):
488
604
  pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
489
605
  next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
490
606
  atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
491
- is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1) or
492
- is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2))
607
+ is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1) or
608
+ is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2))
493
609
  causal_switch = not atten_off
494
610
  if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
495
611
  window_left = pre_tockens
496
612
  window_right = next_tockens
497
613
  else:
498
614
  pre_tockens = next_tockens = CompareConst.MAX_TOKENS
499
- window_left = pre_tockens - S1 + S2
500
- window_right = next_tockens + S1 - S2
501
-
615
+ window_left = pre_tockens - s1 + s2
616
+ window_right = next_tockens + s1 - s2
617
+
502
618
  if pse is not None:
503
- alibi_slopes = torch.rand(B, N1, dtype=torch.float32) * 0.3
619
+ alibi_slopes = torch.rand(b, n1, dtype=torch.float32) * 0.3
504
620
  else:
505
621
  alibi_slopes = None
506
-
507
- out = flash_attn_func(query, key, value, dropout_p=(1-keep_prob), softmax_scale=scale, causal=causal_switch,
508
- window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic)
622
+
623
+ input_layout = get_input_layout(*args, **kwargs)
624
+ query = convert_to_bsnd(query, n1, input_layout)
625
+ key = convert_to_bsnd(key, n2, input_layout)
626
+ value = convert_to_bsnd(value, n2, input_layout)
627
+ out = flash_attn_func(
628
+ query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
629
+ window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
630
+ )
631
+ out = convert_from_bsnd(out, input_layout)
509
632
  return out, Const.NONE, Const.NONE
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18