mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,39 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ # 前向函数声明对比
18
+ 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
19
+ 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
20
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
21
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
22
+ gen_mask_parallel=True, sync=False
23
+
24
+ # 反向函数声明对比
25
+ 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
26
+ 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
27
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
28
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
30
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
31
+ """
32
+
1
33
  import torch
2
34
  import numpy as np
3
35
  from einops import rearrange
36
+
4
37
  try:
5
38
  import torch_npu
6
39
  except ImportError:
@@ -9,35 +42,17 @@ except ImportError:
9
42
  # flash_attn为gpu的fa三方库
10
43
  from flash_attn import flash_attn_func
11
44
  except ImportError:
12
- #如果为cpu的ut环境,则不做任何处理
45
+ # 如果为cpu的ut环境,则不做任何处理
13
46
  pass
14
47
  else:
15
48
  is_gpu = False
16
49
 
17
-
18
50
  from msprobe.pytorch.common.utils import logger
19
51
  from msprobe.core.common.const import Const, CompareConst
20
52
 
21
53
  gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
22
54
  softmax_build_mode = "QKV" # "MAX_SUM"
23
55
 
24
- """
25
- # 前向函数声明对比
26
- 标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
27
- 融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
28
- atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
29
- next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
30
- gen_mask_parallel=True, sync=False
31
-
32
- # 反向函数声明对比
33
- 标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
34
- 融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
35
- atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
36
- attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
37
- next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
38
- numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
39
- """
40
-
41
56
 
42
57
  def softmax_forward(x):
43
58
  x_max = torch.max(x, dim=-1, keepdims=True)[0]
@@ -62,10 +77,10 @@ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
62
77
 
63
78
  factor = num_heads // num_kv_heads
64
79
  kv_shape = kv_tensor.shape
65
- B = kv_shape[0]
66
- S = kv_shape[2]
67
- D = kv_shape[3]
68
- kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
80
+ b = kv_shape[0]
81
+ s = kv_shape[2]
82
+ d = kv_shape[3]
83
+ kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
69
84
  for i in range(num_heads):
70
85
  j = i // factor
71
86
  kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
@@ -112,7 +127,7 @@ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, k
112
127
 
113
128
  def parse_bsnd_args(query, key, head_num, input_layout):
114
129
  supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
115
- B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
130
+ b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
116
131
 
117
132
  if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
118
133
  raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
@@ -121,32 +136,33 @@ def parse_bsnd_args(query, key, head_num, input_layout):
121
136
  raise ValueError(f"input_layout {input_layout} does not supported for now.")
122
137
  try:
123
138
  if input_layout == "BSH":
124
- B, S1, H1 = query.shape
125
- _, S2, H2 = key.shape
126
- D = H1 // N1
127
- N2 = H2 // D
139
+ b, s1, h1 = query.shape
140
+ _, s2, h2 = key.shape
141
+ d = h1 // n1
142
+ n2 = h2 // d
128
143
  elif input_layout == "SBH":
129
- S1, B, H1 = query.shape
130
- S2, _, H2 = key.shape
131
- D = H1 // N1
132
- N2 = H2 // D
144
+ s1, b, h1 = query.shape
145
+ s2, _, h2 = key.shape
146
+ d = h1 // n1
147
+ n2 = h2 // d
133
148
  elif input_layout == "BSND":
134
- B, S1, N1, D = query.shape
135
- _, S2, N2, _ = key.shape
136
- H1 = N1 * D
137
- H2 = N2 * D
149
+ b, s1, n1, d = query.shape
150
+ _, s2, n2, _ = key.shape
151
+ h1 = n1 * d
152
+ h2 = n2 * d
138
153
  elif input_layout == "BNSD":
139
- B, N1, S1, D = query.shape
140
- _, N2, S2, _ = key.shape
141
- H1 = N1 * D
142
- H2 = N2 * D
154
+ b, n1, s1, d = query.shape
155
+ _, n2, s2, _ = key.shape
156
+ h1 = n1 * d
157
+ h2 = n2 * d
143
158
  except Exception as e:
144
159
  raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
145
160
 
146
- if D == 0:
147
- raise ValueError(f"Value D must be non-zero.")
148
- DTYPE = query.dtype
149
- return B, S1, S2, N1, N2, D, H1, H2, DTYPE
161
+ if d == 0:
162
+ raise ValueError(f"Value d must be non-zero.")
163
+ _dtype = query.dtype
164
+ ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
165
+ return ret
150
166
 
151
167
 
152
168
  def convert_from_bnsd(_input, input_layout):
@@ -186,24 +202,26 @@ def convert_to_bnsd(_input, n, input_layout):
186
202
  return out.to(gtype)
187
203
 
188
204
 
189
- def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
205
+ def generate_atten_mask(*args):
190
206
  """
