mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from einops import rearrange
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from msprobe.pytorch.common.utils import logger
|
|
25
|
+
|
|
26
|
+
GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
|
|
27
|
+
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
28
|
+
|
|
29
|
+
FaForwardParams = namedtuple("FaForwardParams",
|
|
30
|
+
["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"])
|
|
31
|
+
FaBackwardParams = namedtuple("FaBackwardParams",
|
|
32
|
+
["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"])
|
|
33
|
+
RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
|
|
34
|
+
["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def softmax_forward(x):
|
|
38
|
+
x_max = torch.max(x, dim=-1, keepdims=True)[0]
|
|
39
|
+
x_sub = x.sub(x_max)
|
|
40
|
+
y = torch.exp(x_sub)
|
|
41
|
+
x_sum = y.sum(dim=-1, keepdims=True)
|
|
42
|
+
res = y.div(x_sum)
|
|
43
|
+
return res, x_max, x_sum
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def softmax_grad(dp, softmax_res):
|
|
47
|
+
muls = dp * softmax_res
|
|
48
|
+
muls_r = muls.sum(dim=-1, keepdims=True)
|
|
49
|
+
sub_r = dp - muls_r
|
|
50
|
+
res = sub_r * softmax_res
|
|
51
|
+
return res
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
|
|
55
|
+
# 检查维度
|
|
56
|
+
if kv_tensor.dim() != 4:
|
|
57
|
+
raise ValueError(f"broadcast_kv: kv_tensor 必须是 4 维 (B, N_kv, S, D),但得到 {kv_tensor.shape}")
|
|
58
|
+
if num_kv_heads == 0 or num_kv_heads > num_heads:
|
|
59
|
+
raise ValueError("broadcast_kv: num_kv_heads 必须大于 0 且不超过 num_heads。")
|
|
60
|
+
if num_heads % num_kv_heads != 0:
|
|
61
|
+
raise ValueError(f"broadcast_kv: num_heads({num_heads}) 必须能被 num_kv_heads({num_kv_heads}) 整除。")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
factor = num_heads // num_kv_heads
|
|
65
|
+
kv_shape = kv_tensor.shape
|
|
66
|
+
b = kv_shape[0]
|
|
67
|
+
s = kv_shape[2]
|
|
68
|
+
d = kv_shape[3]
|
|
69
|
+
kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
|
|
70
|
+
for i in range(num_heads):
|
|
71
|
+
j = i // factor
|
|
72
|
+
kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
|
|
73
|
+
return kv_res
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def calculate_qk(q, k, attn_mask, pse, scalar_value):
|
|
77
|
+
# 基本形状检查
|
|
78
|
+
if q.dim() < 4 or k.dim() < 4:
|
|
79
|
+
raise ValueError(f"calculate_qk: q,k 必须至少 4 维,q={q.dim()},k={k.dim()}")
|
|
80
|
+
# 检查 head_dim 一致性
|
|
81
|
+
if q.size(-1) != k.size(-1):
|
|
82
|
+
raise ValueError(f"calculate_qk: q.head_dim({q.size(-1)}) != k.head_dim({k.size(-1)})")
|
|
83
|
+
|
|
84
|
+
if k.dim() != 4:
|
|
85
|
+
raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})")
|
|
86
|
+
|
|
87
|
+
if k.dim() == 3:
|
|
88
|
+
k = k.unsqueeze(1) # 在head维度扩展
|
|
89
|
+
|
|
90
|
+
if pse is None or len(pse.shape) == 0:
|
|
91
|
+
qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value)
|
|
92
|
+
else:
|
|
93
|
+
qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value)
|
|
94
|
+
if attn_mask is None or len(attn_mask.shape) == 0:
|
|
95
|
+
return qk
|
|
96
|
+
else:
|
|
97
|
+
qk = qk + attn_mask.bool() * (-40000.0) # -10000
|
|
98
|
+
return qk
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def fusion_attention_forward(forward_params):
|
|
102
|
+
q = forward_params.q
|
|
103
|
+
k = forward_params.k
|
|
104
|
+
v = forward_params.v
|
|
105
|
+
drop_mask = forward_params.drop_mask
|
|
106
|
+
attn_mask = forward_params.attn_mask
|
|
107
|
+
pse = forward_params.pse
|
|
108
|
+
scalar_value = forward_params.scalar_value
|
|
109
|
+
keep_prob = forward_params.keep_prob
|
|
110
|
+
|
|
111
|
+
# 拦截 keep_prob 为 0 的情况,防止除零
|
|
112
|
+
if keep_prob == 0:
|
|
113
|
+
raise ValueError("fusion_attention_forward: keep_prob 不能为 0,避免除零错误。")
|
|
114
|
+
|
|
115
|
+
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
116
|
+
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
117
|
+
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
118
|
+
drop_res = softmax_res
|
|
119
|
+
else:
|
|
120
|
+
drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
|
|
121
|
+
y = torch.matmul(drop_res, v)
|
|
122
|
+
return y, softmax_max, softmax_sum
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def fusion_attention_backward(backward_params):
|
|
126
|
+
dx = backward_params.dx
|
|
127
|
+
q = backward_params.q
|
|
128
|
+
k = backward_params.k
|
|
129
|
+
v = backward_params.v
|
|
130
|
+
softmax_res = backward_params.softmax_res
|
|
131
|
+
drop_mask = backward_params.drop_mask
|
|
132
|
+
pse = backward_params.pse
|
|
133
|
+
scalar_value = backward_params.scalar_value
|
|
134
|
+
keep_prob = backward_params.keep_prob
|
|
135
|
+
|
|
136
|
+
# 拦截 keep_prob 为 0 的情况,防止除零
|
|
137
|
+
if keep_prob == 0:
|
|
138
|
+
raise ValueError("fusion_attention_backward: keep_prob 不能为 0,避免除零错误。")
|
|
139
|
+
|
|
140
|
+
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
141
|
+
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
142
|
+
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
143
|
+
dp_drop = dp
|
|
144
|
+
else:
|
|
145
|
+
drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
|
|
146
|
+
dp_drop = dp * drop_mask * (1.0 / keep_prob)
|
|
147
|
+
dv = torch.matmul(drop_res, dx)
|
|
148
|
+
softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value)
|
|
149
|
+
dq = torch.matmul(softmax_grad_res, k)
|
|
150
|
+
dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
|
|
151
|
+
return dq, dk, dv
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def parse_bsnd_args(query, key, head_num, input_layout):
|
|
155
|
+
supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
|
|
156
|
+
b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
|
|
157
|
+
|
|
158
|
+
if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
|
|
159
|
+
raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
|
|
160
|
+
|
|
161
|
+
if input_layout == "TND":
|
|
162
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
163
|
+
|
|
164
|
+
# 防止 head_num 为 0
|
|
165
|
+
if n1 == 0:
|
|
166
|
+
raise ValueError("parse_bsnd_args: head_num (n1) 不能为 0,避免除零错误。")
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
if input_layout == "BSH":
|
|
170
|
+
b, s1, h1 = query.shape
|
|
171
|
+
_, s2, h2 = key.shape
|
|
172
|
+
d = h1 // n1
|
|
173
|
+
# 拦截 d 为 0 的情况
|
|
174
|
+
if d == 0:
|
|
175
|
+
raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
|
|
176
|
+
n2 = h2 // d
|
|
177
|
+
elif input_layout == "SBH":
|
|
178
|
+
s1, b, h1 = query.shape
|
|
179
|
+
s2, _, h2 = key.shape
|
|
180
|
+
d = h1 // n1
|
|
181
|
+
if d == 0:
|
|
182
|
+
raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
|
|
183
|
+
n2 = h2 // d
|
|
184
|
+
elif input_layout == "BSND":
|
|
185
|
+
b, s1, n1, d = query.shape
|
|
186
|
+
_, s2, n2, _ = key.shape
|
|
187
|
+
if d == 0:
|
|
188
|
+
raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
|
|
189
|
+
h1 = n1 * d
|
|
190
|
+
h2 = n2 * d
|
|
191
|
+
elif input_layout == "BNSD":
|
|
192
|
+
b, n1, s1, d = query.shape
|
|
193
|
+
_, n2, s2, _ = key.shape
|
|
194
|
+
if d == 0:
|
|
195
|
+
raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
|
|
196
|
+
h1 = n1 * d
|
|
197
|
+
h2 = n2 * d
|
|
198
|
+
except Exception as e:
|
|
199
|
+
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
200
|
+
|
|
201
|
+
ret = (b, s1, s2, n1, n2, d, h1, h2, query.dtype)
|
|
202
|
+
return ret
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def convert_from_bnsd(_input, input_layout):
|
|
206
|
+
"""
|
|
207
|
+
transform qkv from bnsd to input_layout.
|
|
208
|
+
B: batch_size
|
|
209
|
+
S: sequence_length
|
|
210
|
+
N: num_heads
|
|
211
|
+
D: head_dim
|
|
212
|
+
Args:
|
|
213
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D)
|
|
214
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
215
|
+
Returns:
|
|
216
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
217
|
+
"""
|
|
218
|
+
if input_layout == "BSH":
|
|
219
|
+
# (B,N,S,D)=>(B,S,N*D)
|
|
220
|
+
out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
|
|
221
|
+
elif input_layout == "SBH":
|
|
222
|
+
# (B,N,S,D)=>(S,B,N*D)
|
|
223
|
+
out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
|
|
224
|
+
elif input_layout == "BSND":
|
|
225
|
+
# (B,N,S,D)=>(B,S,N,D)
|
|
226
|
+
out = rearrange(_input, 'b n s d -> b s n d').contiguous()
|
|
227
|
+
elif input_layout == "TND":
|
|
228
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
229
|
+
else:
|
|
230
|
+
out = _input
|
|
231
|
+
return out
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def convert_to_bnsd(_input, n, input_layout):
|
|
235
|
+
"""
|
|
236
|
+
transform qkv from input_layout to bnsd.
|
|
237
|
+
B: batch_size
|
|
238
|
+
S: sequence_length
|
|
239
|
+
N: num_heads
|
|
240
|
+
D: head_dim
|
|
241
|
+
Args:
|
|
242
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
243
|
+
n (int): num_heads
|
|
244
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
245
|
+
Returns:
|
|
246
|
+
tensor of shape (B,N,S,D)
|
|
247
|
+
"""
|
|
248
|
+
if input_layout == "BSH":
|
|
249
|
+
# (B,S,N*D)=>(B,N,S,D)
|
|
250
|
+
out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
|
|
251
|
+
elif input_layout == "SBH":
|
|
252
|
+
# (S,B,N*D)=>(B,N,S,D)
|
|
253
|
+
out = rearrange(_input, 's b (n d) -> b n s d', n=n)
|
|
254
|
+
elif input_layout == "BSND":
|
|
255
|
+
# (B,S,N,D)=>(B,N,S,D)
|
|
256
|
+
out = rearrange(_input, 'b s n d -> b n s d', n=n)
|
|
257
|
+
elif input_layout == "TND":
|
|
258
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
259
|
+
else:
|
|
260
|
+
out = _input
|
|
261
|
+
if out.dim() != 4:
|
|
262
|
+
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
263
|
+
return out.to(GTYPE)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def generate_attn_mask(*args):
|
|
267
|
+
"""
|
|
268
|
+
# 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
|
|
269
|
+
===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
|
|
273
|
+
shape = [s1, s2]
|
|
274
|
+
|
|
275
|
+
if attn_mask is not None:
|
|
276
|
+
# 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
|
|
277
|
+
if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
|
|
278
|
+
logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}")
|
|
279
|
+
|
|
280
|
+
if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048:
|
|
281
|
+
if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)):
|
|
282
|
+
if sparse_mode == 2:
|
|
283
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
284
|
+
elif sparse_mode == 3:
|
|
285
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
286
|
+
elif sparse_mode == 4:
|
|
287
|
+
attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
288
|
+
attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
289
|
+
attn_mask = attn_mask_u + attn_mask_l
|
|
290
|
+
logger.debug(f"反向转换attn_mask {attn_mask.shape}")
|
|
291
|
+
return attn_mask.to(dtype)
|
|
292
|
+
|
|
293
|
+
return attn_mask.to(dtype)
|
|
294
|
+
|
|
295
|
+
if attn_mask is not None:
|
|
296
|
+
if attn_mask.dim() == 2:
|
|
297
|
+
if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2:
|
|
298
|
+
raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}")
|
|
299
|
+
shape = [s1, s2]
|
|
300
|
+
elif attn_mask.dim() == 4:
|
|
301
|
+
if attn_mask.shape[1] == 1:
|
|
302
|
+
shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
|
|
303
|
+
else:
|
|
304
|
+
shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
|
|
305
|
+
|
|
306
|
+
if sparse_mode == 0:
|
|
307
|
+
attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
308
|
+
attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
309
|
+
attn_mask = attn_mask_u + attn_mask_l
|
|
310
|
+
elif sparse_mode == 1: # no sparse
|
|
311
|
+
attn_mask = torch.from_numpy(np.zeros(shape))
|
|
312
|
+
elif sparse_mode == 2:
|
|
313
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
314
|
+
elif sparse_mode == 3:
|
|
315
|
+
attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
316
|
+
elif sparse_mode == 4:
|
|
317
|
+
attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
318
|
+
attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
319
|
+
attn_mask = attn_mask_u + attn_mask_l
|
|
320
|
+
# 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS,
|
|
321
|
+
# 因此可以认为FA的输入已经是正确的attn_mask了
|
|
322
|
+
return attn_mask.to(dtype)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def generate_kv(key, value, n1, n2):
|
|
326
|
+
# N不等长适配by cdy
|
|
327
|
+
if not (n1 == n2):
|
|
328
|
+
k_new = broadcast_kv(n1, n2, key, key.dtype)
|
|
329
|
+
v_new = broadcast_kv(n1, n2, value, value.dtype)
|
|
330
|
+
else:
|
|
331
|
+
k_new = key
|
|
332
|
+
v_new = value
|
|
333
|
+
return k_new, v_new
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value):
|
|
337
|
+
"""
|
|
338
|
+
attention = softmax(QK^T/sqrt(d))V
|
|
339
|
+
softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
|
|
340
|
+
"""
|
|
341
|
+
logger.info("Using QKV to rebuild original softmax")
|
|
342
|
+
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
343
|
+
softmax_res, _, _ = softmax_forward(qk)
|
|
344
|
+
return softmax_res
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def rebuild_softmax_by_max_sum(softmax_params):
|
|
348
|
+
"""
|
|
349
|
+
attention = softmax(QK^T/sqrt(d))V
|
|
350
|
+
softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
|
|
351
|
+
"""
|
|
352
|
+
q = softmax_params.q
|
|
353
|
+
k = softmax_params.k
|
|
354
|
+
attn_mask = softmax_params.attn_mask
|
|
355
|
+
pse = softmax_params.pse
|
|
356
|
+
scalar_value = softmax_params.scalar_value
|
|
357
|
+
softmax_max = softmax_params.softmax_max
|
|
358
|
+
softmax_sum = softmax_params.softmax_sum
|
|
359
|
+
logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
|
|
360
|
+
|
|
361
|
+
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
362
|
+
if softmax_max.shape[-1] == 0:
|
|
363
|
+
raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
|
|
364
|
+
repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
|
|
365
|
+
softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
|
|
366
|
+
softmax_sum.repeat(1, 1, 1, repeat_dim))
|
|
367
|
+
return softmax_res
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def get_head_num(*args, **kwargs):
|
|
371
|
+
if kwargs.get("head_num", None):
|
|
372
|
+
head_num = kwargs.get("head_num")
|
|
373
|
+
elif len(args) >= 4:
|
|
374
|
+
head_num = args[3]
|
|
375
|
+
else:
|
|
376
|
+
raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
|
|
377
|
+
return head_num
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def get_input_layout(*args, **kwargs):
|
|
381
|
+
if kwargs.get("input_layout", None):
|
|
382
|
+
input_layout = kwargs.get("input_layout")
|
|
383
|
+
elif len(args) >= 5:
|
|
384
|
+
input_layout = args[4]
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
|
|
387
|
+
return input_layout
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
391
|
+
if len(args) < 2:
|
|
392
|
+
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should be greater than or equal to 2.")
|
|
393
|
+
|
|
394
|
+
# query, key, value, head_num, input_layout
|
|
395
|
+
head_num = get_head_num(*args, **kwargs)
|
|
396
|
+
input_layout = get_input_layout(*args, **kwargs)
|
|
397
|
+
|
|
398
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
399
|
+
# 此处 d 已在 parse_bsnd_args 中检查为非零
|
|
400
|
+
if n1 == n2 and s1 == s2:
|
|
401
|
+
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
402
|
+
else:
|
|
403
|
+
logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
404
|
+
if n2 == 0:
|
|
405
|
+
raise ValueError("n2 不能为 0,避免除零错误。")
|
|
406
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
407
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
408
|
+
|
|
409
|
+
dims_kwargs = {
|
|
410
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
411
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
412
|
+
}
|
|
413
|
+
new_kwargs = {
|
|
414
|
+
"keep_prob": 1, # 注意:如果外部传入 keep_prob 为 0,也会在 fusion_attention_forward 中捕获
|
|
415
|
+
"scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)),
|
|
416
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
417
|
+
"prefix": kwargs.get("prefix"),
|
|
418
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
419
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
420
|
+
"pse": kwargs.get("pse"),
|
|
421
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
422
|
+
"attn_mask": kwargs.get("attn_mask")
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
return args, dims_kwargs, new_kwargs
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
429
|
+
if len(args) != 6:
|
|
430
|
+
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
431
|
+
|
|
432
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
433
|
+
# 此处 d 已在 parse_bsnd_args 中检查为非零
|
|
434
|
+
if n1 == n2 and s1 == s2:
|
|
435
|
+
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
436
|
+
else:
|
|
437
|
+
logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
438
|
+
if n2 == 0:
|
|
439
|
+
raise ValueError("n2 不能为 0,避免除零错误。")
|
|
440
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
441
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
442
|
+
|
|
443
|
+
dims_kwargs = {
|
|
444
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
445
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
new_kwargs = {
|
|
449
|
+
"keep_prob": 1, # 同上,fusion_attention_backward 内会拦截 keep_prob 为 0 的情况
|
|
450
|
+
"scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)),
|
|
451
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
452
|
+
"prefix": kwargs.get("prefix"),
|
|
453
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
454
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
455
|
+
"pse": kwargs.get("pse"),
|
|
456
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
457
|
+
"softmax_max": kwargs.get("softmax_max"),
|
|
458
|
+
"softmax_sum": kwargs.get("softmax_sum"),
|
|
459
|
+
"softmax_in": kwargs.get("softmax_in"),
|
|
460
|
+
"attention_in": kwargs.get("attention_in"),
|
|
461
|
+
"seed": kwargs.get("seed", 0),
|
|
462
|
+
"offset": kwargs.get("offset", 0),
|
|
463
|
+
"numels": kwargs.get("numels", 0),
|
|
464
|
+
"attn_mask": kwargs.get("attn_mask")
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
return args, dims_kwargs, new_kwargs
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class FlashAttentionScore(nn.Module):
|
|
471
|
+
def __init__(self):
|
|
472
|
+
super(FlashAttentionScore, self).__init__()
|
|
473
|
+
# You can initialize any parameters here if necessary
|
|
474
|
+
|
|
475
|
+
def forward(self, *inputs, **kwargs):
|
|
476
|
+
# Extract the inputs for the attention calculation
|
|
477
|
+
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs)
|
|
478
|
+
query, key, value = new_args[0], new_args[1], new_args[2]
|
|
479
|
+
|
|
480
|
+
input_layout = get_input_layout(*inputs, **kwargs)
|
|
481
|
+
|
|
482
|
+
n1 = dims_kwargs.get("n1")
|
|
483
|
+
n2 = dims_kwargs.get("n2")
|
|
484
|
+
s1 = dims_kwargs.get("s1")
|
|
485
|
+
s2 = dims_kwargs.get("s2")
|
|
486
|
+
b = dims_kwargs.get("b")
|
|
487
|
+
dtype = dims_kwargs.get("dtype")
|
|
488
|
+
attn_mask = new_kwargs.get("attn_mask")
|
|
489
|
+
keep_prob = new_kwargs.get("keep_prob")
|
|
490
|
+
sparse_mode = new_kwargs.get("sparse_mode")
|
|
491
|
+
pre_tockens = new_kwargs.get("pre_tockens")
|
|
492
|
+
next_tockens = new_kwargs.get("next_tokens")
|
|
493
|
+
pse = new_kwargs.get("real_shift")
|
|
494
|
+
scalar_value = new_kwargs.get("scalar_value")
|
|
495
|
+
|
|
496
|
+
args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
497
|
+
|
|
498
|
+
attn_mask = generate_attn_mask(*args_temp)
|
|
499
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
500
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
501
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
502
|
+
|
|
503
|
+
forward_params = FaForwardParams(
|
|
504
|
+
q=query,
|
|
505
|
+
k=key,
|
|
506
|
+
v=value,
|
|
507
|
+
drop_mask=None,
|
|
508
|
+
attn_mask=attn_mask,
|
|
509
|
+
pse=pse,
|
|
510
|
+
scalar_value=scalar_value,
|
|
511
|
+
keep_prob=keep_prob
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
|
|
515
|
+
|
|
516
|
+
# If output dimension is 5, reshape accordingly
|
|
517
|
+
if out_golden.dim() == 5:
|
|
518
|
+
out_golden = out_golden.reshape(out_golden.size(0),
|
|
519
|
+
out_golden.size(1) * out_golden.size(2),
|
|
520
|
+
out_golden.size(3), out_golden.size(4))
|
|
521
|
+
|
|
522
|
+
out_golden = convert_from_bnsd(out_golden, input_layout)
|
|
523
|
+
|
|
524
|
+
# Ensure the output matches the desired layout
|
|
525
|
+
out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
|
|
526
|
+
|
|
527
|
+
return out_golden
|
|
528
|
+
|
|
529
|
+
def backward(self, *inputs, **kwargs):
|
|
530
|
+
# The backward pass will be similar to what was described for the gradient computation
|
|
531
|
+
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs)
|
|
532
|
+
query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
|
|
533
|
+
n1 = dims_kwargs.get("n1")
|
|
534
|
+
n2 = dims_kwargs.get("n2")
|
|
535
|
+
s1 = dims_kwargs.get("s1")
|
|
536
|
+
s2 = dims_kwargs.get("s2")
|
|
537
|
+
b = dims_kwargs.get("b")
|
|
538
|
+
dtype = dims_kwargs.get("dtype")
|
|
539
|
+
attn_mask = new_kwargs.get("attn_mask")
|
|
540
|
+
keep_prob = new_kwargs.get("keep_prob")
|
|
541
|
+
sparse_mode = new_kwargs.get("sparse_mode")
|
|
542
|
+
pre_tockens = new_kwargs.get("pre_tockens")
|
|
543
|
+
next_tockens = new_kwargs.get("next_tockens")
|
|
544
|
+
pse = new_kwargs.get("pse")
|
|
545
|
+
softmax_max = new_kwargs.get("softmax_max")
|
|
546
|
+
softmax_sum = new_kwargs.get("softmax_sum")
|
|
547
|
+
scalar_value = new_kwargs.get("scalar_value")
|
|
548
|
+
|
|
549
|
+
args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
550
|
+
attn_mask = generate_attn_mask(*args_temp)
|
|
551
|
+
|
|
552
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
553
|
+
dx = convert_to_bnsd(dx, n1, input_layout)
|
|
554
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
555
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
556
|
+
|
|
557
|
+
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
558
|
+
|
|
559
|
+
if SOFTMAX_BUILD_MODE == "QKV":
|
|
560
|
+
softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value)
|
|
561
|
+
else:
|
|
562
|
+
softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum)
|
|
563
|
+
softmax_res = rebuild_softmax_by_max_sum(softmax_params)
|
|
564
|
+
|
|
565
|
+
backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob)
|
|
566
|
+
dq, dk, dv = fusion_attention_backward(backward_params)
|
|
567
|
+
|
|
568
|
+
# Reshape as needed
|
|
569
|
+
if dq.dim() == 5:
|
|
570
|
+
dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
|
|
571
|
+
if dk.dim() == 5:
|
|
572
|
+
dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
|
|
573
|
+
if dv.dim() == 5:
|
|
574
|
+
dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
|
|
575
|
+
|
|
576
|
+
dq = convert_from_bnsd(dq, input_layout)
|
|
577
|
+
dk = convert_from_bnsd(dk, input_layout)
|
|
578
|
+
dv = convert_from_bnsd(dv, input_layout)
|
|
579
|
+
|
|
580
|
+
return dq.cpu(), dk.cpu(), dv.cpu()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FusionOperator:
|
|
20
|
+
"""
|
|
21
|
+
所有融合算子的父类,定义了通用的接口和属性。
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# 初始化操作符字典
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符
|
|
27
|
+
self._register_operators()
|
|
28
|
+
|
|
29
|
+
def __getattr__(self, name):
|
|
30
|
+
""" 动态获取算子类 """
|
|
31
|
+
if hasattr(self, name):
|
|
32
|
+
return getattr(self, name)
|
|
33
|
+
else:
|
|
34
|
+
raise AttributeError(f"'FusionOperator' object has no attribute '{name}'")
|
|
35
|
+
|
|
36
|
+
def _register_operators(self):
|
|
37
|
+
""" 注册操作符到父类,以便通过 fusion.xxx 调用 """
|
|
38
|
+
self.flash_attention_score = FlashAttentionScore()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
fusion = FusionOperator()
|
|
@@ -39,6 +39,8 @@ def add_api_accuracy_checker_argument(parser):
|
|
|
39
39
|
help="<optional> The ut task result out path.")
|
|
40
40
|
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
41
41
|
help="<optional> the exit csv for continue")
|
|
42
|
+
parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
|
|
43
|
+
help="<optional> Save compare failed api output.", required=False)
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
def multi_add_api_accuracy_checker_argument(parser):
|
|
@@ -49,6 +51,8 @@ def multi_add_api_accuracy_checker_argument(parser):
|
|
|
49
51
|
help="<optional> The ut task result out path.")
|
|
50
52
|
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
51
53
|
help="<optional> the exit csv for continue")
|
|
54
|
+
parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
|
|
55
|
+
help="<optional> Save compare failed api output.", required=False)
|
|
52
56
|
#以下属于多线程参数
|
|
53
57
|
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
54
58
|
help="<optional> set device id to run ut, must be unique and in range 0-7",
|
|
@@ -16,12 +16,13 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import csv
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
20
|
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
|
|
21
21
|
from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
|
|
22
22
|
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
23
23
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
24
24
|
from msprobe.mindspore.common.log import logger
|
|
25
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class ResultCsvEntry:
|
|
@@ -187,7 +188,7 @@ class DataManager:
|
|
|
187
188
|
|
|
188
189
|
def record_exception_skip(self, api_name, forward_or_backward, err_msg):
|
|
189
190
|
'''
|
|
190
|
-
record exception_skip
|
|
191
|
+
record exception_skip information into self.record_exception_skip.
|
|
191
192
|
self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
|
|
192
193
|
string in key is api_name, string in value is err_msg
|
|
193
194
|
'''
|
|
@@ -269,7 +270,7 @@ class DataManager:
|
|
|
269
270
|
entry.backward_pass_status,
|
|
270
271
|
overall_err_msg
|
|
271
272
|
]
|
|
272
|
-
# change row if this api has
|
|
273
|
+
# change row if this api has exception_skip information
|
|
273
274
|
if api_name in self.results_exception_skip:
|
|
274
275
|
if self.results_exception_skip[api_name][Const.FORWARD] is not None:
|
|
275
276
|
row[1] = CompareConst.SKIP
|