mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,39 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
# 前向函数声明对比
|
|
18
|
+
标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
|
|
19
|
+
融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
20
|
+
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
21
|
+
next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
|
|
22
|
+
gen_mask_parallel=True, sync=False
|
|
23
|
+
|
|
24
|
+
# 反向函数声明对比
|
|
25
|
+
标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
|
|
26
|
+
融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
27
|
+
atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
|
|
28
|
+
attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
29
|
+
next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
|
|
30
|
+
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
31
|
+
"""
|
|
32
|
+
|
|
1
33
|
import torch
|
|
2
34
|
import numpy as np
|
|
3
35
|
from einops import rearrange
|
|
36
|
+
|
|
4
37
|
try:
|
|
5
38
|
import torch_npu
|
|
6
39
|
except ImportError:
|
|
@@ -9,35 +42,17 @@ except ImportError:
|
|
|
9
42
|
# flash_attn为gpu的fa三方库
|
|
10
43
|
from flash_attn import flash_attn_func
|
|
11
44
|
except ImportError:
|
|
12
|
-
|
|
45
|
+
# 如果为cpu的ut环境,则不做任何处理
|
|
13
46
|
pass
|
|
14
47
|
else:
|
|
15
48
|
is_gpu = False
|
|
16
49
|
|
|
17
|
-
|
|
18
50
|
from msprobe.pytorch.common.utils import logger
|
|
19
51
|
from msprobe.core.common.const import Const, CompareConst
|
|
20
52
|
|
|
21
53
|
gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
|
|
22
54
|
softmax_build_mode = "QKV" # "MAX_SUM"
|
|
23
55
|
|
|
24
|
-
"""
|
|
25
|
-
# 前向函数声明对比
|
|
26
|
-
标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
|
|
27
|
-
融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
28
|
-
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
29
|
-
next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
|
|
30
|
-
gen_mask_parallel=True, sync=False
|
|
31
|
-
|
|
32
|
-
# 反向函数声明对比
|
|
33
|
-
标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
|
|
34
|
-
融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
35
|
-
atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
|
|
36
|
-
attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
37
|
-
next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
|
|
38
|
-
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
56
|
|
|
42
57
|
def softmax_forward(x):
|
|
43
58
|
x_max = torch.max(x, dim=-1, keepdims=True)[0]
|
|
@@ -62,10 +77,10 @@ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
|
|
|
62
77
|
|
|
63
78
|
factor = num_heads // num_kv_heads
|
|
64
79
|
kv_shape = kv_tensor.shape
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
kv_res = torch.zeros([
|
|
80
|
+
b = kv_shape[0]
|
|
81
|
+
s = kv_shape[2]
|
|
82
|
+
d = kv_shape[3]
|
|
83
|
+
kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
|
|
69
84
|
for i in range(num_heads):
|
|
70
85
|
j = i // factor
|
|
71
86
|
kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
|
|
@@ -112,7 +127,7 @@ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, k
|
|
|
112
127
|
|
|
113
128
|
def parse_bsnd_args(query, key, head_num, input_layout):
|
|
114
129
|
supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
|
|
115
|
-
|
|
130
|
+
b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
|
|
116
131
|
|
|
117
132
|
if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
|
|
118
133
|
raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
|
|
@@ -121,32 +136,33 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
121
136
|
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
122
137
|
try:
|
|
123
138
|
if input_layout == "BSH":
|
|
124
|
-
|
|
125
|
-
_,
|
|
126
|
-
|
|
127
|
-
|
|
139
|
+
b, s1, h1 = query.shape
|
|
140
|
+
_, s2, h2 = key.shape
|
|
141
|
+
d = h1 // n1
|
|
142
|
+
n2 = h2 // d
|
|
128
143
|
elif input_layout == "SBH":
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
144
|
+
s1, b, h1 = query.shape
|
|
145
|
+
s2, _, h2 = key.shape
|
|
146
|
+
d = h1 // n1
|
|
147
|
+
n2 = h2 // d
|
|
133
148
|
elif input_layout == "BSND":
|
|
134
|
-
|
|
135
|
-
_,
|
|
136
|
-
|
|
137
|
-
|
|
149
|
+
b, s1, n1, d = query.shape
|
|
150
|
+
_, s2, n2, _ = key.shape
|
|
151
|
+
h1 = n1 * d
|
|
152
|
+
h2 = n2 * d
|
|
138
153
|
elif input_layout == "BNSD":
|
|
139
|
-
|
|
140
|
-
_,
|
|
141
|
-
|
|
142
|
-
|
|
154
|
+
b, n1, s1, d = query.shape
|
|
155
|
+
_, n2, s2, _ = key.shape
|
|
156
|
+
h1 = n1 * d
|
|
157
|
+
h2 = n2 * d
|
|
143
158
|
except Exception as e:
|
|
144
159
|
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
145
160
|
|
|
146
|
-
if
|
|
147
|
-
raise ValueError(f"Value
|
|
148
|
-
|
|
149
|
-
|
|
161
|
+
if d == 0:
|
|
162
|
+
raise ValueError(f"Value d must be non-zero.")
|
|
163
|
+
_dtype = query.dtype
|
|
164
|
+
ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
|
|
165
|
+
return ret
|
|
150
166
|
|
|
151
167
|
|
|
152
168
|
def convert_from_bnsd(_input, input_layout):
|
|
@@ -186,24 +202,26 @@ def convert_to_bnsd(_input, n, input_layout):
|
|
|
186
202
|
return out.to(gtype)
|
|
187
203
|
|
|
188
204
|
|
|
189
|
-
def generate_atten_mask(
|
|
205
|
+
def generate_atten_mask(*args):
|
|
190
206
|
"""
|
|
191
207
|
# 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
|
|
192
208
|
===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
|
|
193
209
|
"""
|
|
194
|
-
|
|
210
|
+
|
|
211
|
+
sparse_mode, atten_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
|
|
212
|
+
shape = [s1, s2]
|
|
195
213
|
|
|
196
214
|
if atten_mask is not None:
|
|
197
215
|
# 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
|
|
198
216
|
if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
|
|
199
|
-
logger.info(f"
|
|
217
|
+
logger.info(f"s1: {s1}, s2:{s2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
|
|
200
218
|
|
|
201
219
|
if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
|
|
202
220
|
if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
|
|
203
221
|
if sparse_mode == 2:
|
|
204
222
|
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
205
223
|
elif sparse_mode == 3:
|
|
206
|
-
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=
|
|
224
|
+
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
207
225
|
elif sparse_mode == 4:
|
|
208
226
|
atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
209
227
|
atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
@@ -215,14 +233,14 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
|
|
|
215
233
|
|
|
216
234
|
if atten_mask is not None:
|
|
217
235
|
if atten_mask.dim() == 2:
|
|
218
|
-
if atten_mask.shape[0] !=
|
|
236
|
+
if atten_mask.shape[0] != s1 or atten_mask.shape[1] != s2:
|
|
219
237
|
raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
|
|
220
|
-
shape = [
|
|
238
|
+
shape = [s1, s2]
|
|
221
239
|
elif atten_mask.dim() == 4:
|
|
222
240
|
if atten_mask.shape[1] == 1:
|
|
223
|
-
shape = [
|
|
241
|
+
shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
|
|
224
242
|
else:
|
|
225
|
-
shape = [
|
|
243
|
+
shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
|
|
226
244
|
|
|
227
245
|
if sparse_mode == 0:
|
|
228
246
|
atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
@@ -233,7 +251,7 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
|
|
|
233
251
|
elif sparse_mode == 2:
|
|
234
252
|
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
235
253
|
elif sparse_mode == 3:
|
|
236
|
-
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=
|
|
254
|
+
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
237
255
|
elif sparse_mode == 4:
|
|
238
256
|
atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
239
257
|
atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
@@ -243,11 +261,11 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
|
|
|
243
261
|
return atten_mask.to(dtype)
|
|
244
262
|
|
|
245
263
|
|
|
246
|
-
def generate_kv(key, value,
|
|
264
|
+
def generate_kv(key, value, n1, n2):
|
|
247
265
|
# N不等长适配by cdy
|
|
248
|
-
if not (
|
|
249
|
-
k_new = broadcast_kv(
|
|
250
|
-
v_new = broadcast_kv(
|
|
266
|
+
if not (n1 == n2):
|
|
267
|
+
k_new = broadcast_kv(n1, n2, key, key.dtype)
|
|
268
|
+
v_new = broadcast_kv(n1, n2, value, value.dtype)
|
|
251
269
|
else:
|
|
252
270
|
k_new = key
|
|
253
271
|
v_new = value
|
|
@@ -305,26 +323,30 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
|
305
323
|
head_num = get_head_num(*args, **kwargs)
|
|
306
324
|
input_layout = get_input_layout(*args, **kwargs)
|
|
307
325
|
|
|
308
|
-
|
|
309
|
-
if
|
|
310
|
-
logger.debug(f"running case : BNSD = {
|
|
326
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
327
|
+
if n1 == n2 and s1 == s2:
|
|
328
|
+
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
311
329
|
else:
|
|
312
|
-
logger.debug(f"running case: BNSD = {
|
|
313
|
-
if not (
|
|
314
|
-
raise ValueError(f"N1与N2不匹配,请检查:
|
|
315
|
-
|
|
316
|
-
dims_kwargs = {
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
330
|
+
logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
331
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
332
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
333
|
+
|
|
334
|
+
dims_kwargs = {
|
|
335
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
336
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
new_kwargs = {
|
|
340
|
+
"keep_prob": 1,
|
|
341
|
+
"scale": kwargs.get("scale", 1 / (d ** 0.5)),
|
|
342
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
343
|
+
"prefix": kwargs.get("prefix"),
|
|
344
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
345
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
346
|
+
"pse": kwargs.get("pse"),
|
|
347
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
348
|
+
"atten_mask": kwargs.get("atten_mask")
|
|
349
|
+
}
|
|
328
350
|
|
|
329
351
|
return args, dims_kwargs, new_kwargs
|
|
330
352
|
|
|
@@ -333,33 +355,37 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
|
333
355
|
if len(args) != 6:
|
|
334
356
|
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
335
357
|
|
|
336
|
-
|
|
337
|
-
if
|
|
338
|
-
logger.info(f"running case :
|
|
358
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
359
|
+
if n1 == n2 and s1 == s2:
|
|
360
|
+
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
339
361
|
else:
|
|
340
|
-
logger.info(f"running case:
|
|
341
|
-
if not (
|
|
342
|
-
raise ValueError(f"N1与N2不匹配,请检查:
|
|
343
|
-
|
|
344
|
-
dims_kwargs = {
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
362
|
+
logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
363
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
364
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
365
|
+
|
|
366
|
+
dims_kwargs = {
|
|
367
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
368
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
new_kwargs = {
|
|
372
|
+
"keep_prob": 1,
|
|
373
|
+
"scale_value": kwargs.get("scale_value", 1 / (d ** 0.5)),
|
|
374
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
375
|
+
"prefix": kwargs.get("prefix"),
|
|
376
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
377
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
378
|
+
"pse": kwargs.get("pse"),
|
|
379
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
380
|
+
"softmax_max": kwargs.get("softmax_max"),
|
|
381
|
+
"softmax_sum": kwargs.get("softmax_sum"),
|
|
382
|
+
"softmax_in": kwargs.get("softmax_in"),
|
|
383
|
+
"attention_in": kwargs.get("attention_in"),
|
|
384
|
+
"seed": kwargs.get("seed", 0),
|
|
385
|
+
"offset": kwargs.get("offset", 0),
|
|
386
|
+
"numels": kwargs.get("numels", 0),
|
|
387
|
+
"atten_mask": kwargs.get("atten_mask")
|
|
388
|
+
}
|
|
363
389
|
|
|
364
390
|
return args, dims_kwargs, new_kwargs
|
|
365
391
|
|
|
@@ -368,12 +394,12 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
368
394
|
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
|
|
369
395
|
query, key, value = new_args[0], new_args[1], new_args[2]
|
|
370
396
|
input_layout = get_input_layout(*args, **kwargs)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
397
|
+
n1 = dims_kwargs.get("n1")
|
|
398
|
+
n2 = dims_kwargs.get("n2")
|
|
399
|
+
s1 = dims_kwargs.get("s1")
|
|
400
|
+
s2 = dims_kwargs.get("s2")
|
|
401
|
+
b = dims_kwargs.get("b")
|
|
402
|
+
dtype = dims_kwargs.get("dtype")
|
|
377
403
|
atten_mask = new_kwargs.get("atten_mask")
|
|
378
404
|
keep_prob = new_kwargs.get("keep_prob")
|
|
379
405
|
sparse_mode = new_kwargs.get("sparse_mode")
|
|
@@ -381,12 +407,12 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
381
407
|
next_tockens = new_kwargs.get("next_tockens")
|
|
382
408
|
pse = new_kwargs.get("pse")
|
|
383
409
|
scale = new_kwargs.get("scale")
|
|
384
|
-
|
|
385
|
-
atten_mask = generate_atten_mask(
|
|
386
|
-
query = convert_to_bnsd(query,
|
|
387
|
-
key = convert_to_bnsd(key,
|
|
388
|
-
value = convert_to_bnsd(value,
|
|
389
|
-
k_new, v_new = generate_kv(key, value,
|
|
410
|
+
args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
411
|
+
atten_mask = generate_atten_mask(*args_temp)
|
|
412
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
413
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
414
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
415
|
+
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
390
416
|
out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
|
|
391
417
|
drop_mask=None, atten_mask=atten_mask,
|
|
392
418
|
pse=pse, scale=scale,
|
|
@@ -403,13 +429,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
403
429
|
# dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
|
|
404
430
|
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
|
|
405
431
|
query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
432
|
+
n1 = dims_kwargs.get("n1")
|
|
433
|
+
n2 = dims_kwargs.get("n2")
|
|
434
|
+
s1 = dims_kwargs.get("s1")
|
|
435
|
+
s2 = dims_kwargs.get("s2")
|
|
436
|
+
b = dims_kwargs.get("b")
|
|
437
|
+
d = dims_kwargs.get("d")
|
|
438
|
+
dtype = dims_kwargs.get("dtype")
|
|
413
439
|
atten_mask = new_kwargs.get("atten_mask")
|
|
414
440
|
keep_prob = new_kwargs.get("keep_prob")
|
|
415
441
|
sparse_mode = new_kwargs.get("sparse_mode")
|
|
@@ -420,12 +446,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
420
446
|
softmax_sum = new_kwargs.get("softmax_sum")
|
|
421
447
|
scale_value = new_kwargs.get("scale_value")
|
|
422
448
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
449
|
+
args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
450
|
+
atten_mask = generate_atten_mask(*args_temp)
|
|
451
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
452
|
+
dx = convert_to_bnsd(dx, n1, input_layout)
|
|
453
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
454
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
455
|
+
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
429
456
|
|
|
430
457
|
if softmax_build_mode == "QKV":
|
|
431
458
|
softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
|
|
@@ -435,12 +462,12 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
435
462
|
dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
|
|
436
463
|
|
|
437
464
|
# N不等长适配by cdy
|
|
438
|
-
if not (
|
|
439
|
-
if
|
|
440
|
-
raise ValueError("dims_kwargs.
|
|
441
|
-
|
|
442
|
-
dk = torch.sum(dk.reshape(
|
|
443
|
-
dv = torch.sum(dv.reshape(
|
|
465
|
+
if not (n1 == n2):
|
|
466
|
+
if n2 == 0:
|
|
467
|
+
raise ValueError("dims_kwargs.n2 must be non-zero.")
|
|
468
|
+
g = int(n1 / n2)
|
|
469
|
+
dk = torch.sum(dk.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
|
|
470
|
+
dv = torch.sum(dv.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
|
|
444
471
|
|
|
445
472
|
if dq.dim() == 5:
|
|
446
473
|
dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
|
|
@@ -460,12 +487,12 @@ def is_attention_off_due_to_mask(atten_mask_dtype):
|
|
|
460
487
|
return not atten_mask_dtype
|
|
461
488
|
|
|
462
489
|
|
|
463
|
-
def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens,
|
|
464
|
-
return sparse_mode == 4 and (next_tockens != 0 or pre_tockens <
|
|
490
|
+
def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1):
|
|
491
|
+
return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < s1)
|
|
465
492
|
|
|
466
493
|
|
|
467
|
-
def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens,
|
|
468
|
-
return sparse_mode == 0 and pre_tockens >=
|
|
494
|
+
def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2):
|
|
495
|
+
return sparse_mode == 0 and pre_tockens >= s1 and next_tockens >= s2
|
|
469
496
|
|
|
470
497
|
|
|
471
498
|
def gpu_fusion_attention(*args, **kwargs):
|
|
@@ -474,11 +501,11 @@ def gpu_fusion_attention(*args, **kwargs):
|
|
|
474
501
|
query, key, value = new_args[0], new_args[1], new_args[2]
|
|
475
502
|
keep_prob = new_kwargs.get("keep_prob", 1.0)
|
|
476
503
|
scale = new_kwargs.get("scale")
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
504
|
+
n1 = dims_kwargs.get("n1")
|
|
505
|
+
n2 = dims_kwargs.get("n2")
|
|
506
|
+
s1 = dims_kwargs.get("s1")
|
|
507
|
+
s2 = dims_kwargs.get("s2")
|
|
508
|
+
b = dims_kwargs.get("b")
|
|
482
509
|
pse = new_kwargs.get("pse")
|
|
483
510
|
sparse_mode = new_kwargs.get("sparse_mode")
|
|
484
511
|
pre_tockens = new_kwargs.get("pre_tockens")
|
|
@@ -488,22 +515,24 @@ def gpu_fusion_attention(*args, **kwargs):
|
|
|
488
515
|
pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
|
|
489
516
|
next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
|
|
490
517
|
atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
|
|
491
|
-
|
|
492
|
-
|
|
518
|
+
is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1) or
|
|
519
|
+
is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2))
|
|
493
520
|
causal_switch = not atten_off
|
|
494
521
|
if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
|
|
495
522
|
window_left = pre_tockens
|
|
496
523
|
window_right = next_tockens
|
|
497
524
|
else:
|
|
498
525
|
pre_tockens = next_tockens = CompareConst.MAX_TOKENS
|
|
499
|
-
window_left = pre_tockens -
|
|
500
|
-
window_right = next_tockens +
|
|
501
|
-
|
|
526
|
+
window_left = pre_tockens - s1 + s2
|
|
527
|
+
window_right = next_tockens + s1 - s2
|
|
528
|
+
|
|
502
529
|
if pse is not None:
|
|
503
|
-
alibi_slopes = torch.rand(
|
|
530
|
+
alibi_slopes = torch.rand(b, n1, dtype=torch.float32) * 0.3
|
|
504
531
|
else:
|
|
505
532
|
alibi_slopes = None
|
|
506
|
-
|
|
507
|
-
out = flash_attn_func(
|
|
508
|
-
|
|
533
|
+
|
|
534
|
+
out = flash_attn_func(
|
|
535
|
+
query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
|
|
536
|
+
window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
|
|
537
|
+
)
|
|
509
538
|
return out, Const.NONE, Const.NONE
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -25,15 +40,19 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
25
40
|
x_shape = x.shape
|
|
26
41
|
h = x.float()
|
|
27
42
|
grad = dy_tensor.float()
|
|
28
|
-
condition_1 = (
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
43
|
+
condition_1 = (r1_shape[0] == 1
|
|
44
|
+
and r1_shape[1] == x_shape[1]
|
|
45
|
+
and r1_shape[2] == 1
|
|
46
|
+
and r1_shape[3] == x_shape[3])
|
|
47
|
+
condition_2 = (r1_shape[0] == 1
|
|
48
|
+
and r1_shape[1] == 1
|
|
49
|
+
and r1_shape[2] == x_shape[2]
|
|
50
|
+
and r1_shape[3] == x_shape[3])
|
|
51
|
+
condition_3 = (r1_shape[0] == x_shape[0]
|
|
52
|
+
and r1_shape[1] == 1
|
|
53
|
+
and r1_shape[2] == 1
|
|
54
|
+
and r1_shape[3] == x_shape[3])
|
|
55
|
+
|
|
37
56
|
if condition_1:
|
|
38
57
|
for i in range(x_shape[0]):
|
|
39
58
|
for j in range(x_shape[2]):
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,16 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
4
19
|
def npu_swiglu(x, dim=-1):
|
|
5
20
|
tensor_dtype = x.dtype
|
|
6
21
|
|
|
7
|
-
|
|
22
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
8
23
|
if tensor_dtype == torch.float32:
|
|
9
|
-
tensor_scalar = torch.sigmoid(torch.mul(
|
|
10
|
-
output_data = torch.mul(torch.mul(tensor_scalar,
|
|
24
|
+
tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
|
|
25
|
+
output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
|
|
11
26
|
else:
|
|
12
|
-
tensor_self_float =
|
|
13
|
-
tensor_other_float =
|
|
27
|
+
tensor_self_float = in_tensors[0].type(torch.float)
|
|
28
|
+
tensor_other_float = in_tensors[1].type(torch.float)
|
|
14
29
|
tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
|
|
15
30
|
torch.float32) * tensor_other_float
|
|
16
31
|
output_data = tensor_out_float.type(tensor_dtype)
|