mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,6 +1,39 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
# 前向函数声明对比
|
|
18
|
+
标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
|
|
19
|
+
融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
20
|
+
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
21
|
+
next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
|
|
22
|
+
gen_mask_parallel=True, sync=False
|
|
23
|
+
|
|
24
|
+
# 反向函数声明对比
|
|
25
|
+
标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
|
|
26
|
+
融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
27
|
+
atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
|
|
28
|
+
attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
29
|
+
next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
|
|
30
|
+
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
31
|
+
"""
|
|
32
|
+
|
|
1
33
|
import torch
|
|
2
34
|
import numpy as np
|
|
3
35
|
from einops import rearrange
|
|
36
|
+
|
|
4
37
|
try:
|
|
5
38
|
import torch_npu
|
|
6
39
|
except ImportError:
|
|
@@ -9,34 +42,16 @@ except ImportError:
|
|
|
9
42
|
# flash_attn为gpu的fa三方库
|
|
10
43
|
from flash_attn import flash_attn_func
|
|
11
44
|
except ImportError:
|
|
12
|
-
|
|
45
|
+
# 如果为cpu的ut环境,则不做任何处理
|
|
13
46
|
pass
|
|
14
47
|
else:
|
|
15
48
|
is_gpu = False
|
|
16
49
|
|
|
17
|
-
|
|
18
50
|
from msprobe.pytorch.common.utils import logger
|
|
19
51
|
from msprobe.core.common.const import Const, CompareConst
|
|
20
52
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
"""
|
|
25
|
-
# 前向函数声明对比
|
|
26
|
-
标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
|
|
27
|
-
融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
28
|
-
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
29
|
-
next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
|
|
30
|
-
gen_mask_parallel=True, sync=False
|
|
31
|
-
|
|
32
|
-
# 反向函数声明对比
|
|
33
|
-
标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
|
|
34
|
-
融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
|
|
35
|
-
atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
|
|
36
|
-
attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
|
|
37
|
-
next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
|
|
38
|
-
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
39
|
-
"""
|
|
53
|
+
GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
|
|
54
|
+
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
40
55
|
|
|
41
56
|
|
|
42
57
|
def softmax_forward(x):
|
|
@@ -62,10 +77,10 @@ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
|
|
|
62
77
|
|
|
63
78
|
factor = num_heads // num_kv_heads
|
|
64
79
|
kv_shape = kv_tensor.shape
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
kv_res = torch.zeros([
|
|
80
|
+
b = kv_shape[0]
|
|
81
|
+
s = kv_shape[2]
|
|
82
|
+
d = kv_shape[3]
|
|
83
|
+
kv_res = torch.zeros([b, num_heads, s, d]).to(dtype)
|
|
69
84
|
for i in range(num_heads):
|
|
70
85
|
j = i // factor
|
|
71
86
|
kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
|
|
@@ -112,7 +127,7 @@ def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, k
|
|
|
112
127
|
|
|
113
128
|
def parse_bsnd_args(query, key, head_num, input_layout):
|
|
114
129
|
supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
|
|
115
|
-
|
|
130
|
+
b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None
|
|
116
131
|
|
|
117
132
|
if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
|
|
118
133
|
raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
|
|
@@ -121,35 +136,48 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
121
136
|
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
122
137
|
try:
|
|
123
138
|
if input_layout == "BSH":
|
|
124
|
-
|
|
125
|
-
_,
|
|
126
|
-
|
|
127
|
-
|
|
139
|
+
b, s1, h1 = query.shape
|
|
140
|
+
_, s2, h2 = key.shape
|
|
141
|
+
d = h1 // n1
|
|
142
|
+
n2 = h2 // d
|
|
128
143
|
elif input_layout == "SBH":
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
144
|
+
s1, b, h1 = query.shape
|
|
145
|
+
s2, _, h2 = key.shape
|
|
146
|
+
d = h1 // n1
|
|
147
|
+
n2 = h2 // d
|
|
133
148
|
elif input_layout == "BSND":
|
|
134
|
-
|
|
135
|
-
_,
|
|
136
|
-
|
|
137
|
-
|
|
149
|
+
b, s1, n1, d = query.shape
|
|
150
|
+
_, s2, n2, _ = key.shape
|
|
151
|
+
h1 = n1 * d
|
|
152
|
+
h2 = n2 * d
|
|
138
153
|
elif input_layout == "BNSD":
|
|
139
|
-
|
|
140
|
-
_,
|
|
141
|
-
|
|
142
|
-
|
|
154
|
+
b, n1, s1, d = query.shape
|
|
155
|
+
_, n2, s2, _ = key.shape
|
|
156
|
+
h1 = n1 * d
|
|
157
|
+
h2 = n2 * d
|
|
143
158
|
except Exception as e:
|
|
144
159
|
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
145
160
|
|
|
146
|
-
if
|
|
147
|
-
raise ValueError(f"Value
|
|
148
|
-
|
|
149
|
-
|
|
161
|
+
if d == 0:
|
|
162
|
+
raise ValueError(f"Value d must be non-zero.")
|
|
163
|
+
_dtype = query.dtype
|
|
164
|
+
ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
|
|
165
|
+
return ret
|
|
150
166
|
|
|
151
167
|
|
|
152
168
|
def convert_from_bnsd(_input, input_layout):
|
|
169
|
+
"""
|
|
170
|
+
transform qkv from bnsd to input_layout.
|
|
171
|
+
B: batch_size
|
|
172
|
+
S: sequence_length
|
|
173
|
+
N: num_heads
|
|
174
|
+
D: head_dim
|
|
175
|
+
Args:
|
|
176
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D)
|
|
177
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
178
|
+
Returns:
|
|
179
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
180
|
+
"""
|
|
153
181
|
if input_layout == "BSH":
|
|
154
182
|
# (B,N,S,D)=>(B,S,N*D)
|
|
155
183
|
out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
|
|
@@ -167,7 +195,19 @@ def convert_from_bnsd(_input, input_layout):
|
|
|
167
195
|
|
|
168
196
|
|
|
169
197
|
def convert_to_bnsd(_input, n, input_layout):
|
|
170
|
-
|
|
198
|
+
"""
|
|
199
|
+
transform qkv from input_layout to bnsd.
|
|
200
|
+
B: batch_size
|
|
201
|
+
S: sequence_length
|
|
202
|
+
N: num_heads
|
|
203
|
+
D: head_dim
|
|
204
|
+
Args:
|
|
205
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
206
|
+
n (int): num_heads
|
|
207
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
208
|
+
Returns:
|
|
209
|
+
tensor of shape (B,N,S,D)
|
|
210
|
+
"""
|
|
171
211
|
if input_layout == "BSH":
|
|
172
212
|
# (B,S,N*D)=>(B,N,S,D)
|
|
173
213
|
out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
|
|
@@ -183,27 +223,90 @@ def convert_to_bnsd(_input, n, input_layout):
|
|
|
183
223
|
out = _input
|
|
184
224
|
if out.dim() != 4:
|
|
185
225
|
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
186
|
-
return out.to(
|
|
226
|
+
return out.to(GTYPE)
|
|
187
227
|
|
|
188
228
|
|
|
189
|
-
def
|
|
229
|
+
def convert_from_bsnd(_input, input_layout):
|
|
230
|
+
"""
|
|
231
|
+
transform qkv from bsnd to input_layout.
|
|
232
|
+
B: batch_size
|
|
233
|
+
S: sequence_length
|
|
234
|
+
N: num_heads
|
|
235
|
+
D: head_dim
|
|
236
|
+
Args:
|
|
237
|
+
_input (torch.Tensor): tensor of shape (B,S,N,D)
|
|
238
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
239
|
+
Returns:
|
|
240
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
241
|
+
"""
|
|
242
|
+
if input_layout == "BSH":
|
|
243
|
+
# (B,S,N,D)=>(B,S,N*D)
|
|
244
|
+
out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
|
|
245
|
+
elif input_layout == "SBH":
|
|
246
|
+
# (B,S,N,D)=>(S,B,N*D)
|
|
247
|
+
out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
|
|
248
|
+
elif input_layout == "BNSD":
|
|
249
|
+
# (B,S,N,D)=>(B,N,S,D)
|
|
250
|
+
out = rearrange(_input, 'b s n d -> b n s d').contiguous()
|
|
251
|
+
elif input_layout == "TND":
|
|
252
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
253
|
+
else:
|
|
254
|
+
out = _input
|
|
255
|
+
return out
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def convert_to_bsnd(_input, n, input_layout):
|
|
259
|
+
"""
|
|
260
|
+
transform qkv from input_layout to bsnd.
|
|
261
|
+
B: batch_size
|
|
262
|
+
S: sequence_length
|
|
263
|
+
N: num_heads
|
|
264
|
+
D: head_dim
|
|
265
|
+
Args:
|
|
266
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
267
|
+
n (int): num_heads
|
|
268
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
269
|
+
Returns:
|
|
270
|
+
tensor of shape (B,S,N,D)
|
|
271
|
+
"""
|
|
272
|
+
if input_layout == "BSH":
|
|
273
|
+
# (B,S,N*D)=>(B,S,N,D)
|
|
274
|
+
out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
|
|
275
|
+
elif input_layout == "SBH":
|
|
276
|
+
# (S,B,N*D)=>(B,S,N,D)
|
|
277
|
+
out = rearrange(_input, 's b (n d) -> b s n d', n=n)
|
|
278
|
+
elif input_layout == "BNSD":
|
|
279
|
+
# (B,N,S,D)=>(B,S,N,D)
|
|
280
|
+
out = rearrange(_input, 'b n s d -> b s n d', n=n)
|
|
281
|
+
elif input_layout == "TND":
|
|
282
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
283
|
+
else:
|
|
284
|
+
out = _input
|
|
285
|
+
if out.dim() != 4:
|
|
286
|
+
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
287
|
+
return out
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def generate_atten_mask(*args):
|
|
190
291
|
"""
|
|
191
292
|
# 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
|
|
192
293
|
===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
|
|
193
294
|
"""
|
|
194
|
-
|
|
295
|
+
|
|
296
|
+
sparse_mode, atten_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args
|
|
297
|
+
shape = [s1, s2]
|
|
195
298
|
|
|
196
299
|
if atten_mask is not None:
|
|
197
300
|
# 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
|
|
198
301
|
if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
|
|
199
|
-
logger.info(f"
|
|
302
|
+
logger.info(f"s1: {s1}, s2:{s2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
|
|
200
303
|
|
|
201
304
|
if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
|
|
202
305
|
if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
|
|
203
306
|
if sparse_mode == 2:
|
|
204
307
|
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
205
308
|
elif sparse_mode == 3:
|
|
206
|
-
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=
|
|
309
|
+
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
207
310
|
elif sparse_mode == 4:
|
|
208
311
|
atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
209
312
|
atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
@@ -215,14 +318,14 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
|
|
|
215
318
|
|
|
216
319
|
if atten_mask is not None:
|
|
217
320
|
if atten_mask.dim() == 2:
|
|
218
|
-
if atten_mask.shape[0] !=
|
|
321
|
+
if atten_mask.shape[0] != s1 or atten_mask.shape[1] != s2:
|
|
219
322
|
raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
|
|
220
|
-
shape = [
|
|
323
|
+
shape = [s1, s2]
|
|
221
324
|
elif atten_mask.dim() == 4:
|
|
222
325
|
if atten_mask.shape[1] == 1:
|
|
223
|
-
shape = [
|
|
326
|
+
shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2]
|
|
224
327
|
else:
|
|
225
|
-
shape = [
|
|
328
|
+
shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2]
|
|
226
329
|
|
|
227
330
|
if sparse_mode == 0:
|
|
228
331
|
atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
@@ -233,7 +336,7 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
|
|
|
233
336
|
elif sparse_mode == 2:
|
|
234
337
|
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
|
|
235
338
|
elif sparse_mode == 3:
|
|
236
|
-
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=
|
|
339
|
+
atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1))
|
|
237
340
|
elif sparse_mode == 4:
|
|
238
341
|
atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
|
|
239
342
|
atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
|
|
@@ -243,11 +346,11 @@ def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next
|
|
|
243
346
|
return atten_mask.to(dtype)
|
|
244
347
|
|
|
245
348
|
|
|
246
|
-
def generate_kv(key, value,
|
|
349
|
+
def generate_kv(key, value, n1, n2):
|
|
247
350
|
# N不等长适配by cdy
|
|
248
|
-
if not (
|
|
249
|
-
k_new = broadcast_kv(
|
|
250
|
-
v_new = broadcast_kv(
|
|
351
|
+
if not (n1 == n2):
|
|
352
|
+
k_new = broadcast_kv(n1, n2, key, key.dtype)
|
|
353
|
+
v_new = broadcast_kv(n1, n2, value, value.dtype)
|
|
251
354
|
else:
|
|
252
355
|
k_new = key
|
|
253
356
|
v_new = value
|
|
@@ -261,7 +364,7 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
|
|
|
261
364
|
"""
|
|
262
365
|
logger.info("Using QKV to rebuild original softmax")
|
|
263
366
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
264
|
-
softmax_res,
|
|
367
|
+
softmax_res, _, _ = softmax_forward(qk)
|
|
265
368
|
return softmax_res
|
|
266
369
|
|
|
267
370
|
|
|
@@ -301,30 +404,38 @@ def get_input_layout(*args, **kwargs):
|
|
|
301
404
|
|
|
302
405
|
|
|
303
406
|
def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
407
|
+
|
|
408
|
+
if len(args) < 2:
|
|
409
|
+
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
|
|
410
|
+
|
|
304
411
|
# query, key, value, head_num, input_layout
|
|
305
412
|
head_num = get_head_num(*args, **kwargs)
|
|
306
413
|
input_layout = get_input_layout(*args, **kwargs)
|
|
307
414
|
|
|
308
|
-
|
|
309
|
-
if
|
|
310
|
-
logger.debug(f"running case : BNSD = {
|
|
415
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
416
|
+
if n1 == n2 and s1 == s2:
|
|
417
|
+
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
311
418
|
else:
|
|
312
|
-
logger.debug(f"running case: BNSD = {
|
|
313
|
-
if not (
|
|
314
|
-
raise ValueError(f"N1与N2不匹配,请检查:
|
|
315
|
-
|
|
316
|
-
dims_kwargs = {
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
419
|
+
logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
420
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
421
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
422
|
+
|
|
423
|
+
dims_kwargs = {
|
|
424
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
425
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
new_kwargs = {
|
|
429
|
+
"keep_prob": 1,
|
|
430
|
+
"scale": kwargs.get("scale", 1 / (d ** 0.5)),
|
|
431
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
432
|
+
"prefix": kwargs.get("prefix"),
|
|
433
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
434
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
435
|
+
"pse": kwargs.get("pse"),
|
|
436
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
437
|
+
"atten_mask": kwargs.get("atten_mask")
|
|
438
|
+
}
|
|
328
439
|
|
|
329
440
|
return args, dims_kwargs, new_kwargs
|
|
330
441
|
|
|
@@ -333,33 +444,37 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
|
333
444
|
if len(args) != 6:
|
|
334
445
|
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
335
446
|
|
|
336
|
-
|
|
337
|
-
if
|
|
338
|
-
logger.info(f"running case :
|
|
447
|
+
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
448
|
+
if n1 == n2 and s1 == s2:
|
|
449
|
+
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
339
450
|
else:
|
|
340
|
-
logger.info(f"running case:
|
|
341
|
-
if not (
|
|
342
|
-
raise ValueError(f"N1与N2不匹配,请检查:
|
|
343
|
-
|
|
344
|
-
dims_kwargs = {
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
451
|
+
logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
452
|
+
if not (n1 % n2 == 0 and n1 >= n2):
|
|
453
|
+
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
454
|
+
|
|
455
|
+
dims_kwargs = {
|
|
456
|
+
"b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2,
|
|
457
|
+
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
new_kwargs = {
|
|
461
|
+
"keep_prob": 1,
|
|
462
|
+
"scale_value": kwargs.get("scale_value", 1 / (d ** 0.5)),
|
|
463
|
+
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
464
|
+
"prefix": kwargs.get("prefix"),
|
|
465
|
+
"pre_tockens": kwargs.get("pre_tockens", 2147483647),
|
|
466
|
+
"next_tockens": kwargs.get("next_tockens", 2147483647),
|
|
467
|
+
"pse": kwargs.get("pse"),
|
|
468
|
+
"padding_mask": kwargs.get("padding_mask"),
|
|
469
|
+
"softmax_max": kwargs.get("softmax_max"),
|
|
470
|
+
"softmax_sum": kwargs.get("softmax_sum"),
|
|
471
|
+
"softmax_in": kwargs.get("softmax_in"),
|
|
472
|
+
"attention_in": kwargs.get("attention_in"),
|
|
473
|
+
"seed": kwargs.get("seed", 0),
|
|
474
|
+
"offset": kwargs.get("offset", 0),
|
|
475
|
+
"numels": kwargs.get("numels", 0),
|
|
476
|
+
"atten_mask": kwargs.get("atten_mask")
|
|
477
|
+
}
|
|
363
478
|
|
|
364
479
|
return args, dims_kwargs, new_kwargs
|
|
365
480
|
|
|
@@ -368,12 +483,12 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
368
483
|
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
|
|
369
484
|
query, key, value = new_args[0], new_args[1], new_args[2]
|
|
370
485
|
input_layout = get_input_layout(*args, **kwargs)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
486
|
+
n1 = dims_kwargs.get("n1")
|
|
487
|
+
n2 = dims_kwargs.get("n2")
|
|
488
|
+
s1 = dims_kwargs.get("s1")
|
|
489
|
+
s2 = dims_kwargs.get("s2")
|
|
490
|
+
b = dims_kwargs.get("b")
|
|
491
|
+
dtype = dims_kwargs.get("dtype")
|
|
377
492
|
atten_mask = new_kwargs.get("atten_mask")
|
|
378
493
|
keep_prob = new_kwargs.get("keep_prob")
|
|
379
494
|
sparse_mode = new_kwargs.get("sparse_mode")
|
|
@@ -381,12 +496,12 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
381
496
|
next_tockens = new_kwargs.get("next_tockens")
|
|
382
497
|
pse = new_kwargs.get("pse")
|
|
383
498
|
scale = new_kwargs.get("scale")
|
|
384
|
-
|
|
385
|
-
atten_mask = generate_atten_mask(
|
|
386
|
-
query = convert_to_bnsd(query,
|
|
387
|
-
key = convert_to_bnsd(key,
|
|
388
|
-
value = convert_to_bnsd(value,
|
|
389
|
-
k_new, v_new = generate_kv(key, value,
|
|
499
|
+
args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
500
|
+
atten_mask = generate_atten_mask(*args_temp)
|
|
501
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
502
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
503
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
504
|
+
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
390
505
|
out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
|
|
391
506
|
drop_mask=None, atten_mask=atten_mask,
|
|
392
507
|
pse=pse, scale=scale,
|
|
@@ -403,13 +518,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
403
518
|
# dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
|
|
404
519
|
new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
|
|
405
520
|
query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
521
|
+
n1 = dims_kwargs.get("n1")
|
|
522
|
+
n2 = dims_kwargs.get("n2")
|
|
523
|
+
s1 = dims_kwargs.get("s1")
|
|
524
|
+
s2 = dims_kwargs.get("s2")
|
|
525
|
+
b = dims_kwargs.get("b")
|
|
526
|
+
d = dims_kwargs.get("d")
|
|
527
|
+
dtype = dims_kwargs.get("dtype")
|
|
413
528
|
atten_mask = new_kwargs.get("atten_mask")
|
|
414
529
|
keep_prob = new_kwargs.get("keep_prob")
|
|
415
530
|
sparse_mode = new_kwargs.get("sparse_mode")
|
|
@@ -420,14 +535,15 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
420
535
|
softmax_sum = new_kwargs.get("softmax_sum")
|
|
421
536
|
scale_value = new_kwargs.get("scale_value")
|
|
422
537
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
538
|
+
args_temp = [sparse_mode, atten_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype]
|
|
539
|
+
atten_mask = generate_atten_mask(*args_temp)
|
|
540
|
+
query = convert_to_bnsd(query, n1, input_layout)
|
|
541
|
+
dx = convert_to_bnsd(dx, n1, input_layout)
|
|
542
|
+
key = convert_to_bnsd(key, n2, input_layout)
|
|
543
|
+
value = convert_to_bnsd(value, n2, input_layout)
|
|
544
|
+
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
429
545
|
|
|
430
|
-
if
|
|
546
|
+
if SOFTMAX_BUILD_MODE == "QKV":
|
|
431
547
|
softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
|
|
432
548
|
else:
|
|
433
549
|
softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
|
|
@@ -435,12 +551,12 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
435
551
|
dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
|
|
436
552
|
|
|
437
553
|
# N不等长适配by cdy
|
|
438
|
-
if not (
|
|
439
|
-
if
|
|
440
|
-
raise ValueError("dims_kwargs.
|
|
441
|
-
|
|
442
|
-
dk = torch.sum(dk.reshape(
|
|
443
|
-
dv = torch.sum(dv.reshape(
|
|
554
|
+
if not (n1 == n2):
|
|
555
|
+
if n2 == 0:
|
|
556
|
+
raise ValueError("dims_kwargs.n2 must be non-zero.")
|
|
557
|
+
g = int(n1 / n2)
|
|
558
|
+
dk = torch.sum(dk.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
|
|
559
|
+
dv = torch.sum(dv.reshape(b, n2, g, s2, d), dim=2, keepdim=True).reshape(b, n2, s2, d)
|
|
444
560
|
|
|
445
561
|
if dq.dim() == 5:
|
|
446
562
|
dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
|
|
@@ -460,12 +576,12 @@ def is_attention_off_due_to_mask(atten_mask_dtype):
|
|
|
460
576
|
return not atten_mask_dtype
|
|
461
577
|
|
|
462
578
|
|
|
463
|
-
def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens,
|
|
464
|
-
return sparse_mode == 4 and (next_tockens != 0 or pre_tockens <
|
|
579
|
+
def is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1):
|
|
580
|
+
return sparse_mode == 4 and (next_tockens != 0 or pre_tockens < s1)
|
|
465
581
|
|
|
466
582
|
|
|
467
|
-
def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens,
|
|
468
|
-
return sparse_mode == 0 and pre_tockens >=
|
|
583
|
+
def is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2):
|
|
584
|
+
return sparse_mode == 0 and pre_tockens >= s1 and next_tockens >= s2
|
|
469
585
|
|
|
470
586
|
|
|
471
587
|
def gpu_fusion_attention(*args, **kwargs):
|
|
@@ -474,11 +590,11 @@ def gpu_fusion_attention(*args, **kwargs):
|
|
|
474
590
|
query, key, value = new_args[0], new_args[1], new_args[2]
|
|
475
591
|
keep_prob = new_kwargs.get("keep_prob", 1.0)
|
|
476
592
|
scale = new_kwargs.get("scale")
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
593
|
+
n1 = dims_kwargs.get("n1")
|
|
594
|
+
n2 = dims_kwargs.get("n2")
|
|
595
|
+
s1 = dims_kwargs.get("s1")
|
|
596
|
+
s2 = dims_kwargs.get("s2")
|
|
597
|
+
b = dims_kwargs.get("b")
|
|
482
598
|
pse = new_kwargs.get("pse")
|
|
483
599
|
sparse_mode = new_kwargs.get("sparse_mode")
|
|
484
600
|
pre_tockens = new_kwargs.get("pre_tockens")
|
|
@@ -488,22 +604,29 @@ def gpu_fusion_attention(*args, **kwargs):
|
|
|
488
604
|
pre_tockens = min(CompareConst.MAX_TOKENS, pre_tockens)
|
|
489
605
|
next_tockens = min(CompareConst.MAX_TOKENS, next_tockens)
|
|
490
606
|
atten_off = (is_attention_off_due_to_mask(atten_mask_dtype) or
|
|
491
|
-
|
|
492
|
-
|
|
607
|
+
is_attention_off_in_sparse_mode_4(sparse_mode, next_tockens, pre_tockens, s1) or
|
|
608
|
+
is_attention_off_in_sparse_mode_0(sparse_mode, pre_tockens, next_tockens, s1, s2))
|
|
493
609
|
causal_switch = not atten_off
|
|
494
610
|
if sparse_mode == CompareConst.SPECIAL_SPARSE_MOED:
|
|
495
611
|
window_left = pre_tockens
|
|
496
612
|
window_right = next_tockens
|
|
497
613
|
else:
|
|
498
614
|
pre_tockens = next_tockens = CompareConst.MAX_TOKENS
|
|
499
|
-
window_left = pre_tockens -
|
|
500
|
-
window_right = next_tockens +
|
|
501
|
-
|
|
615
|
+
window_left = pre_tockens - s1 + s2
|
|
616
|
+
window_right = next_tockens + s1 - s2
|
|
617
|
+
|
|
502
618
|
if pse is not None:
|
|
503
|
-
alibi_slopes = torch.rand(
|
|
619
|
+
alibi_slopes = torch.rand(b, n1, dtype=torch.float32) * 0.3
|
|
504
620
|
else:
|
|
505
621
|
alibi_slopes = None
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
622
|
+
|
|
623
|
+
input_layout = get_input_layout(*args, **kwargs)
|
|
624
|
+
query = convert_to_bsnd(query, n1, input_layout)
|
|
625
|
+
key = convert_to_bsnd(key, n2, input_layout)
|
|
626
|
+
value = convert_to_bsnd(value, n2, input_layout)
|
|
627
|
+
out = flash_attn_func(
|
|
628
|
+
query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
|
|
629
|
+
window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
|
|
630
|
+
)
|
|
631
|
+
out = convert_from_bsnd(out, input_layout)
|
|
509
632
|
return out, Const.NONE, Const.NONE
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|