mindspore 2.2.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  84. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,193 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """The impl of flash attention"""
16
- from __future__ import absolute_import
17
- import mindspore.ops as ops
18
- import mindspore.common.dtype as mstype
19
- from mindspore.ops import Custom
20
- from mindspore.ops import DataType
21
- from mindspore.ops import TBERegOp
22
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_bwd import flash_attention_grad
23
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_fwd import flash_attention
24
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
25
-
26
- KERNEL_NAME = "flash_attention"
27
-
28
- cus_flash_atten_op_info = TBERegOp("FlashAttentionPrimitive") \
29
- .fusion_type("OPAQUE") \
30
- .partial_flag(True) \
31
- .async_flag(False) \
32
- .binfile_name("flash_attention.so") \
33
- .compute_cost(10) \
34
- .kernel_name(KERNEL_NAME) \
35
- .attr("prev_block_num", "required", "int", "all", "65536") \
36
- .attr("next_block_num", "required", "int", "all", "65536") \
37
- .attr("high_precision", "required", "bool", "all", "false") \
38
- .attr("tiling_stgy_name", "required", "str", "all", "sparse") \
39
- .input(0, "query", False, "required", "all") \
40
- .input(1, "key", False, "required", "all") \
41
- .input(2, "value", False, "required", "all") \
42
- .input(3, "attn_mask", False, "optional", "all") \
43
- .input(4, "dropout_mask", False, "optional", "all") \
44
- .input(5, "alibi_mask", False, "optional", "all") \
45
- .output(0, "output", False, "required", "all") \
46
- .output(1, "rowsum", False, "required", "all") \
47
- .output(2, "rowmax", False, "required", "all") \
48
- .dtype_format(DataType.F16_FracNZ,
49
- DataType.F16_FracNZ,
50
- DataType.F16_FracNZ,
51
- DataType.F16_FracNZ,
52
- DataType.F16_Default,
53
- DataType.F16_FracNZ,
54
- DataType.F16_FracNZ,
55
- DataType.F16_Default,
56
- DataType.F16_Default) \
57
- .dtype_format(DataType.F16_FracNZ,
58
- DataType.F16_FracNZ,
59
- DataType.F16_FracNZ,
60
- DataType.F16_FracNZ,
61
- DataType.F16_Default,
62
- DataType.F16_FracNZ,
63
- DataType.F16_FracNZ,
64
- DataType.F32_Default,
65
- DataType.F16_Default) \
66
- .get_op_info()
67
-
68
- GRAD_KERNEL_NAME = "flash_attention_grad"
69
-
70
- cus_flash_atten_grad_op_info = TBERegOp("FlashAttentionGradPrimitive") \
71
- .fusion_type("OPAQUE") \
72
- .partial_flag(True) \
73
- .async_flag(False) \
74
- .binfile_name("flash_attention_grad.so") \
75
- .compute_cost(10) \
76
- .kernel_name(GRAD_KERNEL_NAME) \
77
- .attr("prev_block_num", "required", "int", "all", "65536") \
78
- .attr("next_block_num", "required", "int", "all", "65536") \
79
- .attr("high_precision", "required", "bool", "all", "false") \
80
- .attr("tiling_stgy_name", "required", "str", "all", "sparse") \
81
- .input(0, "query", False, "required", "all") \
82
- .input(1, "key", False, "required", "all") \
83
- .input(2, "value", False, "required", "all") \
84
- .input(3, "output", False, "required", "all") \
85
- .input(4, "do", False, "required", "all") \
86
- .input(5, "rowsum", False, "required", "all") \
87
- .input(6, "rowmax", False, "required", "all") \
88
- .input(7, "attn_mask", False, "optional", "all") \
89
- .input(8, "dropout_mask", False, "optional", "all") \
90
- .input(9, "alibi_mask", False, "optional", "all") \
91
- .output(0, "dq", False, "required", "all") \
92
- .output(1, "dk", False, "required", "all") \
93
- .output(2, "dv", False, "required", "all") \
94
- .dtype_format(DataType.F16_FracNZ,
95
- DataType.F16_FracNZ,
96
- DataType.F16_FracNZ,
97
- DataType.F16_FracNZ,
98
- DataType.F16_FracNZ,
99
- DataType.F16_Default,
100
- DataType.F16_Default,
101
- DataType.F16_FracNZ,
102
- DataType.F16_Default,
103
- DataType.F16_FracNZ,
104
- DataType.F32_FracNZ,
105
- DataType.F32_FracNZ,
106
- DataType.F32_FracNZ) \
107
- .dtype_format(DataType.F16_FracNZ,
108
- DataType.F16_FracNZ,
109
- DataType.F16_FracNZ,
110
- DataType.F16_FracNZ,
111
- DataType.F16_FracNZ,
112
- DataType.F32_Default,
113
- DataType.F16_Default,
114
- DataType.F16_FracNZ,
115
- DataType.F16_Default,
116
- DataType.F16_FracNZ,
117
- DataType.F32_FracNZ,
118
- DataType.F32_FracNZ,
119
- DataType.F32_FracNZ) \
120
- .get_op_info()
121
-
122
-
123
- def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
124
- tiling_stgy_name='sparse', high_precision=False):
125
- """get flash attention grad"""
126
-
127
- def infer_shape(q_shape, k_shape, v_shape, o_shape, do_shape, l_shape, m_shape,
128
- att_mask_shape, dropout_mask_shape, alibi_mask_shape):
129
- return q_shape, k_shape, v_shape
130
-
131
- def infer_dtype(q_dtype, k_dtype, v_dtype, o_dytpe, do_dtype, l_dtype, m_dtype,
132
- attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
133
- return mstype.float32, mstype.float32, mstype.float32
134
-
135
- fa_grad = Custom(flash_attention_grad, out_shape=infer_shape,
136
- out_dtype=infer_dtype, func_type="tbe", reg_info=cus_flash_atten_grad_op_info)
137
- fa_grad.add_prim_attr("prev_block_num", prev_block_num)
138
- fa_grad.add_prim_attr("next_block_num", next_block_num)
139
- fa_grad.add_prim_attr("high_precision", high_precision)
140
- fa_grad.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
141
- fa_grad.init_prim_io_names(
142
- inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "attn_mask", "dropout_mask",
143
- "alibi_mask"],
144
- outputs=["dq", "dk", "dv"]
145
- )
146
-
147
- def bprop(query, key, value, attn_mask, dropout_mask, alibi_mask, out, douts):
148
- output, rowsum, rowmax = out
149
- dout, _, _ = douts
150
- dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, attn_mask, dropout_mask,
151
- alibi_mask)
152
- dq = ops.cast(dq, mstype.float16)
153
- dk = ops.cast(dk, mstype.float16)
154
- dv = ops.cast(dv, mstype.float16)
155
- return dq, dk, dv, zeros_like(attn_mask), \
156
- zeros_like(dropout_mask), zeros_like(alibi_mask)
157
-
158
- return bprop
159
-
160
-
161
- def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_name='sparse', high_precision=False):
162
- """get_flash_attention"""
163
-
164
- def infer_shape(q_shape, k_shape, v_shape, attn_mask_shape=None,
165
- dropout_mask_shape=None, alibi_mask_shape=None):
166
- """infer shape"""
167
- batch, hidden_size, seq_len, _ = q_shape
168
- l_shape = (batch, hidden_size, seq_len)
169
- m_shape = (batch, hidden_size, seq_len)
170
- return q_shape, l_shape, m_shape
171
-
172
- def infer_dtype(q_dtype, k_dtype, v_dtype, attn_mask_dtype=None,
173
- dropout_mask_dtype=None, alibi_mask_type=None):
174
- """infer type"""
175
- l_dtype = mstype.float16
176
- if high_precision:
177
- l_dtype = mstype.float32
178
- return q_dtype, l_dtype, q_dtype
179
-
180
- fa_grad = get_flash_attention_grad(prev_block_num, next_block_num, tiling_stgy_name, high_precision)
181
- fa_forward = Custom(flash_attention, out_shape=infer_shape,
182
- out_dtype=infer_dtype, func_type="tbe", bprop=fa_grad,
183
- reg_info=cus_flash_atten_op_info)
184
- fa_forward.add_prim_attr("prev_block_num", prev_block_num)
185
- fa_forward.add_prim_attr("next_block_num", next_block_num)
186
- fa_forward.add_prim_attr("high_precision", high_precision)
187
- fa_forward.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
188
- fa_forward.init_prim_io_names(
189
- inputs=["query", "key", "value", "attn_mask", "dropout_mask", "alibi_mask"],
190
- outputs=["output", "rowsum", "rowmax"]
191
- )
192
-
193
- return fa_forward
@@ -1,435 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """the common about tik ops"""
16
- from functools import partial
17
-
18
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SIZE
19
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
20
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
21
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import L0C
22
- from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
23
-
24
-
25
- class TikOpsUtils:
26
- """Utils function class about tik ops"""
27
-
28
- def __init__(self, tik_instance):
29
- self.tik_instance = tik_instance
30
- self.dtype = "float16"
31
- self.cont_data_mv_1_bust = partial(self.tik_instance.data_move, sid=0, nburst=1,
32
- src_stride=0,
33
- dst_stride=0)
34
-
35
- def MK_TO_K1MK0(self, mk_input_tensor, workspace_tensor=None):
36
- """change data shape from (M, K) to (K1, M, K0), K1 = K // K0, the effect is equant to:
37
- new_tensor = np.stack(np.hsplit(mk_input_tensor, K1), axis=0)
38
-
39
- :param mk_input_tensor: input tensor in GM with shape: (M, K)
40
- :param workspace_tensor: workspace tensor with shape: (K1, M, K0)
41
- tensor will be changed, otherwise the new data will be copied to the workspace tensor,
42
- and input tensor will stay unchanged.
43
- :return: Tensor with shape (K1,M, K0)
44
- """
45
- dtype = mk_input_tensor.dtype
46
- m, k = mk_input_tensor.shape
47
- K0 = 16
48
- K1 = k // K0
49
- M = self.up_align_to_K0(m)
50
- try:
51
- dtype_size = DTYPE_SIZE[dtype]
52
- except KeyError:
53
- raise ValueError("The argument 'dtype' is not valid.")
54
- if workspace_tensor is not None:
55
- with self.tik_instance.for_range(0, K1) as i:
56
- self.tik_instance.data_move(
57
- workspace_tensor[i * M * K0:],
58
- mk_input_tensor[i * K0:],
59
- 0,
60
- M,
61
- K0 * dtype_size // 32,
62
- (K1 - 1) * K0 * dtype_size // 32,
63
- 0,
64
- )
65
- return workspace_tensor.reshape((K1, M, K0))
66
-
67
- with self.tik_instance.new_stmt_scope(disable_sync=False):
68
- tmp_ub = self.tik_instance.Tensor(dtype, (K1, M, K0), name="tmp_ub", scope=UB)
69
- # data_move(m,k) --> (k1,m,K0)
70
- with self.tik_instance.for_range(0, K1) as i:
71
- self.tik_instance.data_move(
72
- tmp_ub[i * M * K0:],
73
- mk_input_tensor[i * K0:],
74
- 0,
75
- M,
76
- K0 * dtype_size // 32,
77
- (K1 - 1) * K0 * dtype_size // 32,
78
- 0,
79
- )
80
- self.cont_data_mv_1_bust(
81
- dst=mk_input_tensor, src=tmp_ub, burst=K1 * M * K0 * dtype_size // 32)
82
- return mk_input_tensor.reshape((K1, M, K0))
83
-
84
- def transpose_matrix(self, src_ub, dst_ub, N, nk0=False):
85
- """ transpose matrix, default support shape: (16, n) -> (n, 16)
86
- if nk0 is true, support shape: (n, 16) -> (16, n)
87
- """
88
- K0 = 16
89
- rep_times = N // K0
90
- if nk0:
91
- src_list = [src_ub[16 * i] for i in range(16)]
92
- dst_list = [dst_ub[N * i] for i in range(16)]
93
- else:
94
- src_list = [src_ub[N * i] for i in range(16)]
95
- dst_list = [dst_ub[16 * i] for i in range(16)]
96
-
97
- dst_rep_stride = K0
98
- src_rep_stride = 1
99
- if rep_times == 1:
100
- dst_rep_stride = 0
101
- src_rep_stride = 0
102
-
103
- if nk0:
104
- src_rep_stride, dst_rep_stride = dst_rep_stride, src_rep_stride
105
-
106
- self.tik_instance.vec_trans_scatter(
107
- False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride
108
- )
109
- return dst_ub
110
-
111
- def KN_TO_K1NK0(self, kn_input_tensor, workspace_tensor=None):
112
- """change data shape from (K,N) to (K1, N, K0), K1 = K // K0, the effect is equvilent to:
113
- new_tensor = np.reshape(kn_input_tensor, newshape=(K1, K0, N)).swapaxes(1, 2)
114
-
115
- :param kn_input_tensor: input tensor with shape: (K, N)
116
- :param workspace_tensor: workspace tensor with shape: (K1, N, K0)
117
- tensor will be changed, otherwise the new data will be copied to the workspace tensor,
118
- and input tensor will stay unchanged.
119
- :return: Tensor with shape: (K1, N, K0)
120
- """
121
- dtype = kn_input_tensor.dtype
122
- k, n = kn_input_tensor.shape
123
- K0 = 16
124
- K1 = k // K0
125
- N = n
126
- try:
127
- dtype_size = DTYPE_SIZE[dtype]
128
- except KeyError:
129
- raise ValueError("The argument 'dtype' is not valid.")
130
- with self.tik_instance.for_range(0, K1) as index:
131
- k1nk0_ub = self.tik_instance.Tensor(dtype, (N, K0), UB, "k1nk0_ub")
132
- src_ub = self.tik_instance.Tensor(dtype, (K0, N), UB, "src_ub")
133
- burst_len = K0 * N * dtype_size // 32
134
- self.cont_data_mv_1_bust(dst=src_ub, src=kn_input_tensor[index * K0 * N],
135
- burst=burst_len)
136
- k1nk0_ub = self.transpose_matrix(src_ub, k1nk0_ub, N)
137
- if workspace_tensor is None:
138
- self.cont_data_mv_1_bust(dst=kn_input_tensor[index * K0 * N], src=k1nk0_ub,
139
- burst=burst_len)
140
- else:
141
- self.cont_data_mv_1_bust(dst=workspace_tensor[index * K0 * N], src=k1nk0_ub,
142
- burst=burst_len)
143
- if workspace_tensor is None:
144
- return kn_input_tensor.reshape((K1, N, K0))
145
-
146
- return workspace_tensor.reshape((K1, N, K0))
147
-
148
- def N1MN0_TO_MN(self, N1MN0_input):
149
- """change data shape from (N1, M, N0) to (M, N), N0=16, N = N1 * K0, the effect is equant to:
150
- N1MN0_input = np.concatenate(list(map(np.squeeze, np.split(N1MN0_input, N1))), axis=1)
151
-
152
- :param N1MN0_input: input tensor with shape (N, M, N0) in GM or L1.
153
- :return:
154
- """
155
- dtype = N1MN0_input.dtype
156
- N1, M, N0 = N1MN0_input.shape
157
- try:
158
- dtype_size = DTYPE_SIZE[dtype]
159
- except KeyError:
160
- raise ValueError("The argument 'dtype' is not valid.")
161
- with self.tik_instance.new_stmt_scope(disable_sync=False):
162
- tmp_ub = self.tik_instance.Tensor(dtype, (M, N1 * N0), name="tmp_ub", scope=UB)
163
- # data_move (n1,m,n0) --> (m,n)
164
- with self.tik_instance.for_range(0, N1) as i:
165
- self.tik_instance.data_move(
166
- tmp_ub[i * N0:],
167
- N1MN0_input[i * M * N0:],
168
- 0,
169
- M,
170
- N0 * dtype_size // 32,
171
- 0,
172
- (N1 - 1) * N0 * dtype_size // 32,
173
- )
174
- # data_move out
175
- self.cont_data_mv_1_bust(dst=N1MN0_input, src=tmp_ub, burst=M * N1 * N0 * dtype_size // 32)
176
- return N1MN0_input.reshape((M, N1 * N0))
177
-
178
- def broadcast(self, vec_ub, shape):
179
- """ broadcast a vector to a matrix
180
- :param vec_ub: a tensor in UB with shape of (M,), and dtype is float16
181
- :param shape: the target shape, a tuple with value (M, N), M and N are integer multiples of 16
182
- :return: a tensor in UB with shape of (M, N)
183
- """
184
- M, N = shape
185
- dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
186
-
187
- with self.tik_instance.new_stmt_scope(disable_sync=False):
188
- # (M,) -> (2, M) -> (4, M) -> (8, M) -> (16, M)
189
- tmp_ub1 = self.tik_instance.Tensor(FP16, (16, M), name="tmp_ub1", scope=UB)
190
- self.tik_instance.data_move(tmp_ub1, vec_ub, 0, 1, M // 16, 0, 0)
191
- times = self.tik_instance.Scalar("int32", name="times", init_value=1)
192
- with self.tik_instance.for_range(begint=0, endt=16):
193
- with self.tik_instance.if_scope(times <= 8):
194
- offset = times * M
195
- burst = times * M // 16
196
- self.cont_data_mv_1_bust(dst=tmp_ub1[offset], src=tmp_ub1, burst=burst)
197
- with self.tik_instance.else_scope():
198
- self.tik_instance.tik_break()
199
- times.set_as(times * 2)
200
-
201
- # (16, M) -> (M, 16)
202
- tmp_ub2 = self.tik_instance.Tensor(FP16, (M, 16), name="tmp_ub2", scope=UB)
203
- tmp_ub2_transposed = self.transpose_matrix(tmp_ub1, tmp_ub2, M)
204
-
205
- # (M, 16) -> (M, 32) -> (M, 64) -> ... -> (M, N)
206
- self.tik_instance.data_move(dst_ub, tmp_ub2_transposed, 0, M, 1, 0, N // 16 - 1)
207
- times.set_as(1)
208
- with self.tik_instance.for_range(begint=0, endt=N):
209
- offset = times * 16
210
- with self.tik_instance.if_scope(offset * 2 <= N):
211
- burst = offset // 16
212
- src_stride = N // 16 - burst
213
- dst_stride = N // 16 - burst
214
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
215
- dst_stride)
216
- with self.tik_instance.else_scope():
217
- burst = (N - offset) // 16
218
- src_stride = N // 16 - burst
219
- dst_stride = N // 16 - burst
220
- with self.tik_instance.if_scope(burst > 0):
221
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, M, burst, src_stride,
222
- dst_stride)
223
- self.tik_instance.tik_break()
224
- times.set_as(times * 2)
225
- return dst_ub
226
-
227
- def broadcast_row(self, vec_ub, shape):
228
- """broadcast row"""
229
- M, N = shape
230
- dst_ub = self.tik_instance.Tensor(FP16, shape, name="dst_ub", scope=UB)
231
- self.tik_instance.data_move(dst_ub, vec_ub, 0, 1, N // 16, 0, 0)
232
- times = self.tik_instance.Scalar("int32", name="times", init_value=1)
233
- # (1, N) -> (2, M) -> (4, N) -> ... -> (M, N)
234
- with self.tik_instance.for_range(begint=0, endt=M):
235
- with self.tik_instance.if_scope(times * 2 <= M):
236
- burst = times * N // 16
237
- offset = times * N
238
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
239
- with self.tik_instance.else_scope():
240
- burst = (M - times) * N // 16
241
- offset = times * N
242
- with self.tik_instance.if_scope(burst > 0):
243
- self.tik_instance.data_move(dst_ub[offset], dst_ub, 0, 1, burst, 0, 0)
244
- self.tik_instance.tik_break()
245
- times.set_as(times * 2)
246
- return dst_ub
247
-
248
- def get_K0(self, dtype=None):
249
- """get K0"""
250
- if dtype is None:
251
- dtype = self.dtype
252
- try:
253
- dtype_size = DTYPE_SIZE[dtype]
254
- except KeyError:
255
- raise ValueError("The argument 'dtype' is not valid.")
256
- return 32 // dtype_size
257
-
258
- def up_align_to_K0(self, n, dtype=None):
259
- """byte alignment by dtype"""
260
- if dtype is None:
261
- dtype = self.dtype
262
- try:
263
- dtype_size = DTYPE_SIZE[dtype]
264
- except KeyError:
265
- raise ValueError("The argument 'dtype' is not valid.")
266
- K0 = 32 // dtype_size
267
- return (n + K0 - 1) // K0 * K0
268
-
269
- def calc_vec_rec(self, vec_ub, vec_len):
270
- """cal the reciprocal of a vector"""
271
- dtype = vec_ub.dtype
272
- vec_len_aligned = self.up_align_to_K0(vec_len)
273
- vec_rec_ub = self.tik_instance.Tensor(dtype, (vec_len_aligned,), scope=UB, name="li_new_rec_ub")
274
- try:
275
- dtype_size = DTYPE_SIZE[dtype]
276
- except KeyError:
277
- raise ValueError("The argument 'dtype' is not valid.")
278
- mask_len = 256 // dtype_size
279
- block_len = 32 // dtype_size
280
- work_size = 8 // dtype_size
281
-
282
- with self.tik_instance.new_stmt_scope(disable_sync=False):
283
- repeat_times = vec_len // mask_len
284
- if repeat_times > 0:
285
- dst_rep_stride = 8
286
- src_rep_stride = 8
287
-
288
- src_extent_size = (repeat_times - 1) * src_rep_stride * block_len + mask_len
289
- wk_size_unit = ((src_extent_size + block_len - 1) // block_len) * block_len
290
- wk_size = work_size * wk_size_unit
291
- # 定义work_tensor
292
- work_tensor_ub = self.tik_instance.Tensor(
293
- "float32", (wk_size,), name="work_tensor_ub", scope=UB
294
- )
295
- # 如果work_tensor有索引,需要写成 work_tensor[index:]
296
- self.tik_instance.vec_rec_high_preci(
297
- mask_len,
298
- vec_rec_ub[0:],
299
- vec_ub[0:],
300
- work_tensor_ub[0:],
301
- repeat_times,
302
- dst_rep_stride,
303
- src_rep_stride,
304
- )
305
-
306
- mask_len = vec_len - repeat_times * mask_len
307
- if mask_len > 0:
308
- wk_size = work_size * ((mask_len + block_len - 1) // block_len) * block_len
309
- work_tensor_ub2 = self.tik_instance.Tensor(
310
- "float32", (wk_size,), name="work_tensor_ub2", scope=UB
311
- )
312
- self.tik_instance.vec_rec_high_preci(
313
- mask_len,
314
- vec_rec_ub[repeat_times * 128:],
315
- vec_ub[repeat_times * 128:],
316
- work_tensor_ub2[0:],
317
- 1,
318
- 0,
319
- 0,
320
- )
321
- return vec_rec_ub
322
-
323
- def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, right_all_one_matrix_l1, rowsum_ub, m, k, precision_type):
324
- """用cube实现矩阵行和:右乘一个shape=(n,1)全一矩阵
325
- :param matrix_l1_K1MK0_ed: input tensor with shape (K1, M, K0)
326
- :param right_all_one_matrix_l1: input tensor with shape (K, 16)
327
- :param rowsum_ub: output tensor stores the row sum of input tensor
328
- :param m: actual tensor height
329
- :param k: actual tensor width
330
- :return: row sum of the output tensor
331
- """
332
- K1, M, K0 = matrix_l1_K1MK0_ed.shape
333
- # 调用matmul实现rowsum,结果shape=(m, 16),取每行的第一个数
334
- with self.tik_instance.new_stmt_scope(disable_sync=False):
335
- row_sum_ub_N1MN0 = self.matmul_compute(matrix_l1_K1MK0_ed, right_all_one_matrix_l1, m, k, 16,
336
- N1MN0_to_MN=False, precision_type=precision_type)
337
- row_sum_ub_MN_ed = row_sum_ub_N1MN0.reshape((M, 16))
338
- if precision_type == FP32:
339
- for idx in range(0, m):
340
- cur_row_sum = self.tik_instance.Scalar(FP32, init_value=row_sum_ub_MN_ed[idx, 0])
341
- rowsum_ub[idx].set_as(cur_row_sum)
342
- else:
343
- # row_sum_ub_MN_ed 先转置,然后取一行, 替换原来按行操作: lij_ub[i].set_as(row_sum_ub_MN_ed[i, 0])
344
- row_sum_ub_trans = self.tik_instance.Tensor(FP16, (16, M), name="row_sum_ub_trans", scope=UB)
345
- row_sum_ub_trans = self.transpose_matrix(row_sum_ub_MN_ed, row_sum_ub_trans, M, True)
346
- self.cont_data_mv_1_bust(dst=rowsum_ub, src=row_sum_ub_trans, burst=M // 16)
347
-
348
- return rowsum_ub
349
-
350
- def matmul_compute(self, A_l1, B_l1, m, k, n, N1MN0_to_MN=True, precision_type=FP16):
351
- """calculate matrix multiplication A_l1 * B_l1, and move the result to C_ub,
352
- then rearrange C_ub
353
- :param A_l1: input tensor in L1 with shape of (K1, M, K0)
354
- :param B_l1: input tensor in L1 with shape of (K1, N, K0)
355
- :param m: the actual number of rows of A_l1
356
- :param k: the actual number of cols of A_l1
357
- :param n: the actual number of cols of B_l1
358
- :param N1MN0_to_MN: Whether reorder the result tensor.
359
- :return: C_ub with tensor with shape of (M, N) if N1MN0_to_MN else (N1, M, N0)
360
- """
361
- M = self.up_align_to_K0(m)
362
- N = self.up_align_to_K0(n)
363
- C_ub = self.tik_instance.Tensor(precision_type, (N // 16, M, 16), name="C_ub", scope=UB)
364
- try:
365
- dtype_size = DTYPE_SIZE[FP32]
366
- except KeyError:
367
- raise ValueError("The argument 'dtype' is not valid.")
368
- with self.tik_instance.new_stmt_scope(disable_sync=False):
369
- # matmul
370
- C_l0c = self.tik_instance.Tensor(
371
- FP32, (N // 16, M, 16), scope=L0C, name="C_l0c"
372
- ) # n1mn0 (n0=16)
373
- self.tik_instance.matmul(C_l0c, A_l1, B_l1, m, k, n)
374
- # L0C -> ub, fp32 -> fp16 (tensor_mov可做随路转换)
375
- self.tik_instance.tensor_mov(C_ub, C_l0c, "m", 1, M * N * dtype_size // 1024, 0, 0)
376
- if N1MN0_to_MN:
377
- return self.N1MN0_TO_MN(C_ub)
378
- return C_ub
379
-
380
- def move_vector_from_gm_to_ub(self, dst_tensor, src_tensor, gm_offset, vec_len):
381
- """load the vector from gm to ub
382
- :param dst_tensor:
383
- :param src_tensor:
384
- :param gm_offset:
385
- :return:
386
- """
387
- try:
388
- dtype_size = DTYPE_SIZE[src_tensor.dtype]
389
- except KeyError:
390
- raise ValueError("The argument 'src_tensor dtype' is not valid.")
391
- a_burst_num = 32 // dtype_size
392
- full_tik_blk_num, tail_num = divmod(vec_len, a_burst_num)
393
- with self.tik_instance.if_scope(full_tik_blk_num > 0):
394
- self.cont_data_mv_1_bust(dst=dst_tensor, src=src_tensor[gm_offset],
395
- burst=full_tik_blk_num)
396
- # 地址回退处理尾部数据
397
- with self.tik_instance.if_scope(tail_num > 0):
398
- offset = vec_len - a_burst_num
399
- last_blk_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="last_blk_ub", scope=UB)
400
- self.cont_data_mv_1_bust(dst=last_blk_ub, src=src_tensor[gm_offset + offset], burst=1)
401
- with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐, 无法用datamove
402
- dst_tensor[offset + idx].set_as(last_blk_ub[idx])
403
-
404
- def move_vector_from_ub_to_gm(self, dst_tensor, src_tensor, gm_offset, block_h):
405
- """write the vector back to gm
406
- :param dst_tensor:
407
- :param src_tensor:
408
- :param gm_offset:
409
- :param block_h:
410
- :return:
411
- """
412
- try:
413
- dtype_size = DTYPE_SIZE[src_tensor.dtype]
414
- except KeyError:
415
- raise ValueError("The argument 'src_tensor dtype' is not valid.")
416
- a_burst_num = 32 // dtype_size
417
- full_tik_blk_num = block_h // a_burst_num
418
- with self.tik_instance.if_scope(full_tik_blk_num > 0):
419
- self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset], src=src_tensor,
420
- burst=full_tik_blk_num)
421
- tail_num = block_h % a_burst_num
422
- with self.tik_instance.if_scope(tail_num > 0):
423
- offset = block_h - a_burst_num
424
- tmp_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="tmp_ub", scope=UB)
425
- with self.tik_instance.for_range(0, a_burst_num) as idx:
426
- tmp_ub[idx].set_as(src_tensor[offset + idx])
427
- self.cont_data_mv_1_bust(dst=dst_tensor[gm_offset + offset], src=tmp_ub, burst=1)
428
-
429
- def scale_compute_vector(self, Sij_ub, dim):
430
- """scale compute vector"""
431
- scale_value = dim ** -0.5
432
- scale = self.tik_instance.Scalar(dtype=FP16)
433
- scale.set_as(scale_value)
434
- self.tik_instance.h_mul(Sij_ub, Sij_ub, scale)
435
- return Sij_ub
@@ -1,45 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """tiling for sparse """
16
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingPara
17
- from mindspore.ops._op_impl._custom_op.flash_attention.tiling_strategy.strategy import TilingStrategy
18
-
19
-
20
- class SparseTiling(TilingStrategy):
21
- """A tiling strategy implementation for sparse shape"""
22
-
23
- @classmethod
24
- def strategy_name(cls):
25
- return "sparse"
26
-
27
- def tiling(self) -> TilingPara:
28
- self.Br = min(128, self.Nq)
29
- self.Bc = min(128, self.N)
30
-
31
- self.Tr = self.Nq // self.Br
32
- self.Tc = self.N // self.Bc
33
-
34
- if self.Nq % self.Br != 0:
35
- self.last_Br = self.Nq - self.Tr * self.Br
36
- self.Tr += 1
37
- else:
38
- self.last_Br = self.Br
39
- if self.N % self.Bc != 0:
40
- self.last_Bc = self.N - self.Tc * self.Bc
41
- self.Tc += 1
42
- else:
43
- self.last_Bc = self.Bc
44
-
45
- return self.gen_tiling_para()