191
207
  # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
192
208
  ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
193
209
  """
194
- shape = [S1, S2]
210
+
211
+ sparse_mode, atten_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
212
+ shape = [s1, s2]
195
213
 
196
214
  if atten_mask is not None:
197
215
  # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
198
216
  if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
199
- logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
217
+ logger.info(f"s1: {s1}, s2:{s2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
200
218
 
201
219
  if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
202
220
  if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
203
221
  if sparse_mode == 2:
204
222
  atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
205
223
  elif sparse_mode == 3:
206
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
224
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
207
225
  elif sparse_mode == 4:
208
226
  atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
209
227
  atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
@@ -215,14 +233,14 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
215
233
 
216
234
  if atten_mask is not None:
217
235
  if atten_mask.dim() == 2:
218
- if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
236
+ if atten_mask.shape[0] != s1 or atten_mask.shape[1] != s2:
219
237
  raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
220
- shape = [S1, S2]
238
+ shape = [s1, s2]
221
239
  elif atten_mask.dim() == 4:
222
240
  if atten_mask.shape[1] == 1:
223
- shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
241
+ shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
224
242
  else:
225
- shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
243
+ shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
226
244
 
227
245
  if sparse_mode == 0:
228
246
  atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
@@ -233,7 +251,7 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
233
251
  elif sparse_mode == 2:
234
252
  atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
235
253
  elif sparse_mode == 3:
236
- atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
254
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
237
255
  elif sparse_mode == 4:
238
256
  atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
239
257
  atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
@@ -243,11 +261,11 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
243
261
  return atten_mask.to(dtype)
244
262
 
245
263
 
246
- def generate_kv(key, value, N1, N2):
264
+ def generate_kv(key, value, n1, n2):
247
265
  # N不等长适配by cdy
248
- if not (N1 == N2):
249
- k_new = broadcast_kv(N1, N2, key, key.dtype)
250
- v_new = broadcast_kv(N1, N2, value, value.dtype)
266
+ if not (n1 == n2):
267
+ k_new = broadcast_kv(n1, n2, key, key.dtype)
268
+ v_new = broadcast_kv(n1, n2, value, value.dtype)
251
269
  else:
252
270
  k_new = key
253
271
  v_new = value
@@ -305,26 +323,30 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
305
323
  head_num = get_head_num(*args, **kwargs)
306
324
  input_layout = get_input_layout(*args, **kwargs)
307
325
 
308
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], head_num, input_layout)
309
- if N1 == N2 and S1 == S2:
310
- logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
326
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
327
+ if n1 == n2 and s1 == s2:
328
+ logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
311
329
  else:
312
- logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
313
- if not (N1 % N2 == 0 and N1 >= N2):
314
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
315
-
316
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
317
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
318
-
319
- new_kwargs = {"keep_prob": 1,
320
- "scale": kwargs.get("scale", 1 / (D ** 0.5)),
321
- "sparse_mode": kwargs.get("sparse_mode", 0),
322
- "prefix": kwargs.get("prefix"),
323
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
324
- "next_tockens": kwargs.get("next_tockens", 2147483647),
325
- "pse": kwargs.get("pse"),
326
- "padding_mask": kwargs.get("padding_mask"),
327
- "atten_mask": kwargs.get("atten_mask")}
330
+ logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
331
+ if not (n1 % n2 == 0 and n1 >= n2):
332
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
333
+
334
+ dims_kwargs = {
335
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
336
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
337
+ }
338
+
339
+ new_kwargs = {
340
+ "keep_prob": 1,
341
+ "scale": kwargs.get("scale", 1 / (d ** 0.5)),
342
+ "sparse_mode": kwargs.get("sparse_mode", 0),
343
+ "prefix": kwargs.get("prefix"),
344
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
345
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
346
+ "pse": kwargs.get("pse"),
347
+ "padding_mask": kwargs.get("padding_mask"),
348
+ "atten_mask": kwargs.get("atten_mask")
349
+ }
328
350
 
329
351
  return args, dims_kwargs, new_kwargs
330
352
 
@@ -333,33 +355,37 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
333
355
  if len(args) != 6:
334
356
  raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
335
357
 
336
- B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
337
- if N1 == N2 and S1 == S2:
338
- logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
358
+ b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
359
+ if n1 == n2 and s1 == s2:
360
+ logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
339
361
  else:
340
- logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
341
- if not (N1 % N2 == 0 and N1 >= N2):
342
- raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
343
-
344
- dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
345
- "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
346
-
347
- new_kwargs = {"keep_prob": 1,
348
- "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
349
- "sparse_mode": kwargs.get("sparse_mode", 0),
350
- "prefix": kwargs.get("prefix"),
351
- "pre_tockens": kwargs.get("pre_tockens", 2147483647),
352
- "next_tockens": kwargs.get("next_tockens", 2147483647),
353
- "pse": kwargs.get("pse"),
354
- "padding_mask": kwargs.get("padding_mask"),
355
- "softmax_max": kwargs.get("softmax_max"),
356
- "softmax_sum": kwargs.get("softmax_sum"),
357
- "softmax_in": kwargs.get("softmax_in"),
358
- "attention_in": kwargs.get("attention_in"),
359
- "seed": kwargs.get("seed", 0),
360
- "offset": kwargs.get("offset", 0),
361
- "numels": kwargs.get("numels", 0),
362
- "atten_mask": kwargs.get("atten_mask")}
362
+ logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
363
+ if not (n1 % n2 == 0 and n1 >= n2):
364
+ raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
365
+
366
+ dims_kwargs = {
367
+ "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
368
+ "d": d, "h1": h1, "h2": h2, "dtype": dtype
369
+ }
370
+
371
+ new_kwargs = {
372
+ "keep_prob": 1,
373
+ "scale_value": kwargs.get("scale_value", 1 / (d ** 0.5)),
374
+ "sparse_mode": kwargs.get("sparse_mode", 0),
375
+ "prefix": kwargs.get("prefix"),
376
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
377
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
378
+ "pse": kwargs.get("pse"),
379
+ "padding_mask": kwargs.get("padding_mask"),
380
+ "softmax_max": kwargs.get("softmax_max"),
381
+ "softmax_sum": kwargs.get("softmax_sum"),
382
+ "softmax_in": kwargs.get("softmax_in"),
383
+ "attention_in": kwargs.get("attention_in"),
384
+ "seed": kwargs.get("seed", 0),
385
+ "offset": kwargs.get("offset", 0),
386
+ "numels": kwargs.get("numels", 0),
387
+ "atten_mask": kwargs.get("atten_mask")
388
+ }
363
389
 
364
390
  return args, dims_kwargs, new_kwargs
365
391
 
@@ -368,12 +394,12 @@ def npu_fusion_attention(*args, **kwargs):
368
394
  new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
369
395
  query, key, value = new_args[0], new_args[1], new_args[2]
370
396
  input_layout = get_input_layout(*args, **kwargs)
371
- N1 = dims_kwargs.get("N1")
372
- N2 = dims_kwargs.get("N2")
373
- S1 = dims_kwargs.get("S1")
374
- S2 = dims_kwargs.get("S2")
375
- B = dims_kwargs.get("B")
376
- DTYPE = dims_kwargs.get("DTYPE")
397
+ n1 = dims_kwargs.get("n1")
398
+ n2 = dims_kwargs.get("n2")
399
+ s1 = dims_kwargs.get("s1")
400
+ s2 = dims_kwargs.get("s2")
401
+ b = dims_kwargs.get("b")
402
+ dtype = dims_kwargs.get("dtype")
377
403
  atten_mask = new_kwargs.get("atten_mask")
378
404
  keep_prob = new_kwargs.get("keep_prob")
379
405
  sparse_mode = new_kwargs.get("sparse_mode")
@@ -381,12 +407,12 @@ def npu_fusion_attention(*args, **kwargs):
381
407
  next_tockens = new_kwargs.get("next_tockens")
382
408
  pse = new_kwargs.get("pse")
383
409
  scale = new_kwargs.get("scale")
384
-
385
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
386
- query = convert_to_bnsd(query, N1, input_layout)
387
- key = convert_to_bnsd(key, N2, input_layout)
388
- value = convert_to_bnsd(value, N2, input_layout)
389
- k_new, v_new = generate_kv(key, value, N1, N2)
410
+ args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
411
+ atten_mask = generate_atten_mask(*args_temp)
412
+ query = convert_to_bnsd(query, n1, input_layout)
413
+ key = convert_to_bnsd(key, n2, input_layout)
414
+ value = convert_to_bnsd(value, n2, input_layout)
415
+ k_new, v_new = generate_kv(key, value, n1, n2)
390
416
  out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
391
417
  drop_mask=None, atten_mask=atten_mask,
392
418
  pse=pse, scale=scale,
@@ -403,13 +429,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
403
429
  # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
404
430
  new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
405
431
  query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
406
- N1 = dims_kwargs.get("N1")
407
- N2 = dims_kwargs.get("N2")
408
- S1 = dims_kwargs.get("S1")
409
- S2 = dims_kwargs.get("S2")
410
- B = dims_kwargs.get("B")
411
- D = dims_kwargs.get("D")
412
- DTYPE = dims_kwargs.get("DTYPE")
432
+ n1 = dims_kwargs.get("n1")
433
+ n2 = dims_kwargs.get("n2")
434
+ s1 = dims_kwargs.get("s1")
435
+ s2 = dims_kwargs.get("s2")
436
+ b = dims_kwargs.get("b")
437
+ d = dims_kwargs.get("d")
438
+ dtype = dims_kwargs.get("dtype")
413
439
  atten_mask = new_kwargs.get("atten_mask")
414
440
  keep_prob = new_kwargs.get("keep_prob")
415
441
  sparse_mode = new_kwargs.get("sparse_mode")
@@ -420,12 +446,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
420
446
  softmax_sum = new_kwargs.get("softmax_sum")
421
447
  scale_value = new_kwargs.get("scale_value")
422
448
 
423
- atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
424
- query = convert_to_bnsd(query, N1, input_layout)
425
- dx = convert_to_bnsd(dx, N1, input_layout)
426
- key = convert_to_bnsd(key, N2, input_layout)
427
- value = convert_to_bnsd(value, N2, input_layout)
428
- k_new, v_new = generate_kv(key, value, N1, N2)
449
+ args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
450
+ atten_mask = generate_atten_mask(*args_temp)
451
+ query = convert_to_bnsd(query, n1, input_layout)
452
+ dx = convert_to_bnsd(dx, n1, input_layout)
453
+ key = convert_to_bnsd(key, n2, input_layout)
454
+ value = convert_to_bnsd(value, n2, input_layout)
455
+ k_new, v_new = generate_kv(key, value, n1, n2)
429
456
 
430
457
  if softmax_build_mode == "QKV":
431
458
  softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
@@ -435,12 +462,12 @@ def npu_fusion_attention_grad(*args, **kwargs):
435
462
  dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
436
463
 
437
464
  # N不等长适配by cdy
438
- if not (N1 == N2):
439
- if N2 == 0:
440
- raise ValueError("dims_kwargs.N2 must be non-zero.")
441
- G = int(N1 / N2)
442
- dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
443
- dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
465
+ if not (n1 == n2):
466
+ if n2 == 0:
467
+ raise ValueError("dims_kwargs.n2 must be non-zero.")
468
+ g = int(n1 / n2)
469
+ dk = torch.sum(dk.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
470
+ dv = torch.sum(dv.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
444
471
 
445
472
  if dq.dim() == 5:
446
473
  dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
@@ -460,12 +487,12 @@ def is_attention_off_due_to_mask(atten_mask_dtype):
460
487
  return not atten_mask_dtype
461
488
 
462
489
 
463
- def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1):
464
- return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < S1)
490
+ def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1):
491
+ return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < s1)
465
492
 
466
493
 
467
- def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2):
468
- return sparse_mode == 0 and pre_tockens >= S1 and next_tockens >= S2
494
+ def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2):
495
+ return sparse_mode == 0 and pre_tockens >= s1 and next_tockens >= s2
469
496
 
470
497
 
471
498
  def gpu_fusion_attention(*args, **kwargs):
@@ -474,11 +501,11 @@ def gpu_fusion_attention(*args, **kwargs):
474
501
  query, key, value = new_args[0], new_args[1], new_args[2]
475
502
  keep_prob = new_kwargs.get("keep_prob", 1.0)
476
503
  scale = new_kwargs.get("scale")
477
- N1 = dims_kwargs.get("N1")
478
- N2 = dims_kwargs.get("N2")
479
- S1 = dims_kwargs.get("S1")
480
- S2 = dims_kwargs.get("S2")
481
- B = dims_kwargs.get("B")
504
+ n1 = dims_kwargs.get("n1")
505
+ n2 = dims_kwargs.get("n2")
506
+ s1 = dims_kwargs.get("s1")
507
+ s2 = dims_kwargs.get("s2")
508
+ b = dims_kwargs.get("b")
482
509
  pse = new_kwargs.get("pse")
483
510
  sparse_mode = new_kwargs.get("sparse_mode")
484
511
  pre_tockens = new_kwargs.get("pre_tockens")
@@ -488,22 +515,24 @@ def gpu_fusion_attention(*args, **kwargs):
488
515
  pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
489
516
  next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
490
517
  atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
491
- is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, S1) or
492
- is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, S1, S2))
518
+ is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1) or
519
+ is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2))
493
520
  causal_switch = not atten_off
494
521
  if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
495
522
  window_left = pre_tockens
496
523
  window_right = next_tockens
497
524
  else:
498
525
  pre_tockens = next_tockens = CompareConst.MAX_TOKENS
499
- window_left = pre_tockens - S1 + S2
500
- window_right = next_tockens + S1 - S2
501
-
526
+ window_left = pre_tockens - s1 + s2
527
+ window_right = next_tockens + s1 - s2
528
+
502
529
  if pse is not None:
503
- alibi_slopes = torch.rand(B, N1, dtype=torch.float32) * 0.3
530
+ alibi_slopes = torch.rand(b, n1, dtype=torch.float32) * 0.3
504
531
  else:
505
532
  alibi_slopes = None
506
-
507
- out = flash_attn_func(query, key, value, dropout_p=(1-keep_prob), softmax_scale=scale, causal=causal_switch,
508
- window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic)
533
+
534
+ out = flash_attn_func(
535
+ query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
536
+ window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
537
+ )
509
538
  return out, Const.NONE, Const.NONE
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -25,15 +40,19 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
25
40
  x_shape = x.shape
26
41
  h = x.float()
27
42
  grad = dy_tensor.float()
28
- condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
29
- ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
30
- (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
31
- condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
32
- ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
33
- (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
34
- condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
35
- ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
36
- (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
43
+ condition_1 = (r1_shape[0] == 1
44
+ and r1_shape[1] == x_shape[1]
45
+ and r1_shape[2] == 1
46
+ and r1_shape[3] == x_shape[3])
47
+ condition_2 = (r1_shape[0] == 1
48
+ and r1_shape[1] == 1
49
+ and r1_shape[2] == x_shape[2]
50
+ and r1_shape[3] == x_shape[3])
51
+ condition_3 = (r1_shape[0] == x_shape[0]
52
+ and r1_shape[1] == 1
53
+ and r1_shape[2] == 1
54
+ and r1_shape[3] == x_shape[3])
55
+
37
56
  if condition_1:
38
57
  for i in range(x_shape[0]):
39
58
  for j in range(x_shape[2]):
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,16 +1,31 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
4
19
  def npu_swiglu(x, dim=-1):
5
20
  tensor_dtype = x.dtype
6
21
 
7
- inTensors = torch.chunk(x, 2, dim=dim)
22
+ in_tensors = torch.chunk(x, 2, dim=dim)
8
23
  if tensor_dtype == torch.float32:
9
- tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
10
- output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
24
+ tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
25
+ output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
11
26
  else:
12
- tensor_self_float = inTensors[0].type(torch.float)
13
- tensor_other_float = inTensors[1].type(torch.float)
27
+ tensor_self_float = in_tensors[0].type(torch.float)
28
+ tensor_other_float = in_tensors[1].type(torch.float)
14
29
  tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
15
30
  torch.float32) * tensor_other_float
16
31
  output_data = tensor_out_float.type(tensor_dtype)