mindspore 2.2.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  84. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,467 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """The backward tik ops of flash attention"""
16
- from tbe import tik
17
-
18
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
19
- from mindspore.ops._op_impl._custom_op.flash_attention.attention import FlashAttention
20
-
21
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
22
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
23
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
24
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import L1
25
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
26
-
27
-
28
- class FlashAttentionBwd(FlashAttention):
29
- """The implementation of FlashAttention backward
30
- This function contains the flash attention backward implementation used in flash attention (see paper)
31
- `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/pdf/2205.14135.pdf>`
32
- """
33
-
34
- def __init__(self, query, key, value, output, dO, rowsum, rowmax, attn_mask, dropout_mask, alibi_mask,
35
- prev_block_num,
36
- next_block_num,
37
- high_precision,
38
- kernel_name,
39
- tiling_stgy: TilingStrategy,
40
- disable_debug):
41
- super().__init__(query, key, value, attn_mask, dropout_mask, alibi_mask, kernel_name,
42
- tiling_stgy, prev_block_num, next_block_num, high_precision, disable_debug)
43
-
44
- if isinstance(query, dict):
45
- self.dO_shape = dO["shape"] # [B, Nq, d]
46
- else:
47
- self.dO_shape = dO.shape
48
-
49
- self.dV_shape = self.v_shape
50
- self.dQ_shape = self.q_shape
51
- self.dK_shape = self.k_shape
52
- self.dQ_gm = None
53
- self.dK_gm = None
54
- self.dV_gm = None
55
- self.O_gm = None
56
- self.dO_gm = None
57
- self.l_gm = None
58
- self.m_gm = None
59
-
60
- def define_outputs(self):
61
- """define output gm tensors"""
62
- self.dQ_gm = self.tik_instance.Tensor(FP32, self.dQ_shape, name="dQ_gm", scope=GM, is_atomic_add=True)
63
- self.dK_gm = self.tik_instance.Tensor(FP32, self.dK_shape, name="dK_gm", scope=GM, is_atomic_add=True)
64
- self.dV_gm = self.tik_instance.Tensor(FP32, self.dV_shape, name="dV_gm", scope=GM, is_atomic_add=True)
65
-
66
- def define_custom_inputs(self):
67
- """define input gm tensors"""
68
- self.O_gm = self.tik_instance.Tensor(FP16, self.O_shape, name="O_gm", scope=GM)
69
- self.dO_gm = self.tik_instance.Tensor(FP16, self.dO_shape, name="dO_gm", scope=GM)
70
- self.l_gm = self.tik_instance.Tensor(self.precision_type, self.l_shape, name="l_gm", scope=GM)
71
- self.m_gm = self.tik_instance.Tensor(FP16, self.m_shape, name="m_gm", scope=GM)
72
-
73
- def collect_inputs(self):
74
- """collect all input gm tensors into input_gm_list,
75
- the input list should keep order with the para order in Primitive and init
76
- """
77
- input_gm_list = [
78
- self.Q_gm, self.K_gm, self.V_gm, self.O_gm, self.dO_gm, self.l_gm,
79
- self.m_gm
80
- ]
81
- if self.has_attn_mask:
82
- input_gm_list.append(self.att_mask_gm)
83
- if self.has_drop_mask:
84
- input_gm_list.append(self.drop_mask_gm)
85
- if self.has_alibi_mask:
86
- input_gm_list.append(self.alibi_mask_gm)
87
- return input_gm_list
88
-
89
- def prepare_global_ones(self):
90
- """Prepare global ones tensor in L1 for cube impl row_sum"""
91
- self.ones_l1 = self.tik_instance.Tensor(FP16, (self.d, 16), name="ones_l1", scope=L1)
92
- with self.tik_instance.new_stmt_scope(disable_sync=False):
93
- ones_ub = self.tik_instance.Tensor(FP16, (self.d, 16), name="ones_ub", scope=UB)
94
- self.tik_instance.h_duplicate(ones_ub, 1.0)
95
- self.cont_data_mv_1_bust(dst=self.ones_l1, src=ones_ub, burst=self.d)
96
-
97
- def compute_Pij(self, Qi_l1_K1MK0_ed, KjT_l1_K1NK0_ed, m, k, n, lm_gm_offset, attn_mask_gm_offset,
98
- dropout_mask_gm_offset, alibi_mask_gm_offset):
99
- """Refer to Algorithm 4 line11-14 in FlashAttention implement Pij computation"""
100
- m_aligned = self.tik_ops_utils.up_align_to_K0(m)
101
- n_aligned = self.tik_ops_utils.up_align_to_K0(n)
102
- Sij_ub = self.tik_ops_utils.matmul_compute(Qi_l1_K1MK0_ed, KjT_l1_K1NK0_ed, m, k, n, N1MN0_to_MN=False)
103
- Pij_drop_ed_ub = self.tik_instance.Tensor(FP16, (n_aligned // self.N0, m_aligned, self.N0),
104
- name="Pij_drop_ed_ub", scope=UB)
105
-
106
- with self.tik_instance.new_stmt_scope(disable_sync=False):
107
- if self.has_alibi_mask:
108
- self.do_alibi_mask(Sij_ub, alibi_mask_gm_offset, m_aligned, n_aligned)
109
- if self.has_attn_mask:
110
- self.do_att_mask(Sij_ub, attn_mask_gm_offset, m, n, m_aligned, n_aligned)
111
-
112
- # move li (ith block of l_gm) and mi (ith block of m_gm) from gm to ub
113
- li_ub = self.tik_instance.Tensor(self.precision_type, (m_aligned,), name="li_ub", scope=UB)
114
- mi_ub = self.tik_instance.Tensor(FP16, (m_aligned,), name="mi_ub", scope=UB)
115
- self.tik_ops_utils.move_vector_from_gm_to_ub(li_ub, self.l_gm, lm_gm_offset, m)
116
- self.tik_ops_utils.move_vector_from_gm_to_ub(mi_ub, self.m_gm, lm_gm_offset, m)
117
- n1 = n_aligned // self.N0
118
- with self.tik_instance.new_stmt_scope(disable_sync=False):
119
- broadcast_mi_ub = self.tik_ops_utils.broadcast(mi_ub, (m, self.N0))
120
- broadcast_mi_ub = broadcast_mi_ub.reshape((1, m, self.N0))
121
- for idx in range(n1):
122
- self.tik_instance.h_sub(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_mi_ub)
123
- li_rec_ub = self.tik_ops_utils.calc_vec_rec(li_ub, m)
124
- with self.tik_instance.new_stmt_scope(disable_sync=False):
125
- if self.high_precision:
126
- # fp16 -> fp32
127
- Sij_ub_fp32 = self.tik_instance.Tensor(FP32, (n_aligned // self.N0, m_aligned, self.N0),
128
- name="Sij_ub_fp32", scope=UB)
129
- self.tik_instance.h_cast(Sij_ub_fp32, Sij_ub, "none")
130
- self.tik_instance.h_exp(Sij_ub_fp32, Sij_ub_fp32)
131
- cur_row_sum_rec = self.tik_instance.Tensor(FP32, (m_aligned, self.N0), name="cur_row_sum_rec",
132
- scope=UB)
133
- for i in range(m_aligned):
134
- src_scalar = self.tik_instance.Scalar(init_value=li_rec_ub[i], dtype=FP32)
135
- self.tik_instance.h_duplicate(cur_row_sum_rec[i, :], src_scalar)
136
- cur_row_sum_rec = cur_row_sum_rec.reshape((1, m_aligned, self.N0))
137
- with self.tik_instance.for_range(0, n_aligned // self.N0) as idx:
138
- self.tik_instance.h_mul(Sij_ub_fp32[idx, :, :], Sij_ub_fp32[idx, :, :], cur_row_sum_rec)
139
- # fp32 -> fp16
140
- self.tik_instance.h_cast(Sij_ub, Sij_ub_fp32, "none")
141
- else:
142
- self.tik_instance.h_exp(Sij_ub, Sij_ub)
143
- broadcast_li_rec_ub = self.tik_ops_utils.broadcast(li_rec_ub, (m_aligned, self.N0))
144
- broadcast_li_rec_ub = broadcast_li_rec_ub.reshape((1, m_aligned, self.N0))
145
- for idx in range(n1):
146
- self.tik_instance.h_mul(Sij_ub[idx, :, :], Sij_ub[idx, :, :], broadcast_li_rec_ub)
147
-
148
- if self.has_drop_mask:
149
- self.do_dropout_mask(Sij_ub, dropout_mask_gm_offset, n_aligned, n, m_aligned, m,
150
- workspace=Pij_drop_ed_ub)
151
- else:
152
- self.cont_data_mv_1_bust(dst=Pij_drop_ed_ub, src=Sij_ub, burst=m_aligned * n_aligned // 16)
153
-
154
- return Sij_ub, Pij_drop_ed_ub
155
-
156
- def compute_Di(self, Di_ub, dOi_ub, qo_gm_offset, q_blk_height):
157
- """Refer to Algorithm 4 line19 in FlashAttention implement Di computation"""
158
- q_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
159
- with self.tik_instance.new_stmt_scope(disable_sync=False):
160
- Oi_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_aligned, self.N0),
161
- scope=UB, name="Oi_ub")
162
- self.tik_instance.data_move(dst=Oi_ub, src=self.O_gm[qo_gm_offset],
163
- sid=0, nburst=self.N1, burst=q_blk_height * self.N0 // 16,
164
- src_stride=(self.Nq - q_blk_height) * self.N0 // 16, dst_stride=0)
165
- self.tik_instance.h_mul(Oi_ub, dOi_ub, Oi_ub)
166
- dOi_Oi_l1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_aligned, self.N0),
167
- name="dOi_Oi_l1_K1MK0", scope=L1)
168
- self.cont_data_mv_1_bust(dst=dOi_Oi_l1_K1MK0, src=Oi_ub, burst=q_blk_height_aligned * self.d // 16)
169
- self.tik_ops_utils.row_sum_cube_impl(dOi_Oi_l1_K1MK0, self.ones_l1, Di_ub, q_blk_height,
170
- self.actual_d, precision_type=FP16)
171
-
172
- def compute_dSij(self, Pij_ub, dOi_l1_K1MK0_ed, VjT_K1NK0_ed, Di_ub, kv_blk_height, q_blk_height,
173
- dropout_mask_gm_offset):
174
- """Refer to Algorithm 4 line20 in FlashAttention implement dSij computation"""
175
- with self.tik_instance.new_stmt_scope(disable_sync=False):
176
- dPij_ub = self.tik_ops_utils.matmul_compute(dOi_l1_K1MK0_ed, VjT_K1NK0_ed,
177
- q_blk_height, self.actual_d, kv_blk_height, N1MN0_to_MN=False)
178
- q_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(q_blk_height)
179
- kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
180
- # dropout_mask
181
- if self.has_drop_mask:
182
- self.do_dropout_mask(dPij_ub, dropout_mask_gm_offset, kv_blk_height_aligned, kv_blk_height,
183
- q_blk_height_aligned, q_blk_height)
184
- # dPij - Di
185
- with self.tik_instance.new_stmt_scope(disable_sync=False):
186
- broadcast_Di_ub = self.tik_ops_utils.broadcast(Di_ub, (q_blk_height_aligned, self.N0))
187
- broadcast_Di_ub = broadcast_Di_ub.reshape((1, q_blk_height_aligned, self.N0))
188
- n1 = kv_blk_height_aligned // self.N0
189
- for idx in range(n1):
190
- self.tik_instance.h_sub(dPij_ub[idx, :, :], dPij_ub[idx, :, :], broadcast_Di_ub)
191
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dPij_ub)
192
- return Pij_ub
193
-
194
- def update_dVj(self,
195
- PijT_l1_K1MK0_ed,
196
- dOi_l1_K1NK0_ed,
197
- kv_gm_offset,
198
- kv_blk_height,
199
- q_blk_height):
200
- """Refer to Algorithm 4 line16 in FlashAttention implement dVj update"""
201
- with self.tik_instance.new_stmt_scope(disable_sync=False):
202
- PijT_Oi_ub = self.tik_ops_utils.matmul_compute(PijT_l1_K1MK0_ed, dOi_l1_K1NK0_ed,
203
- kv_blk_height, q_blk_height,
204
- self.actual_d, N1MN0_to_MN=False,
205
- precision_type=FP32)
206
- self.tik_instance.set_atomic_add(1)
207
- self.tik_instance.data_move(dst=self.dV_gm[kv_gm_offset], src=PijT_Oi_ub, sid=0,
208
- nburst=self.N1, burst=kv_blk_height * self.N0 // 8,
209
- src_stride=0, dst_stride=(self.Nq - kv_blk_height) * self.N0 // 8)
210
- self.tik_instance.set_atomic_add(0)
211
-
212
- def update_dQi(self,
213
- dSij_l1_K1MK0_ed,
214
- Kj_l1_K1NK0_ed,
215
- qo_gm_offset,
216
- q_blk_height,
217
- kv_blk_height):
218
- """Refer to Algorithm 4 line21 in FlashAttention implement dQi update"""
219
- with self.tik_instance.new_stmt_scope(disable_sync=False):
220
- dSij_Kj_ub = self.tik_ops_utils.matmul_compute(dSij_l1_K1MK0_ed, Kj_l1_K1NK0_ed,
221
- q_blk_height, kv_blk_height,
222
- self.actual_d, N1MN0_to_MN=False, precision_type=FP32)
223
- self.tik_instance.set_atomic_add(1)
224
- self.tik_instance.data_move(dst=self.dQ_gm[qo_gm_offset], src=dSij_Kj_ub, sid=0,
225
- nburst=self.d // self.N0, burst=q_blk_height * self.N0 // 8,
226
- src_stride=0, dst_stride=(self.Nq - q_blk_height) * self.N0 // 8)
227
- self.tik_instance.set_atomic_add(0)
228
-
229
- def update_dKj(self,
230
- dSijT_l1_K1MK0_ed,
231
- Qi_l1_K1NK0_ed,
232
- kv_gm_offset,
233
- kv_blk_height,
234
- q_blk_height):
235
- """Refer to Algorithm 4 line22 in FlashAttention implement dKi update"""
236
- with self.tik_instance.new_stmt_scope(disable_sync=False):
237
- dSijT_Qi_ub = self.tik_ops_utils.matmul_compute(dSijT_l1_K1MK0_ed, Qi_l1_K1NK0_ed,
238
- kv_blk_height, q_blk_height,
239
- self.actual_d, N1MN0_to_MN=False, precision_type=FP32)
240
- self.tik_instance.set_atomic_add(1)
241
- self.tik_instance.data_move(dst=self.dK_gm[kv_gm_offset], src=dSijT_Qi_ub, sid=0,
242
- nburst=self.d // self.N0, burst=kv_blk_height * self.N0 // 8,
243
- src_stride=0, dst_stride=(self.Nq - kv_blk_height) * self.N0 // 8)
244
- self.tik_instance.set_atomic_add(0)
245
-
246
- def compute_in_each_kv_block(self, batch_start, batch_idx, kv_blk_idx, kv_blk_height,
247
- core_idx_to_tr_info, core_idx):
248
- """The backward computation in each outer loop"""
249
- kv_blk_height_aligned = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
250
- kv_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.N, self.d,
251
- self.Bc, kv_blk_idx)
252
- # load KjT
253
- Kj_l1_1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
254
- name="Kj_l1_1_K1MK0",
255
- scope=L1)
256
- self.tik_instance.data_move(dst=Kj_l1_1_K1MK0, src=self.K_gm[kv_gm_offset],
257
- sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
258
- src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
259
-
260
- # load Kj
261
- Kj_l1_2 = self.tik_instance.Tensor(FP16, (kv_blk_height_aligned, self.d), name="Kj_l1_2",
262
- scope=L1)
263
- with self.tik_instance.new_stmt_scope(disable_sync=False):
264
- Kj_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0),
265
- name="Kj_ub", scope=UB)
266
- self.tik_instance.data_move(dst=Kj_ub, src=self.K_gm[kv_gm_offset],
267
- sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
268
- src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
269
- # (N1, K, N0) -> (K, N)
270
- Kj_ub = self.tik_ops_utils.N1MN0_TO_MN(Kj_ub)
271
- # (K, N) -> (K1, N, K0)
272
- Kj_l1_2_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Kj_ub, workspace_tensor=Kj_l1_2)
273
-
274
- # load VjT
275
- Vj_l1 = self.tik_instance.Tensor(FP16, (self.d // self.N0, kv_blk_height_aligned, self.N0), name="Vj_l1",
276
- scope=L1)
277
- self.tik_instance.data_move(dst=Vj_l1, src=self.V_gm[kv_gm_offset],
278
- sid=0, nburst=self.N1, burst=kv_blk_height_aligned * self.N0 // 16,
279
- src_stride=(self.N - kv_blk_height_aligned) * self.N0 // 16, dst_stride=0)
280
-
281
- tr_start_s = self.tik_instance.Scalar("int32", name="tr_start_s")
282
- tr_end_s = self.tik_instance.Scalar("int32", name="tr_end_s")
283
- tr_start_s.set_as(core_idx_to_tr_info[core_idx, batch_start + batch_idx, 0])
284
- tr_end_s.set_as(core_idx_to_tr_info[core_idx, batch_start + batch_idx, 1])
285
- with self.tik_instance.for_range(tr_start_s, tr_end_s, name="q_blk_idx") as q_blk_idx:
286
- with self.tik_instance.if_scope(tik.all(kv_blk_idx - self.next_block_num <= q_blk_idx,
287
- q_blk_idx <= kv_blk_idx + self.prev_block_num)):
288
- with self.tik_instance.if_scope(q_blk_idx != self.Tr - 1):
289
- self.compute_in_each_q_block(Kj_l1_1_K1MK0,
290
- Kj_l1_2_K1NK0_ed,
291
- Vj_l1,
292
- batch_idx,
293
- batch_start,
294
- kv_gm_offset,
295
- kv_blk_height,
296
- self.Br,
297
- kv_blk_idx,
298
- q_blk_idx)
299
- with self.tik_instance.else_scope():
300
- self.compute_in_each_q_block(Kj_l1_1_K1MK0,
301
- Kj_l1_2_K1NK0_ed,
302
- Vj_l1,
303
- batch_idx,
304
- batch_start,
305
- kv_gm_offset,
306
- kv_blk_height,
307
- self.last_Br,
308
- kv_blk_idx,
309
- q_blk_idx)
310
-
311
- def compute_in_each_q_block(self, KjT_l1_K1NK0_ed, Kj_l1_K1NK0_ed, VjT_l1_K1NK0_ed,
312
- batch_idx, batch_start, kv_gm_offset, kv_blk_height,
313
- q_blk_height, kv_blk_idx, q_blk_idx):
314
- """The backward computation in each inner loop"""
315
- kv_blk_height_alig = self.tik_ops_utils.up_align_to_K0(kv_blk_height)
316
- q_blk_height_alig = self.tik_ops_utils.up_align_to_K0(q_blk_height)
317
-
318
- qo_gm_offset = self.get_gm_offset(batch_start, batch_idx, self.Nq, self.d, self.Br, q_blk_idx)
319
- Qi_l1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
320
- name="Qi_l1_K1MK0",
321
- scope=L1)
322
- self.tik_instance.data_move(dst=Qi_l1_K1MK0, src=self.Q_gm[qo_gm_offset],
323
- sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
324
- src_stride=(self.Nq - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
325
-
326
- Qi_l1_right = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="Qi_l1_right",
327
- scope=L1)
328
- with self.tik_instance.new_stmt_scope(disable_sync=False):
329
- Qi_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
330
- name="Qi_ub", scope=UB)
331
- self.tik_instance.data_move(dst=Qi_ub, src=self.Q_gm[qo_gm_offset],
332
- sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
333
- src_stride=(self.N - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
334
- # (N1, K, N0) -> (K, N)
335
- Qi_ub = self.tik_ops_utils.N1MN0_TO_MN(Qi_ub)
336
- # (K, N) -> (K1, N, K0)
337
- Qi_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Qi_ub, workspace_tensor=Qi_l1_right)
338
-
339
- lm_gm_offset = self.get_l_m_gm_offset(batch_start, batch_idx, self.Nq, self.Br, q_blk_idx)
340
- attn_mask_gm_offset, dropout_mask_gm_offset, alibi_mask_gm_offset = None, None, None
341
- if self.has_attn_mask:
342
- attn_mask_gm_offset = self.get_attn_mask_gm_offset(batch_start, batch_idx, self.Nq, self.N,
343
- self.Br, q_blk_idx, self.Bc, kv_blk_idx)
344
- if self.has_drop_mask:
345
- dropout_mask_gm_offset = self.get_drop_mask_gm_offset(batch_start, batch_idx, self.Nq, self.N,
346
- self.Br, q_blk_idx, self.Bc, kv_blk_idx)
347
- if self.has_alibi_mask:
348
- alibi_mask_gm_offset = self.get_alibi_gm_offset(batch_start, batch_idx, self.N, self.Bc, kv_blk_idx)
349
- Pij_ub, Pij_drop_ed_ub = self.compute_Pij(Qi_l1_K1MK0, KjT_l1_K1NK0_ed,
350
- q_blk_height, self.actual_d, kv_blk_height,
351
- lm_gm_offset, attn_mask_gm_offset,
352
- dropout_mask_gm_offset, alibi_mask_gm_offset)
353
-
354
- dOi_l1_right = self.tik_instance.Tensor(FP16, (q_blk_height_alig, self.d), name="dOi_l1_right",
355
- scope=L1)
356
- Di_ub = self.tik_instance.Tensor(FP16, (q_blk_height_alig,), name="Di_ub", scope=UB)
357
- with self.tik_instance.new_stmt_scope(disable_sync=False):
358
- dOi_ub = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
359
- name="dOi_ub", scope=UB)
360
- self.tik_instance.data_move(dst=dOi_ub, src=self.dO_gm[qo_gm_offset],
361
- sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
362
- src_stride=(self.Nq - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
363
-
364
- self.compute_Di(Di_ub, dOi_ub, qo_gm_offset, q_blk_height)
365
- # (N1, K, N0) -> (K, N)
366
- dOi_ub = self.tik_ops_utils.N1MN0_TO_MN(dOi_ub)
367
- # (K, N) -> (K1, N, K0)
368
- dOi_l1_K1NK0_ed = self.tik_ops_utils.KN_TO_K1NK0(dOi_ub, workspace_tensor=dOi_l1_right)
369
-
370
- dOi_l1_K1MK0 = self.tik_instance.Tensor(FP16, (self.d // self.N0, q_blk_height_alig, self.N0),
371
- name="dOi_l1_K1MK0",
372
- scope=L1)
373
-
374
- self.tik_instance.data_move(dst=dOi_l1_K1MK0, src=self.dO_gm[qo_gm_offset],
375
- sid=0, nburst=self.N1, burst=q_blk_height_alig * self.N0 // 16,
376
- src_stride=(self.Nq - q_blk_height_alig) * self.N0 // 16, dst_stride=0)
377
- Pij_l1 = self.tik_instance.Tensor(FP16, (q_blk_height_alig, kv_blk_height_alig), name="Pij_l1", scope=L1)
378
- Pij_drop_ed_ub = self.tik_ops_utils.N1MN0_TO_MN(Pij_drop_ed_ub)
379
- PijT_l1_K1MK0_ed = self.tik_ops_utils.KN_TO_K1NK0(Pij_drop_ed_ub, workspace_tensor=Pij_l1)
380
- self.update_dVj(PijT_l1_K1MK0_ed, dOi_l1_K1NK0_ed,
381
- kv_gm_offset, kv_blk_height, q_blk_height)
382
- # (L1: 512K)
383
- dSij_l1_K1MK0_ed = self.tik_instance.Tensor(FP16, (kv_blk_height_alig // self.N0, q_blk_height_alig, self.N0),
384
- name="dSij_l1_1", scope=L1)
385
- dSij_l1_2 = self.tik_instance.Tensor(FP16, (q_blk_height_alig, kv_blk_height_alig),
386
- name="dSij_l1_2", scope=L1)
387
- with self.tik_instance.new_stmt_scope(disable_sync=False):
388
- dSij_ub = self.compute_dSij(Pij_ub,
389
- dOi_l1_K1MK0,
390
- VjT_l1_K1NK0_ed,
391
- Di_ub,
392
- kv_blk_height,
393
- q_blk_height,
394
- dropout_mask_gm_offset)
395
- self.cont_data_mv_1_bust(dst=dSij_l1_K1MK0_ed, src=dSij_ub,
396
- burst=kv_blk_height_alig * q_blk_height_alig // 16)
397
- dSij_ub = self.tik_ops_utils.N1MN0_TO_MN(dSij_ub)
398
- dSijT_l1_K1MK0_ed = self.tik_ops_utils.KN_TO_K1NK0(dSij_ub, workspace_tensor=dSij_l1_2)
399
- self.update_dQi(dSij_l1_K1MK0_ed, Kj_l1_K1NK0_ed,
400
- qo_gm_offset, q_blk_height, kv_blk_height)
401
- self.update_dKj(dSijT_l1_K1MK0_ed, Qi_l1_K1NK0_ed,
402
- kv_gm_offset, kv_blk_height, q_blk_height)
403
-
404
- def compute_one_core(self, batch_start_sc, batch_num_sc, core_idx_to_tr_info, core_idx):
405
- """The computation of FlashAttention backward on each core"""
406
- with self.tik_instance.for_range(0, batch_num_sc, name="batch_index") as batch_idx:
407
- with self.tik_instance.for_range(0, self.Tc, name="kv_blk_idx") as kv_blk_idx:
408
- with self.tik_instance.if_scope(kv_blk_idx != self.Tc - 1):
409
- self.compute_in_each_kv_block(batch_start_sc, batch_idx, kv_blk_idx, self.Bc,
410
- core_idx_to_tr_info, core_idx)
411
- with self.tik_instance.else_scope():
412
- self.compute_in_each_kv_block(batch_start_sc, batch_idx, kv_blk_idx, self.last_Bc,
413
- core_idx_to_tr_info, core_idx)
414
-
415
- def collect_outputs(self):
416
- """collect all output gm tensors into output_gm_list,
417
- the output list should keep order with the para order in Primitive and init
418
- """
419
- output_gm_list = [self.dQ_gm, self.dK_gm, self.dV_gm]
420
- return output_gm_list
421
-
422
-
423
- def flash_attention_grad(Query, Key, Value, Output, dO, rowsum, rowmax, attn_mask, dropout_mask, alibi_mask,
424
- dq, dk, dv,
425
- prev_block_num=65536,
426
- next_block_num=65536,
427
- high_precision=False,
428
- tiling_stgy_name='sparse',
429
- kernel_name="flash_attention_grad",
430
- disable_debug=True):
431
- """
432
- algorithm: flash_attention_backward
433
-
434
- Parameters
435
- ----------
436
- Query : dict. shape and dtype of input, only support float16
437
- Key : dict. shape and dtype of input, only support float16
438
- Value: dict. shape and dtype of input, only support float16
439
- Output: dict. shape and dtype of input, only support float16
440
- dO: dict. shape and dtype of input, only support float16
441
- rowsum: dict. shape and dtype of input, only support float16
442
- rowmax: dict. shape and dtype of input, only support float16
443
- dropout_mask: dict. shape and dtype of input, only support float16
444
- dropout_mask: dict. shape and dtype of input, only support float16
445
- alibi_mask: dict. shape and dtype of input, only support float16
446
- dq: dict. shape and dtype of output, only support float16
447
- dk: dict. shape and dtype of output, only support float16
448
- dv: dict. shape and dtype of output, only support float16
449
- prev_block_num: int. an attribute used to define sparse attention
450
- next_block_num: int. an attribute used to define sparse attention
451
- tiling_stgy_name: str. an attribute used to choose the tiling strategy
452
- kernel_name: str. cce kernel name, default value is real_div
453
- disable_debug: bool. whether disable debug
454
-
455
- Returns
456
- -------
457
- tik_instance
458
- """
459
- fa_grad = FlashAttentionBwd(Query, Key, Value, Output, dO, rowsum, rowmax, attn_mask, dropout_mask,
460
- alibi_mask, prev_block_num=prev_block_num,
461
- next_block_num=next_block_num,
462
- high_precision=high_precision,
463
- kernel_name=kernel_name,
464
- tiling_stgy=TilingStrategy.from_strategy_name(tiling_stgy_name),
465
- disable_debug=disable_debug)
466
- fa_grad.compute_process()
467
- return fa_grad.tik_instance