mindspore 2.2.10__cp39-cp39-win_amd64.whl → 2.2.14__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (128) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +2 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/__init__.py +3 -2
  7. mindspore/_extends/parse/parser.py +6 -1
  8. mindspore/_extends/parse/standard_method.py +12 -2
  9. mindspore/common/_utils.py +16 -0
  10. mindspore/common/tensor.py +0 -2
  11. mindspore/communication/management.py +3 -0
  12. mindspore/context.py +34 -4
  13. mindspore/dataset/engine/cache_client.py +8 -5
  14. mindspore/dataset/engine/datasets.py +23 -0
  15. mindspore/dataset/engine/validators.py +1 -1
  16. mindspore/dataset/vision/py_transforms_util.py +2 -2
  17. mindspore/dnnl.dll +0 -0
  18. mindspore/experimental/optim/lr_scheduler.py +5 -6
  19. mindspore/jpeg62.dll +0 -0
  20. mindspore/mindrecord/tools/cifar100_to_mr.py +49 -57
  21. mindspore/mindrecord/tools/cifar10_to_mr.py +46 -55
  22. mindspore/mindrecord/tools/csv_to_mr.py +3 -8
  23. mindspore/mindrecord/tools/mnist_to_mr.py +4 -9
  24. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -4
  25. mindspore/mindspore_backend.dll +0 -0
  26. mindspore/mindspore_common.dll +0 -0
  27. mindspore/mindspore_core.dll +0 -0
  28. mindspore/mindspore_glog.dll +0 -0
  29. mindspore/mindspore_shared_lib.dll +0 -0
  30. mindspore/nn/layer/activation.py +1 -1
  31. mindspore/nn/layer/embedding.py +2 -2
  32. mindspore/nn/layer/flash_attention.py +48 -135
  33. mindspore/nn/loss/loss.py +1 -1
  34. mindspore/nn/optim/ada_grad.py +2 -2
  35. mindspore/nn/optim/sgd.py +3 -2
  36. mindspore/nn/wrap/__init__.py +4 -2
  37. mindspore/nn/wrap/cell_wrapper.py +6 -3
  38. mindspore/numpy/math_ops.py +1 -1
  39. mindspore/opencv_core452.dll +0 -0
  40. mindspore/opencv_imgcodecs452.dll +0 -0
  41. mindspore/opencv_imgproc452.dll +0 -0
  42. mindspore/ops/__init__.py +3 -0
  43. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -31
  44. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  45. mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
  46. mindspore/ops/_grad_experimental/grad_math_ops.py +37 -17
  47. mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
  48. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  49. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  50. mindspore/ops/function/array_func.py +6 -5
  51. mindspore/ops/function/debug_func.py +1 -1
  52. mindspore/ops/function/linalg_func.py +21 -11
  53. mindspore/ops/function/math_func.py +3 -0
  54. mindspore/ops/function/nn_func.py +13 -11
  55. mindspore/ops/function/parameter_func.py +2 -0
  56. mindspore/ops/function/sparse_unary_func.py +2 -2
  57. mindspore/ops/function/vmap_func.py +1 -0
  58. mindspore/ops/operations/__init__.py +5 -2
  59. mindspore/ops/operations/_embedding_cache_ops.py +1 -1
  60. mindspore/ops/operations/_grad_ops.py +3 -4
  61. mindspore/ops/operations/_inner_ops.py +56 -1
  62. mindspore/ops/operations/_quant_ops.py +4 -4
  63. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  64. mindspore/ops/operations/array_ops.py +15 -4
  65. mindspore/ops/operations/custom_ops.py +1 -1
  66. mindspore/ops/operations/debug_ops.py +1 -1
  67. mindspore/ops/operations/image_ops.py +3 -3
  68. mindspore/ops/operations/inner_ops.py +49 -0
  69. mindspore/ops/operations/math_ops.py +65 -3
  70. mindspore/ops/operations/nn_ops.py +95 -28
  71. mindspore/ops/operations/random_ops.py +2 -0
  72. mindspore/ops/operations/sparse_ops.py +4 -4
  73. mindspore/ops/silent_check.py +162 -0
  74. mindspore/parallel/__init__.py +3 -2
  75. mindspore/parallel/_auto_parallel_context.py +82 -3
  76. mindspore/parallel/_parallel_serialization.py +34 -2
  77. mindspore/parallel/_tensor.py +3 -1
  78. mindspore/parallel/_transformer/transformer.py +8 -8
  79. mindspore/parallel/checkpoint_transform.py +191 -45
  80. mindspore/profiler/parser/ascend_cluster_generator.py +111 -0
  81. mindspore/profiler/parser/ascend_communicate_generator.py +315 -0
  82. mindspore/profiler/parser/ascend_flops_generator.py +8 -2
  83. mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
  84. mindspore/profiler/parser/ascend_hccl_generator.py +2 -2
  85. mindspore/profiler/parser/ascend_msprof_exporter.py +30 -6
  86. mindspore/profiler/parser/ascend_msprof_generator.py +16 -5
  87. mindspore/profiler/parser/ascend_op_generator.py +15 -7
  88. mindspore/profiler/parser/ascend_timeline_generator.py +5 -2
  89. mindspore/profiler/parser/base_timeline_generator.py +11 -3
  90. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
  91. mindspore/profiler/parser/framework_parser.py +8 -2
  92. mindspore/profiler/parser/memory_usage_parser.py +8 -2
  93. mindspore/profiler/parser/minddata_analyzer.py +8 -2
  94. mindspore/profiler/parser/minddata_parser.py +1 -1
  95. mindspore/profiler/parser/msadvisor_analyzer.py +4 -2
  96. mindspore/profiler/parser/msadvisor_parser.py +9 -3
  97. mindspore/profiler/profiling.py +97 -25
  98. mindspore/rewrite/api/node.py +1 -1
  99. mindspore/rewrite/api/symbol_tree.py +2 -2
  100. mindspore/rewrite/parsers/for_parser.py +6 -6
  101. mindspore/rewrite/parsers/module_parser.py +4 -4
  102. mindspore/tinyxml2.dll +0 -0
  103. mindspore/train/callback/_checkpoint.py +8 -8
  104. mindspore/train/callback/_landscape.py +2 -3
  105. mindspore/train/callback/_summary_collector.py +6 -7
  106. mindspore/train/dataset_helper.py +6 -0
  107. mindspore/train/model.py +17 -5
  108. mindspore/train/serialization.py +6 -1
  109. mindspore/train/summary/_writer_pool.py +1 -1
  110. mindspore/train/summary/summary_record.py +5 -6
  111. mindspore/turbojpeg.dll +0 -0
  112. mindspore/version.py +1 -1
  113. {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/METADATA +3 -2
  114. {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/RECORD +117 -124
  115. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  116. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  117. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  118. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  119. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  120. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  121. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  122. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  123. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  124. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  125. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  126. {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/WHEEL +0 -0
  127. {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/entry_points.txt +0 -0
  128. {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/top_level.txt +0 -0
@@ -1,406 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """the base class of flash attention"""
16
- from abc import ABCMeta
17
- from abc import abstractmethod
18
- from functools import partial
19
- from collections import defaultdict
20
-
21
- import te.platform as tbe_platform
22
- from tbe import tik
23
- from tbe.common.platform import get_soc_spec
24
-
25
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
26
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
27
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import GM
28
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import MASK_FILL_VALUE
29
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
30
- from mindspore.ops._op_impl._custom_op.flash_attention.tik_ops_utils import TikOpsUtils
31
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
32
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
33
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.sparse_tiling import SparseTiling
34
-
35
-
36
- class FlashAttention(metaclass=ABCMeta):
37
- """The base class of FlashAttention"""
38
-
39
- def __init__(self, q, k, v, attn_mask, dropout_mask, alibi_mask, kernel_name,
40
- tiling_stgy_cls,
41
- prev_block_num=65536,
42
- next_block_num=65536,
43
- high_precision=False,
44
- disable_debug=True):
45
- """
46
- Init parameter shape
47
- :param q: with shape: (B, h, N, d)
48
- :param k: with shape: (B, h, N, d)
49
- :param v: with shape: (B, h, N, d)
50
- :param attn_mask: with shape: (1, N, N) or (B, N, N)
51
- :param dropout_mask: with shape: (B, h, N, N)
52
- :param alibi_mask: with shape: (B, h, 1, N)
53
- :param kernel_name:
54
- :param tiling_stgy_cls:
55
- :param prev_block_num:
56
- :param next_block_num:
57
- :param disable_debug:
58
- """
59
- self.tik_instance = tik.Tik(disable_debug=disable_debug)
60
- self.core_num = get_soc_spec(tbe_platform.CORE_NUM)
61
- self.M = tbe_platform.get_soc_spec(tbe_platform.L1_SIZE)
62
- self.kernel_name = kernel_name
63
- self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
64
- src_stride=0,
65
- dst_stride=0)
66
- self.tik_ops_utils = TikOpsUtils(self.tik_instance)
67
- self.parse_input_shape(alibi_mask, attn_mask, dropout_mask, k, q, v)
68
- # NZ
69
- _, _, N1, M1, M0, N0 = self.q_shape
70
- self.M1 = M1
71
- self.N1 = N1
72
- self.M0 = M0
73
- self.N0 = N0
74
- self.d = N1 * N0
75
- # ND
76
- batch_size, h, Nq, actual_d = self.q_ori_shape
77
- self.head_num = h
78
- self.B, self.Nq = batch_size * h, Nq
79
- self.N = self.k_ori_shape[2]
80
- self.actual_d = actual_d
81
-
82
- self.l_shape = [batch_size, h, self.Nq]
83
- self.m_shape = [batch_size, h, self.Nq]
84
- self.O_shape = self.q_shape
85
-
86
- self.prev_block_num = prev_block_num
87
- self.next_block_num = next_block_num
88
- self.high_precision = high_precision
89
- if self.high_precision:
90
- self.precision_type = FP32
91
- else:
92
- self.precision_type = FP16
93
- if tiling_stgy_cls is None:
94
- self.tiling_stgy = SparseTiling(self.Nq, self.N, self.d)
95
- else:
96
- self.tiling_stgy: TilingStrategy = tiling_stgy_cls(self.Nq, self.N, self.d)
97
- self.Br = None
98
- self.last_Br = None
99
- self.Bc = None
100
- self.last_Bc = None
101
- self.Tr = None
102
- self.Tc = None
103
- self.Q_gm = None
104
- self.K_gm = None
105
- self.V_gm = None
106
- self.dim_mask_gm = None
107
- self.att_mask_gm = None
108
- self.drop_mask_gm = None
109
- self.alibi_mask_gm = None
110
-
111
- @staticmethod
112
- def get_l_m_gm_offset(batch_start, batch_idx, h, block_h, block_idx):
113
- """get l m gm offset"""
114
- gm_offset = (batch_start + batch_idx) * h + block_idx * block_h
115
- return gm_offset
116
-
117
- @staticmethod
118
- def get_alibi_gm_offset(batch_start, batch_idx, w, block_w, block_idx):
119
- """get alibi gm offset"""
120
- gm_offset = (batch_start + batch_idx) * w + block_idx * block_w
121
- return gm_offset
122
-
123
- @staticmethod
124
- def get_drop_mask_gm_offset(batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
125
- """get drop mask gm offset"""
126
- gm_offset = (batch_start + batch_idx) * h * w + block_h_idx * (w * block_h) + block_w_idx * block_w
127
- return gm_offset
128
-
129
- @abstractmethod
130
- def define_custom_inputs(self):
131
- """define custom inputs"""
132
- raise NotImplementedError
133
-
134
- @abstractmethod
135
- def define_outputs(self):
136
- """define outputs"""
137
- raise NotImplementedError
138
-
139
- @abstractmethod
140
- def collect_inputs(self):
141
- """collect inputs"""
142
- raise NotImplementedError
143
-
144
- @abstractmethod
145
- def collect_outputs(self):
146
- """collect outputs"""
147
- raise NotImplementedError
148
-
149
- @abstractmethod
150
- def compute_one_core(self, batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx):
151
- """compute one core"""
152
- raise NotImplementedError
153
-
154
- @abstractmethod
155
- def prepare_global_ones(self):
156
- """prepare global ones"""
157
- raise NotImplementedError
158
-
159
- def get_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_idx):
160
- """get gm offset"""
161
- gm_offset = (batch_start + batch_idx) * h * w + block_idx * block_h * self.N0
162
- return gm_offset
163
-
164
- def get_cur_tr_block_num(self, tr_idx):
165
- """get cur tr block_num"""
166
- cur_prev_block_num = min(tr_idx, self.prev_block_num)
167
- cur_next_block_num = min(self.next_block_num, self.Tc - tr_idx - 1)
168
- block_num = cur_prev_block_num + 1 + cur_next_block_num
169
- return block_num
170
-
171
- def get_total_block_num(self):
172
- """get total block num"""
173
- block_num = 0
174
- for b_idx in range(self.B):
175
- for tr_idx in range(self.Tr):
176
- block_num += self.get_cur_tr_block_num(tr_idx)
177
- return block_num
178
-
179
- def update_core_task_map(self,
180
- core_b_map,
181
- core_b_tr_map,
182
- core_idx,
183
- b_start,
184
- b_end,
185
- tr_start,
186
- tr_end):
187
- """update core task map"""
188
- core_b_map[core_idx][0] = min(core_b_map[core_idx][0], b_start)
189
- if tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
190
- core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end - 1)
191
- else:
192
- core_b_map[core_idx][1] = max(core_b_map[core_idx][1], b_end)
193
- for b_idx in range(b_start, b_end + 1):
194
- if b_idx == b_end and tr_end == 0: # 跨head,但跨过的head不会被当前的core处理
195
- break
196
- elif b_idx == b_start and b_idx == b_end: # 没跨head
197
- core_b_tr_map[core_idx][b_idx] = (tr_start, tr_end)
198
- elif b_idx == b_start: # 跨head,第一个head
199
- core_b_tr_map[core_idx][b_idx] = (tr_start, self.Tr)
200
- elif b_idx == b_end: # 跨head,最后一个head
201
- core_b_tr_map[core_idx][b_idx] = (0, tr_end)
202
- else: # 跨head,中间的head
203
- core_b_tr_map[core_idx][b_idx] = (0, self.Tr)
204
-
205
- def convert_py_dict_to_tik_tensor(self, core_b_map, core_b_tr_map):
206
- """convert py dict to tik tensor"""
207
- # python dict -> tik tensor
208
- # [batch_start, batch_idx_end] -> [batch_start, batch_num]
209
- # [tr_start, tr_idx_end] -> [tr_start, tr_idx_end)
210
- core_idx_to_batch_info = self.tik_instance.Tensor(
211
- "int32", (self.core_num, 2), name="core_idx_to_batch_info", scope=UB
212
- )
213
- core_idx_to_tr_info = self.tik_instance.Tensor(
214
- "int32", (self.core_num, self.B, 2), name="core_idx_to_tr_info", scope=UB
215
- )
216
- for core_idx in core_b_map.keys():
217
- batch_start, batch_end = core_b_map[core_idx]
218
- core_idx_to_batch_info[core_idx, 0] = batch_start
219
- core_idx_to_batch_info[core_idx, 1] = batch_end - batch_start + 1
220
- for batch_idx in core_b_tr_map[core_idx].keys():
221
- tr_start, tr_end = core_b_tr_map[core_idx][batch_idx]
222
- core_idx_to_tr_info[core_idx, batch_idx, 0] = tr_start
223
- core_idx_to_tr_info[core_idx, batch_idx, 1] = tr_end
224
-
225
- return core_idx_to_batch_info, core_idx_to_tr_info
226
-
227
- def get_core_task_info(self):
228
- """
229
- Get batch start and batch number of each NPU core.
230
- :return: Tensor([[core_1_batch_start, core_1_batch_num],...,[core_m_batch_start,
231
- core_m_batch_num]]), Tensor([[[core_1_batch_1_Tr_start, core_1_batch_1_Tr_end],...[core_1_batch_n_Tr_start,
232
- core_1_batch_n_Tr_end]],...,[[core_m_batch_1_Tr_start, core_m_batch_1_Tr_end],...[core_m_batch_n_Tr_start,
233
- core_m_batch_n_Tr_end]]
234
- """
235
- if self.core_num > self.B * self.Tr:
236
- self.core_num = self.B * self.Tr
237
-
238
- total_blk_num = self.get_total_block_num()
239
- b_start = 0
240
- tr_start = 0
241
- remain_blk_num = total_blk_num
242
- core_b_map = defaultdict(lambda: [100000, -1])
243
- core_b_tr_map = defaultdict(lambda: defaultdict(list))
244
- for core_idx in range(self.core_num):
245
- cur_core_blk_num = 0
246
- cur_each_core_blk_num = remain_blk_num // (self.core_num - core_idx)
247
- cur_core_finished = False
248
- b_end = b_start
249
- tr_end = tr_start
250
- while b_end < self.B:
251
- while tr_end < self.Tr:
252
- cur_tr_blk_num = self.get_cur_tr_block_num(tr_end)
253
- if abs(cur_core_blk_num - cur_each_core_blk_num) <= \
254
- (cur_core_blk_num + cur_tr_blk_num - cur_each_core_blk_num):
255
- self.update_core_task_map(core_b_map, core_b_tr_map, core_idx, b_start, b_end, tr_start, tr_end)
256
- remain_blk_num -= cur_core_blk_num
257
- cur_core_finished = True
258
- break
259
- else:
260
- cur_core_blk_num += cur_tr_blk_num
261
- tr_end += 1
262
- if tr_end == self.Tr:
263
- tr_end = 0
264
- b_end += 1
265
- if cur_core_finished:
266
- b_start = b_end
267
- tr_start = tr_end
268
- break
269
- core_idx_to_batch_info, core_idx_to_tr_info = self.convert_py_dict_to_tik_tensor(core_b_map, core_b_tr_map)
270
- return core_idx_to_batch_info, core_idx_to_tr_info
271
-
272
- def get_attn_mask_gm_offset(self, batch_start, batch_idx, h, w, block_h, block_h_idx, block_w, block_w_idx):
273
- """get attn mask gm offset"""
274
- if self.att_mask_shape[0] == 1:
275
- gm_offset = block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
276
- else:
277
- gm_offset = ((batch_start + batch_idx) // self.head_num) * h * w \
278
- + block_w_idx * (h * block_w) + block_h_idx * block_h * self.N0
279
- return gm_offset
280
-
281
- def parse_input_shape(self, alibi_mask, attn_mask, dropout_mask, k, q, v):
282
- """parser input shape"""
283
- self.has_attn_mask = False
284
- self.has_drop_mask = False
285
- self.has_alibi_mask = False
286
- # NZ
287
- self.q_shape = q["shape"]
288
- self.k_shape = k["shape"]
289
- self.v_shape = v["shape"]
290
- # ND
291
- self.q_ori_shape = q["ori_shape"]
292
- self.k_ori_shape = k["ori_shape"]
293
- if attn_mask is not None:
294
- self.has_attn_mask = True
295
- self.att_mask_shape = attn_mask["shape"]
296
- if dropout_mask is not None:
297
- self.has_drop_mask = True
298
- self.drop_mask_shape = dropout_mask["shape"]
299
- if alibi_mask is not None:
300
- self.has_alibi_mask = True
301
- self.alibi_mask_shape = alibi_mask["shape"]
302
-
303
- def define_inputs_outputs(self):
304
- """define inputs outputs"""
305
- self.define_common_inputs()
306
-
307
- self.define_custom_inputs()
308
-
309
- self.define_outputs()
310
-
311
- def init(self):
312
- """init parameters"""
313
- tiling_para: TilingPara = self.tiling_stgy.tiling()
314
-
315
- self.Br = tiling_para.Br
316
- self.last_Br = tiling_para.last_Br
317
- self.Bc = tiling_para.Bc
318
- self.last_Bc = tiling_para.last_Bc
319
- self.Tr = tiling_para.Tr
320
- self.Tc = tiling_para.Tc
321
-
322
- self.define_inputs_outputs()
323
-
324
- def define_common_inputs(self):
325
- """define common input gm tensors"""
326
- self.Q_gm = self.tik_instance.Tensor(FP16, self.q_shape, name="Q_gm", scope=GM)
327
- self.K_gm = self.tik_instance.Tensor(FP16, self.k_shape, name="K_gm", scope=GM)
328
- self.V_gm = self.tik_instance.Tensor(FP16, self.v_shape, name="V_gm", scope=GM)
329
- if self.has_attn_mask:
330
- self.att_mask_gm = self.tik_instance.Tensor(FP16, self.att_mask_shape,
331
- name="att_mask_gm", scope=GM)
332
- if self.has_drop_mask:
333
- self.drop_mask_gm = self.tik_instance.Tensor(FP16, self.drop_mask_shape,
334
- name="drop_mask_gm", scope=GM)
335
- if self.has_alibi_mask:
336
- self.alibi_mask_gm = self.tik_instance.Tensor(FP16, self.alibi_mask_shape,
337
- name="alibi_mask_gm", scope=GM)
338
-
339
- def do_alibi_mask(self, Sij_ub, alibi_mask_gm_offset, m_aligned, n_aligned):
340
- """load alibi mask from gm to ub, then add Sij"""
341
- with self.tik_instance.new_stmt_scope(disable_sync=False):
342
- alibi_mask_ub = self.tik_instance.Tensor(FP16, (1, n_aligned),
343
- scope=UB, name="alibi_mask_ub")
344
- self.tik_instance.data_move(alibi_mask_ub, self.alibi_mask_gm[alibi_mask_gm_offset], 0, 1,
345
- n_aligned // 16, 0, 0)
346
- alibi_mask_ub_broadcast = self.tik_ops_utils.broadcast_row(alibi_mask_ub, (m_aligned, n_aligned))
347
- self.tik_instance.h_add(Sij_ub, Sij_ub, alibi_mask_ub_broadcast)
348
-
349
- def do_att_mask(self, Sij_ub_N1MN0, attn_mask_gm_offset, q_blk_height, kv_blk_height,
350
- q_blk_h_aligned, kv_blk_h_aligned):
351
- """load attn mask from gm to ub, then mul it by MASK_FILL_VALUE and add Sij"""
352
- with self.tik_instance.new_stmt_scope(disable_sync=False):
353
- att_mask_ub = self.tik_instance.Tensor(FP16, (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
354
- scope=UB, name="att_mask_ub")
355
- self.tik_instance.data_move(att_mask_ub, self.att_mask_gm[attn_mask_gm_offset], 0,
356
- kv_blk_height // self.N0, q_blk_height * self.N0 // 16,
357
- (self.Nq - q_blk_height) * self.N0 // 16, 0)
358
- self.tik_instance.h_mul(att_mask_ub, att_mask_ub, MASK_FILL_VALUE)
359
- self.tik_instance.h_add(Sij_ub_N1MN0, Sij_ub_N1MN0, att_mask_ub)
360
-
361
- def do_dropout_mask(self, Pij_ub, dropout_mask_gm_offset, kv_blk_h_aligned, kv_blk_height,
362
- q_blk_h_aligned, q_blk_height, precision_type=FP16, workspace=None):
363
- """load drop mask from gm to ub, then mul it by Pij"""
364
- with self.tik_instance.new_stmt_scope(disable_sync=False):
365
- dropout_mask_ub = self.tik_instance.Tensor(FP16, (q_blk_h_aligned, kv_blk_h_aligned),
366
- scope=UB, name="drop_mask_ub")
367
- self.tik_instance.data_move(dropout_mask_ub, self.drop_mask_gm[dropout_mask_gm_offset], 0,
368
- q_blk_height, kv_blk_height // 16, (self.N - kv_blk_height) // 16, 0)
369
- dropout_mask_ub = dropout_mask_ub.reshape((kv_blk_height // self.N0, q_blk_height, self.N0))
370
- if precision_type == FP32:
371
- dropout_mask_ub_fp32 = self.tik_instance.Tensor(FP32,
372
- (kv_blk_h_aligned // self.N0, q_blk_h_aligned, self.N0),
373
- scope=UB, name="dropout_mask_ub_fp32")
374
- self.tik_instance.h_cast(dropout_mask_ub_fp32, dropout_mask_ub, "none")
375
- if workspace is None:
376
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub_fp32)
377
- else:
378
- self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub_fp32)
379
- else:
380
- if workspace is None:
381
- self.tik_instance.h_mul(Pij_ub, Pij_ub, dropout_mask_ub)
382
- else:
383
- self.tik_instance.h_mul(workspace, Pij_ub, dropout_mask_ub)
384
-
385
- def compute_process(self):
386
- """The compute process of FlashAttention"""
387
- self.init()
388
- self.prepare_global_ones()
389
- core_idx_to_batch_info, core_idx_to_tr_info = self.get_core_task_info()
390
- with self.tik_instance.for_range(begint=0, endt=self.core_num, name="core_index",
391
- block_num=self.core_num) as core_idx:
392
- batch_start_s = self.tik_instance.Scalar("int32", name="batch_start_s")
393
- batch_num_s = self.tik_instance.Scalar("int32", name="batch_num_s")
394
-
395
- batch_start_s.set_as(core_idx_to_batch_info[core_idx, 0])
396
- batch_num_s.set_as(core_idx_to_batch_info[core_idx, 1])
397
-
398
- self.compute_one_core(batch_start_s, batch_num_s, core_idx_to_tr_info, core_idx)
399
-
400
- self.tik_instance.BuildCCE(
401
- kernel_name=self.kernel_name,
402
- inputs=self.collect_inputs(),
403
- outputs=self.collect_outputs(),
404
- config={"dump_cce_code": False, "save_temp_cce_file": True, "enable_const_fold": True},
405
- enable_l2=True
406
- )
@@ -1,41 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """define constants"""
16
- from tbe import tik
17
-
18
- BLOCK_NUM = 16
19
- FP16 = "float16"
20
- INT8 = "int8"
21
- INT32 = "int32"
22
- FP32 = "float32"
23
- REPEAT_SZ = 128
24
- BLK_STRIDE = 1
25
- REPEAT_STRIDE = 8
26
- TRANS_CUBE_TGT = 8
27
- FP16_MIN_VAL = -65504.0
28
- MASK_FILL_VALUE = -10000.0
29
- GM = tik.scope_gm
30
- L1 = tik.scope_cbuf
31
- L1OUT = tik.scope_cbuf_out
32
- UB = tik.scope_ubuf
33
- L0A = tik.scope_ca
34
- L0B = tik.scope_cb
35
- L0C = tik.scope_cc
36
- DTYPE_SIZE = {
37
- "int8": 1,
38
- "float16": 2,
39
- "int16": 2,
40
- "float32": 4,
41
- }