mindspore 2.2.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  84. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,563 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """The forward tik ops of flash attention"""
16
-
17
- from tbe import tik
18
-
19
- from mindspore.ops._op_impl._custom_op.flash_attention.attention import FlashAttention
20
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SIZE
21
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
22
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
23
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
24
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import L1
25
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
26
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
27
-
28
-
29
- class FlashAttentionFwd(FlashAttention):
30
- """The implementation of flash attention forward
31
- This function contains the flash attention forward implementation used in flash attention (see paper)
32
- `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/pdf/2205.14135.pdf>`
33
- """
34
-
35
- def __init__(self, query, key, value,
36
- attn_mask, dropout_mask, alibi_mask,
37
- kernel_name,
38
- tiling_stgy: TilingStrategy,
39
- prev_block_num=65536,
40
- next_block_num=65536, high_precision=False, disable_debug=True):
41
- super(FlashAttentionFwd, self).__init__(query, key, value, attn_mask, dropout_mask, alibi_mask,
42
- kernel_name,
43
- tiling_stgy, prev_block_num, next_block_num, high_precision,
44
- disable_debug)
45
- self.O_gm = None
46
- self.l_gm = None
47
- self.m_gm = None
48
- self.O_gm_workspace = None
49
-
50
- def define_custom_inputs(self):
51
- """define custom inputs"""
52
-
53
- def define_outputs(self):
54
- """define outputs"""
55
- self.O_gm = self.tik_instance.Tensor(FP16, self.O_shape, name="O_gm", scope=GM, is_atomic_add=True)
56
- if self.high_precision:
57
- self.O_gm_workspace = self.tik_instance.Tensor(FP32, self.O_shape, name="O_gm_workspace", scope=GM,
58
- is_workspace=True, is_atomic_add=True)
59
- self.l_gm = self.tik_instance.Tensor(self.precision_type, self.l_shape, name="l_gm", scope=GM,
60
- is_atomic_add=True)
61
- self.m_gm = self.tik_instance.Tensor(FP16, self.m_shape, name="m_gm", scope=GM, is_atomic_add=True)
62
-
63
- def prepare_global_ones(self):
64
- """Prepare global ones tensor in L1 for cube impl row_sum"""
65
- Bc_aligned = (self.Bc + 15) // 16 * 16
66
- last_Bc_aligned = (self.last_Bc + 15) // 16 * 16
67
- self.ones_l1 = self.tik_instance.Tensor(FP16, (Bc_aligned, 16), name="ones_l1", scope=L1)
68
- self.last_ones_l1 = self.tik_instance.Tensor(FP16, (last_Bc_aligned, 16), name="last_ones_l1", scope=L1)
69
- with self.tik_instance.new_stmt_scope(disable_sync=False):
70
- ones_ub = self.tik_instance.Tensor(FP16, (Bc_aligned, 16), name="ones_ub", scope=UB)
71
- self.tik_instance.h_duplicate(ones_ub, 1.0)
72
- self.cont_data_mv_1_bust(dst=self.ones_l1, src=ones_ub, burst=Bc_aligned)
73
- last_ones_ub = self.tik_instance.Tensor(FP16, (last_Bc_aligned, 16), name="last_ones_ub", scope=UB)
74
- self.tik_instance.h_duplicate(ones_ub, 1.0)
75
- self.cont_data_mv_1_bust(dst=self.last_ones_l1, src=last_ones_ub, burst=last_Bc_aligned)
76
-
77
- def softmax_compute(self, Sij_ub, mij_ub, lij_ub, m, n):
78
- """Refer to Algorithm 2 line12"""
79
- m_aligned = self.tik_ops_utils.up_align_to_K0(m)
80
- n_aligned = self.tik_ops_utils.up_align_to_K0(n)
81
- n0 = 16
82
- n1 = n // 16
83
- # only support n % 16 == 0
84
- with self.tik_instance.new_stmt_scope(disable_sync=False):
85
- mn0_block_max = self.tik_instance.Tensor(FP16, (1, m, n0), name="mn0_block_max", scope=UB)
86
- self.cont_data_mv_1_bust(dst=mn0_block_max, src=Sij_ub, burst=m)
87
- with self.tik_instance.for_range(1, n1) as idx:
88
- self.tik_instance.h_max(mn0_block_max, mn0_block_max, Sij_ub[idx, :, :])
89
- mn0_block_max = mn0_block_max.reshape((m, n0))
90
- self.tik_instance.h_reduce_max(mij_ub, mn0_block_max, 1)
91
- # Sij - mij
92
- with self.tik_instance.new_stmt_scope(disable_sync=False):
93
- broadcast_mij_ub = self.tik_ops_utils.broadcast(mij_ub, (m, n0))
94
- broadcast_mij_ub = broadcast_mij_ub.reshape((1, m, n0))
95
- for idx in range(n1):
96
- self.tik_instance.h_sub(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_mij_ub)
97
- # exp
98
- if self.high_precision:
99
- Sij_ub_fp32 = self.tik_instance.Tensor(
100
- FP32, (n_aligned // 16, m_aligned, 16), name="Sij_ub_fp32", scope=UB
101
- )
102
- with self.tik_instance.new_stmt_scope(disable_sync=False):
103
- self.tik_instance.h_cast(Sij_ub_fp32, Sij_ub, "none")
104
- self.tik_instance.h_exp(Sij_ub_fp32, Sij_ub_fp32)
105
- self.tik_instance.h_cast(Sij_ub, Sij_ub_fp32, "none")
106
- else:
107
- self.tik_instance.h_exp(Sij_ub, Sij_ub)
108
-
109
- # cube impl rowsum
110
- Sij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (n_aligned // 16, m_aligned, 16),
111
- name="Sij_l1_K1MK0_ed", scope=L1)
112
- self.cont_data_mv_1_bust(dst=Sij_l1_K1MK0_ed, src=Sij_ub, burst=m * n // 16)
113
- if n == self.Bc:
114
- Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, self.ones_l1,
115
- lij_ub, m, n, self.precision_type)
116
- else:
117
- Sij_row_sum_ub = self.tik_ops_utils.row_sum_cube_impl(Sij_l1_K1MK0_ed, self.last_ones_l1,
118
- lij_ub, m, n, self.precision_type)
119
-
120
- if self.high_precision:
121
- return Sij_ub_fp32, mij_ub, Sij_row_sum_ub
122
- if self.has_drop_mask:
123
- return Sij_ub, mij_ub, Sij_row_sum_ub
124
- return Sij_l1_K1MK0_ed, mij_ub, Sij_row_sum_ub
125
-
126
- def update_m_l(self, mi_old_ub, mij_ub, li_old_ub, lij_ub, vec_len):
127
- """Refer to Algorithm 2 line13
128
- mi_new = max(mi, mij), li_new = exp(mi-mi_new)*li + exp(mij - mi_new) * lij
129
- """
130
- dtype = li_old_ub.dtype
131
- vec_len_aligned = self.tik_ops_utils.up_align_to_K0(vec_len)
132
- mi_new_ub = self.tik_instance.Tensor(FP16, (vec_len_aligned,), name="mi_new_ub", scope=UB)
133
- li_new_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), name="li_new_ub", scope=UB)
134
- self.tik_instance.h_max(mi_new_ub, mi_old_ub, mij_ub)
135
-
136
- self.tik_instance.h_sub(mi_old_ub, mi_old_ub, mi_new_ub) # mi-mi_new
137
- self.tik_instance.h_exp(mi_old_ub, mi_old_ub) # exp(mi-mi_new)
138
-
139
- self.tik_instance.h_sub(mij_ub, mij_ub, mi_new_ub) # mij-mi_new
140
- self.tik_instance.h_exp(mij_ub, mij_ub) # exp(mij-mi_new)
141
-
142
- with self.tik_instance.new_stmt_scope(disable_sync=False):
143
- mul_li_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="mul_li_ub")
144
- mul_lij_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="mul_lij_ub")
145
- if self.high_precision:
146
- self.tik_instance.h_cast(mul_li_ub, mi_old_ub, "none")
147
- self.tik_instance.h_cast(mul_lij_ub, mij_ub, "none")
148
- self.tik_instance.h_mul(mul_li_ub, mul_li_ub, li_old_ub)
149
- self.tik_instance.h_mul(mul_lij_ub, mul_lij_ub, lij_ub)
150
- else:
151
- self.tik_instance.h_mul(mul_li_ub, mi_old_ub, li_old_ub)
152
- self.tik_instance.h_mul(mul_lij_ub, mij_ub, lij_ub)
153
- self.tik_instance.h_add(li_new_ub, mul_li_ub, mul_lij_ub)
154
- return mi_new_ub, li_new_ub
155
-
156
- def update_o_m_l_fp32(self,
157
- Pij_ub_fp32,
158
- Vj_l1_K1NK0_ed,
159
- Pij_ub,
160
- mij_ub,
161
- lij_ub,
162
- batch_start,
163
- batch_idx,
164
- kv_blk_idx,
165
- kv_blk_height,
166
- q_blk_idx,
167
- block_h):
168
- """ load o m l from gm and update them in ub, then write them back to gm
169
- :param Pij_Vj_ub: input tensor with shape of (q_blk_h_aligned, self.d)
170
- :param mij_ub: input tensor with shape of (Br)
171
- :param lij_ub: input tensor with shape of (Br)
172
- :param batch_start:
173
- :param batch_idx:
174
- :param kv_blk_idx:
175
- :param q_blk_idx:
176
- :param block_h:
177
- :return: None
178
- """
179
- vec_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
180
- o_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
181
- block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
182
- block_k_aligned_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
183
- n1 = block_k_aligned_aligned // self.N0
184
- with self.tik_instance.if_scope(tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
185
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, lij_ub, vec_gm_offset, block_h)
186
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mij_ub, vec_gm_offset, block_h)
187
- li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(lij_ub, block_h)
188
- vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="vec_ub", scope=UB)
189
- for i in range(block_h):
190
- src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
191
- self.tik_instance.h_duplicate(vec_ub[i, :], src_scalar)
192
- vec_ub = vec_ub.reshape((1, block_h, self.N0))
193
- with self.tik_instance.for_range(0, n1) as idx:
194
- self.tik_instance.h_mul(Pij_ub_fp32[idx, :, :],
195
- Pij_ub_fp32[idx, :, :],
196
- vec_ub)
197
- self.tik_instance.h_cast(Pij_ub, Pij_ub_fp32, "none")
198
- Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
199
- FP16, (block_k_aligned_aligned // 16, block_h_aligned, 16), name="Pij_l1_K1MK0_ed", scope=L1
200
- )
201
- self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_ub,
202
- burst=block_k_aligned_aligned * block_h_aligned // 16)
203
- Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
204
- kv_blk_height, self.actual_d, N1MN0_to_MN=False,
205
- precision_type=self.precision_type) # Pij*Vj
206
- self.tik_instance.data_move(dst=self.O_gm_workspace[o_gm_offset], src=Pij_Vj_matmul_res_ub, sid=0,
207
- nburst=self.N1, burst=block_h * self.N0 // 8,
208
- src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 8)
209
- with self.tik_instance.else_scope():
210
- mi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="mi_old_ub", scope=UB)
211
- li_ub = self.tik_instance.Tensor(FP32, (block_h_aligned,), name="li_ub", scope=UB)
212
- self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, vec_gm_offset, block_h)
213
- self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, vec_gm_offset, block_h)
214
- mi_new_ub, li_new_ub = self.update_m_l(mi_ub, mij_ub, li_ub, lij_ub, block_h)
215
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, li_new_ub, vec_gm_offset, block_h)
216
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mi_new_ub, vec_gm_offset, block_h)
217
-
218
- exp_m_old_fp32 = self.tik_instance.Tensor(FP32, (block_h_aligned,), scope=UB, name="exp_m_old_fp32")
219
- exp_m_cur_fp32 = self.tik_instance.Tensor(FP32, (block_h_aligned,), scope=UB, name="exp_m_cur_fp32")
220
- self.tik_instance.h_cast(exp_m_old_fp32, mi_ub, "none")
221
- self.tik_instance.h_cast(exp_m_cur_fp32, mij_ub, "none")
222
-
223
- li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(li_new_ub, block_h)
224
- self.tik_instance.h_mul(exp_m_cur_fp32, exp_m_cur_fp32, li_new_rec_ub)
225
- exp_m_cur_fp32_vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="exp_m_cur_fp32_vec_ub",
226
- scope=UB)
227
- for i in range(block_h):
228
- src_scalar = self.tik_instance.Scalar(init_value=exp_m_cur_fp32[i], dtype=FP32)
229
- self.tik_instance.h_duplicate(exp_m_cur_fp32_vec_ub[i, :], src_scalar)
230
- exp_m_cur_fp32_vec_ub = exp_m_cur_fp32_vec_ub.reshape((1, block_h, self.N0))
231
- with self.tik_instance.for_range(0, n1) as idx:
232
- self.tik_instance.h_mul(Pij_ub_fp32[idx, :, :],
233
- Pij_ub_fp32[idx, :, :],
234
- exp_m_cur_fp32_vec_ub)
235
- self.tik_instance.h_cast(Pij_ub, Pij_ub_fp32, "none")
236
- Pij_l1_K1MK0_ed = self.tik_instance.Tensor(
237
- FP16, (block_k_aligned_aligned // 16, block_h_aligned, 16), name="Pij_l1_K1MK0_ed", scope=L1
238
- )
239
- self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_ub,
240
- burst=block_k_aligned_aligned * block_h_aligned // 16)
241
- Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
242
- kv_blk_height, self.actual_d, N1MN0_to_MN=False,
243
- precision_type=self.precision_type) # Pij*Vj
244
- n1, m, n0 = Pij_Vj_matmul_res_ub.shape
245
- Oi_ub = self.tik_instance.Tensor(FP32, (n1, m, n0), name="Oi_ub", scope=UB)
246
- self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm_workspace[o_gm_offset],
247
- sid=0, nburst=self.N1, burst=m * self.N0 // 8,
248
- src_stride=(self.Nq - m) * self.N0 // 8, dst_stride=0)
249
-
250
- self.tik_instance.h_mul(li_new_rec_ub, li_new_rec_ub, li_ub)
251
- self.tik_instance.h_mul(li_new_rec_ub, li_new_rec_ub, exp_m_old_fp32)
252
- li_new_rec_vec_ub = self.tik_instance.Tensor(FP32, (block_h, self.N0), name="li_new_rec_vec_ub",
253
- scope=UB)
254
- for i in range(block_h):
255
- src_scalar = self.tik_instance.Scalar(init_value=li_new_rec_ub[i], dtype=FP32)
256
- self.tik_instance.h_duplicate(li_new_rec_vec_ub[i, :], src_scalar)
257
- li_new_rec_vec_ub = li_new_rec_vec_ub.reshape((1, block_h, self.N0))
258
- with self.tik_instance.for_range(0, n1) as idx:
259
- self.tik_instance.h_mul(Oi_ub[idx, :, :],
260
- Oi_ub[idx, :, :],
261
- li_new_rec_vec_ub)
262
- self.tik_instance.h_add(Oi_ub, Oi_ub, Pij_Vj_matmul_res_ub)
263
- self.tik_instance.data_move(dst=self.O_gm_workspace[o_gm_offset], src=Oi_ub, sid=0,
264
- nburst=self.N1, burst=block_h * self.N0 // 8,
265
- src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 8)
266
-
267
- def exp_Pij_Vj(self, exp_mij_sub_mi_new, Pij_Vj_ub, block_h_aligned):
268
- """Refer to Algorithm 2 line15
269
- exp(mij - mi_new) * Pij * Vj
270
- """
271
- with self.tik_instance.new_stmt_scope(disable_sync=False):
272
- broadcast_exp_mij_sub_mi_new = self.tik_ops_utils.broadcast(exp_mij_sub_mi_new, (block_h_aligned, self.d))
273
- self.tik_instance.h_mul(Pij_Vj_ub, Pij_Vj_ub, broadcast_exp_mij_sub_mi_new)
274
- return Pij_Vj_ub
275
-
276
- def update_o_m_l(self,
277
- Pij_l1_K1MK0_ed,
278
- Vj_l1_K1NK0_ed,
279
- mij_ub,
280
- lij_ub,
281
- batch_start,
282
- batch_idx,
283
- kv_blk_idx,
284
- kv_blk_height,
285
- q_blk_idx,
286
- block_h):
287
- """Refer to Algorithm 2 line13 and line15 in FlashAttention"""
288
- vec_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
289
- o_gm_offset = self.get_gm_offset(
290
- batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx
291
- )
292
- block_h_aligned = self.tik_ops_utils.up_align_to_K0(block_h)
293
-
294
- Pij_Vj_matmul_res_ub = self.tik_ops_utils.matmul_compute(Pij_l1_K1MK0_ed, Vj_l1_K1NK0_ed, block_h,
295
- kv_blk_height, self.actual_d,
296
- N1MN0_to_MN=False) # Pij*Vj
297
- n1, m, n0 = Pij_Vj_matmul_res_ub.shape
298
- with self.tik_instance.if_scope(tik.any(kv_blk_idx == 0, kv_blk_idx + self.prev_block_num == q_blk_idx)):
299
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, lij_ub, vec_gm_offset, block_h)
300
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mij_ub, vec_gm_offset, block_h)
301
- li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(lij_ub, block_h)
302
- broadcast_li_new_rec_ub = self.tik_ops_utils.broadcast(li_new_rec_ub, (m, n0))
303
- broadcast_li_new_rec_ub = broadcast_li_new_rec_ub.reshape((1, m, n0))
304
- with self.tik_instance.for_range(0, n1) as idx:
305
- self.tik_instance.h_mul(Pij_Vj_matmul_res_ub[idx, :, :],
306
- Pij_Vj_matmul_res_ub[idx, :, :],
307
- broadcast_li_new_rec_ub)
308
- self.tik_instance.data_move(dst=self.O_gm[o_gm_offset], src=Pij_Vj_matmul_res_ub, sid=0,
309
- nburst=self.N1, burst=block_h * self.N0 // 16,
310
- src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 16)
311
-
312
- with self.tik_instance.else_scope():
313
- mi_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="mi_old_ub", scope=UB)
314
- li_ub = self.tik_instance.Tensor(FP16, (block_h_aligned,), name="li_ub", scope=UB)
315
- self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, vec_gm_offset, block_h)
316
- self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, vec_gm_offset, block_h)
317
- mi_new_ub, li_new_ub = self.update_m_l(mi_ub, mij_ub, li_ub, lij_ub, block_h)
318
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.l_gm, li_new_ub, vec_gm_offset, block_h)
319
- self.tik_ops_utils.move_vector_from_ub_to_gm(self.m_gm, mi_new_ub, vec_gm_offset, block_h)
320
- exp_mi_sub_mi_new = mi_ub
321
- exp_mij_sub_mi_new = mij_ub
322
-
323
- li_new_rec_ub = self.tik_ops_utils.calc_vec_rec(li_new_ub, block_h)
324
- self.tik_instance.h_mul(li_ub, li_ub, exp_mi_sub_mi_new)
325
- self.tik_instance.h_mul(li_ub, li_ub, li_new_rec_ub)
326
- scale1 = li_ub
327
- self.tik_instance.h_mul(exp_mij_sub_mi_new, exp_mij_sub_mi_new, li_new_rec_ub)
328
- scale2 = exp_mij_sub_mi_new
329
- Oi_ub = self.tik_instance.Tensor(FP16, (n1, m, n0), name="Oi_ub", scope=UB)
330
- self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm[o_gm_offset],
331
- sid=0, nburst=self.N1, burst=m * self.N0 // 16,
332
- src_stride=(self.Nq - m) * self.N0 // 16, dst_stride=0)
333
- broadcast_scale1 = self.tik_ops_utils.broadcast(scale1, (m, n0))
334
- broadcast_scale1 = broadcast_scale1.reshape((1, m, n0))
335
- with self.tik_instance.for_range(0, n1) as idx:
336
- self.tik_instance.h_mul(Oi_ub[idx, :, :], Oi_ub[idx, :, :], broadcast_scale1)
337
- broadcast_scale2 = self.tik_ops_utils.broadcast(scale2, (m, n0))
338
- broadcast_scale2 = broadcast_scale2.reshape((1, m, n0))
339
- with self.tik_instance.for_range(0, n1) as idx:
340
- self.tik_instance.h_mul(Pij_Vj_matmul_res_ub[idx, :, :],
341
- Pij_Vj_matmul_res_ub[idx, :, :],
342
- broadcast_scale2)
343
- self.tik_instance.h_add(Oi_ub, Oi_ub, Pij_Vj_matmul_res_ub)
344
- self.tik_instance.data_move(dst=self.O_gm[o_gm_offset], src=Oi_ub, sid=0,
345
- nburst=self.N1, burst=block_h * self.N0 // 16,
346
- src_stride=0, dst_stride=(self.Nq - block_h_aligned) * self.N0 // 16)
347
-
348
- def compute_in_each_kv_block(self, batch_start, batch_idx, kv_blk_idx, kv_blk_height,
349
- core_idx_to_tr_info, core_idx):
350
- """The forward computation in each outer loop"""
351
- kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
352
- kv_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.N, self.d, self.Bc, kv_blk_idx)
353
- # load Kj (kv_blk_idx_th block of K_gm)
354
- KjT_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
355
- name="KjT_l1_K1MK0_ed", scope=L1)
356
- self.tik_instance.data_move(dst=KjT_l1_K1MK0_ed, src=self.K_gm[kv_gm_offset],
357
- sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
358
- src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
359
-
360
- # load Vj (kv_blk_idx_th block of V_gm), then reorder for Pij*Vj
361
- Vj_l1 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Vj_l1", scope=L1)
362
- with self.tik_instance.new_stmt_scope(disable_sync=False):
363
- Vj_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
364
- name="Vj_ub", scope=UB)
365
- self.tik_instance.data_move(dst=Vj_ub, src=self.V_gm[kv_gm_offset],
366
- sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
367
- src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
368
- # (N1, K, N0) -> (K, N)
369
- Vj_ub = self.tik_ops_utils.N1MN0_TO_MN(Vj_ub)
370
- # (K, N) -> (K1, N, K0)
371
- Vj_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Vj_ub, workspace_tensor=Vj_l1)
372
-
373
- tr_start_s = self.tik_instance.Scalar("int32", name="tr_start_s")
374
- tr_end_s = self.tik_instance.Scalar("int32", name="tr_end_s")
375
- tr_start_s.set_as(core_idx_to_tr_info[core_idx, batch_start + batch_idx, 0])
376
- tr_end_s.set_as(core_idx_to_tr_info[core_idx, batch_start + batch_idx, 1])
377
- with self.tik_instance.for_range(tr_start_s, tr_end_s, name="q_blk_idx") as q_blk_idx:
378
- # 根据atten_mask倒三角特性,过滤无效计算
379
- with self.tik_instance.if_scope(tik.all(kv_blk_idx - self.next_block_num <= q_blk_idx,
380
- q_blk_idx <= kv_blk_idx + self.prev_block_num)):
381
- with self.tik_instance.if_scope(q_blk_idx != self.Tr - 1):
382
- self.compute_in_each_q_block(KjT_l1_K1MK0_ed, Vj_l1_K1NK0_ed, batch_idx,
383
- batch_start,
384
- kv_blk_height, self.Br, q_blk_idx, kv_blk_idx)
385
- with self.tik_instance.else_scope():
386
- self.compute_in_each_q_block(KjT_l1_K1MK0_ed, Vj_l1_K1NK0_ed, batch_idx,
387
- batch_start,
388
- kv_blk_height, self.last_Br, q_blk_idx, kv_blk_idx)
389
-
390
- def compute_in_each_q_block(self, KjT_l1_K1MK0_ed, Vj_l1_K1NK0_ed, batch_idx, batch_start,
391
- kv_blk_height, q_blk_height, q_blk_idx, kv_blk_idx):
392
- """The forward computation in each inner loop"""
393
- kv_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
394
- q_blk_h_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
395
- # load Qi (q_blk_idx_th block of Q_gm), then reorder it fo Qi*KjT
396
- q_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
397
- Qi_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_h_aligned, self.N0),
398
- scope=L1, name="Qi_l1_K1MK0_ed")
399
- self.tik_instance.data_move(dst=Qi_l1_K1MK0_ed, src=self.Q_gm[q_gm_offset],
400
- sid=0, nburst=self.N1, burst=q_blk_h_aligned * self.N0 // 16,
401
- src_stride=(self.Nq - q_blk_h_aligned) * self.N0 // 16, dst_stride=0)
402
-
403
- lij_ub = self.tik_instance.Tensor(self.precision_type, (q_blk_h_aligned,), scope=UB, name="lij_ub")
404
- mij_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned,), scope=UB, name="mij_ub")
405
-
406
- Sij_ub_N1MN0 = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1MK0_ed, m=q_blk_height,
407
- k=self.actual_d, n=kv_blk_height,
408
- N1MN0_to_MN=False) # Qi*KjT
409
- if self.has_alibi_mask:
410
- alibi_mask_gm_offset = self.get_alibi_gm_offset(batch_start, batch_idx, self.N, self.Bc, kv_blk_idx)
411
- self.do_alibi_mask(Sij_ub_N1MN0, alibi_mask_gm_offset, q_blk_h_aligned, kv_blk_h_aligned)
412
-
413
- # att_mask
414
- if self.has_attn_mask:
415
- attn_mask_gm_offset = self.get_attn_mask_gm_offset(batch_start, batch_idx, self.Nq, self.N,
416
- self.Br, q_blk_idx, self.Bc, kv_blk_idx)
417
- self.do_att_mask(Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
418
- q_blk_h_aligned, kv_blk_h_aligned)
419
-
420
- Pij_N1MN0, mij_ub, lij_ub = self.softmax_compute(
421
- Sij_ub_N1MN0, mij_ub, lij_ub, q_blk_height, kv_blk_height
422
- ) # self.high_precision=True, Pij_ub return type fp32
423
- # dropout_mask
424
- if self.has_drop_mask:
425
- dropout_mask_gm_offset = self.get_drop_mask_gm_offset(batch_start, batch_idx, self.Nq,
426
- self.N, self.Br, q_blk_idx, self.Bc, kv_blk_idx)
427
- self.do_dropout_mask(Pij_N1MN0, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
428
- q_blk_h_aligned, q_blk_height, precision_type=self.precision_type)
429
- if not self.high_precision:
430
- Pij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16,
431
- (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
432
- name="Pij_l1_K1MK0_ed", scope=L1)
433
- self.cont_data_mv_1_bust(dst=Pij_l1_K1MK0_ed, src=Pij_N1MN0,
434
- burst=kv_blk_h_aligned * q_blk_h_aligned // 16)
435
- Pij_N1MN0 = Pij_l1_K1MK0_ed
436
- if self.high_precision:
437
- self.update_o_m_l_fp32(
438
- Pij_N1MN0,
439
- Vj_l1_K1NK0_ed,
440
- Sij_ub_N1MN0,
441
- mij_ub,
442
- lij_ub,
443
- batch_start,
444
- batch_idx,
445
- kv_blk_idx,
446
- kv_blk_height,
447
- q_blk_idx,
448
- q_blk_height
449
- )
450
- else:
451
- self.update_o_m_l(
452
- Pij_N1MN0,
453
- Vj_l1_K1NK0_ed,
454
- mij_ub,
455
- lij_ub,
456
- batch_start,
457
- batch_idx,
458
- kv_blk_idx,
459
- kv_blk_height,
460
- q_blk_idx,
461
- q_blk_height
462
- )
463
-
464
- def compute_one_core(self, batch_start_sc, batch_num_sc, core_idx_to_tr_info, core_idx):
465
- """The computation of FlashAttention forward on each core"""
466
- with self.tik_instance.for_range(0, batch_num_sc, name="batch_index") as batch_idx:
467
- with self.tik_instance.for_range(0, self.Tc, name="kv_blk_idx") as kv_blk_idx:
468
- with self.tik_instance.if_scope(kv_blk_idx != self.Tc - 1):
469
- self.compute_in_each_kv_block(batch_start_sc, batch_idx, kv_blk_idx, self.Bc,
470
- core_idx_to_tr_info, core_idx)
471
- with self.tik_instance.else_scope():
472
- self.compute_in_each_kv_block(batch_start_sc, batch_idx, kv_blk_idx, self.last_Bc,
473
- core_idx_to_tr_info, core_idx)
474
- if self.high_precision:
475
- block_h = 128
476
- gm_offset = (batch_start_sc + batch_idx) * (self.Nq * self.d)
477
- temp_ub = self.tik_instance.Tensor(FP32, (block_h, self.d), name="temp_ub", scope=UB)
478
- temp_ub_fp16 = self.tik_instance.Tensor(FP16, (block_h, self.d), name="temp_ub_fp16", scope=UB)
479
- try:
480
- dtype32_size = DTYPE_SIZE[FP32]
481
- except KeyError:
482
- raise ValueError("The argument 'FP32' is not valid.")
483
- try:
484
- dtype16_size = DTYPE_SIZE[FP16]
485
- except KeyError:
486
- raise ValueError("The argument 'FP16' is not valid.")
487
- if self.Nq // block_h // 2 > 0:
488
- with self.tik_instance.for_range(0, self.Nq // block_h // 2) as i:
489
- with self.tik_instance.for_range(0, 2, thread_num=2) as t_idx:
490
- index = i * 2 + t_idx
491
- gm_offset += index * (block_h * self.d)
492
- self.cont_data_mv_1_bust(dst=temp_ub, src=self.O_gm_workspace[gm_offset],
493
- burst=block_h * self.d * dtype32_size // 32)
494
- self.tik_instance.h_cast(temp_ub_fp16, temp_ub, "none")
495
- self.cont_data_mv_1_bust(dst=self.O_gm[gm_offset], src=temp_ub_fp16,
496
- burst=block_h * self.d * dtype16_size // 32)
497
- if self.Nq % (block_h * 2) > 0:
498
- gm_offset = (batch_start_sc + batch_idx) * (self.Nq * self.d) + \
499
- (self.Nq // (block_h * 2) * 2) * (block_h * self.d)
500
- last_block_h = self.Nq % (block_h * 2)
501
- self.cont_data_mv_1_bust(dst=temp_ub, src=self.O_gm_workspace[gm_offset],
502
- burst=last_block_h * self.d * dtype32_size // 32)
503
- self.tik_instance.h_cast(temp_ub_fp16, temp_ub, "none")
504
- self.cont_data_mv_1_bust(dst=self.O_gm[gm_offset], src=temp_ub_fp16,
505
- burst=last_block_h * self.d * dtype16_size // 32)
506
-
507
- def collect_inputs(self):
508
- """collect all input gm tensors into input_gm_list,
509
- the input list should keep order with the para order in Primitive and init
510
- """
511
- input_gm_list = [self.Q_gm, self.K_gm, self.V_gm]
512
- if self.has_attn_mask:
513
- input_gm_list.append(self.att_mask_gm)
514
- if self.has_drop_mask:
515
- input_gm_list.append(self.drop_mask_gm)
516
- if self.has_alibi_mask:
517
- input_gm_list.append(self.alibi_mask_gm)
518
-
519
- return input_gm_list
520
-
521
- def collect_outputs(self):
522
- """collect all output gm tensors into output_gm_list,
523
- the output list should keep order with the para order in Primitive and init
524
- """
525
- output_gm_list = [self.O_gm, self.l_gm, self.m_gm]
526
- return output_gm_list
527
-
528
-
529
- def flash_attention(query, key, value, attn_mask, dropout_mask, alibi_mask, output, rowsum, rowmax,
530
- prev_block_num=65536, next_block_num=65536, high_precision=False, tiling_stgy_name='sparse',
531
- kernel_name="flash_attention", disable_debug=True):
532
- """
533
- algorithm: flash_attention_backward
534
-
535
- Parameters
536
- ----------
537
- query : dict. shape and dtype of input, only support float16
538
- key : dict. shape and dtype of input, only support float16
539
- value: dict. shape and dtype of input, only support float16
540
- attn_mask: dict. shape and dtype of input, only support float16
541
- dropout_mask: dict. shape and dtype of input, only support float16
542
- dropout_mask: dict. shape and dtype of input, only support float16
543
- alibi_mask: dict. shape and dtype of input, only support float16
544
- output: dict. shape and dtype of output, only support float16
545
- rowsum: dict. shape and dtype of output, only support float16
546
- rowmax: dict. shape and dtype of output, only support float16
547
- prev_block_num: int. an attribute used to define sparse attention
548
- next_block_num: int. an attribute used to define sparse attention
549
- tiling_stgy_name: str. an attribute used to choose the tiling strategy
550
- kernel_name: str. cce kernel name, default value is real_div
551
- disable_debug: bool. whether disable debug
552
-
553
- Returns
554
- -------
555
- tik_instance
556
- """
557
- fa = FlashAttentionFwd(query=query, key=key, value=value, attn_mask=attn_mask,
558
- dropout_mask=dropout_mask, alibi_mask=alibi_mask, kernel_name=kernel_name,
559
- tiling_stgy=TilingStrategy.from_strategy_name(tiling_stgy_name),
560
- prev_block_num=prev_block_num, next_block_num=next_block_num,
561
- high_precision=high_precision, disable_debug=disable_debug)
562
- fa.compute_process()
563
- return fa.tik_instance