mindspore 2.3.0rc1__cp38-none-any.whl → 2.3.0rc2__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (318) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/bin/cache_admin +0 -0
  12. mindspore/bin/cache_server +0 -0
  13. mindspore/boost/boost_cell_wrapper.py +1 -1
  14. mindspore/boost/group_loss_scale_manager.py +1 -1
  15. mindspore/common/__init__.py +4 -2
  16. mindspore/common/_register_for_recompute.py +48 -0
  17. mindspore/common/_stub_tensor.py +1 -0
  18. mindspore/common/api.py +56 -4
  19. mindspore/common/dtype.py +5 -3
  20. mindspore/common/dump.py +2 -2
  21. mindspore/common/hook_handle.py +51 -4
  22. mindspore/common/initializer.py +1 -1
  23. mindspore/common/jit_config.py +17 -6
  24. mindspore/common/parameter.py +7 -2
  25. mindspore/common/recompute.py +247 -0
  26. mindspore/common/sparse_tensor.py +2 -2
  27. mindspore/common/symbol.py +1 -1
  28. mindspore/common/tensor.py +74 -36
  29. mindspore/communication/__init__.py +3 -3
  30. mindspore/communication/management.py +30 -30
  31. mindspore/context.py +28 -15
  32. mindspore/dataset/__init__.py +5 -5
  33. mindspore/dataset/audio/__init__.py +2 -2
  34. mindspore/dataset/audio/transforms.py +51 -51
  35. mindspore/dataset/callback/ds_callback.py +2 -2
  36. mindspore/dataset/engine/cache_client.py +1 -1
  37. mindspore/dataset/engine/datasets.py +3 -3
  38. mindspore/dataset/engine/datasets_audio.py +14 -14
  39. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  40. mindspore/dataset/engine/datasets_text.py +38 -38
  41. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  42. mindspore/dataset/engine/datasets_vision.py +68 -68
  43. mindspore/dataset/text/__init__.py +3 -3
  44. mindspore/dataset/text/transforms.py +26 -26
  45. mindspore/dataset/transforms/__init__.py +1 -1
  46. mindspore/dataset/vision/__init__.py +3 -3
  47. mindspore/dataset/vision/transforms.py +92 -92
  48. mindspore/dataset/vision/utils.py +1 -1
  49. mindspore/experimental/optim/adadelta.py +2 -2
  50. mindspore/experimental/optim/adagrad.py +2 -2
  51. mindspore/experimental/optim/adam.py +2 -2
  52. mindspore/experimental/optim/adamax.py +2 -2
  53. mindspore/experimental/optim/adamw.py +2 -2
  54. mindspore/experimental/optim/asgd.py +2 -2
  55. mindspore/experimental/optim/lr_scheduler.py +24 -20
  56. mindspore/experimental/optim/nadam.py +2 -2
  57. mindspore/experimental/optim/optimizer.py +1 -1
  58. mindspore/experimental/optim/radam.py +2 -2
  59. mindspore/experimental/optim/rmsprop.py +2 -2
  60. mindspore/experimental/optim/rprop.py +2 -2
  61. mindspore/experimental/optim/sgd.py +2 -2
  62. mindspore/hal/stream.py +2 -0
  63. mindspore/include/mindapi/base/types.h +5 -0
  64. mindspore/lib/libdnnl.so.2 +0 -0
  65. mindspore/lib/libmindspore.so +0 -0
  66. mindspore/lib/libmindspore_backend.so +0 -0
  67. mindspore/lib/libmindspore_common.so +0 -0
  68. mindspore/lib/libmindspore_core.so +0 -0
  69. mindspore/lib/libmindspore_glog.so.0 +0 -0
  70. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  71. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  72. mindspore/lib/libmindspore_shared_lib.so +0 -0
  73. mindspore/lib/libopencv_core.so.4.5 +0 -0
  74. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  75. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  76. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  77. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  78. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  79. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  80. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  81. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  82. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  83. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  84. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  85. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
  86. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  87. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  88. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
  89. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
  90. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
  91. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
  92. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
  93. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
  94. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  164. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
  165. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
  166. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
  167. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
  168. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
  169. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
  170. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  171. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/log.py +2 -2
  174. mindspore/mint/__init__.py +457 -0
  175. mindspore/mint/nn/__init__.py +430 -0
  176. mindspore/mint/nn/functional.py +424 -0
  177. mindspore/mint/optim/__init__.py +24 -0
  178. mindspore/mint/optim/adamw.py +186 -0
  179. mindspore/multiprocessing/__init__.py +4 -0
  180. mindspore/nn/__init__.py +3 -0
  181. mindspore/nn/cell.py +51 -47
  182. mindspore/nn/extend/__init__.py +29 -0
  183. mindspore/nn/extend/basic.py +140 -0
  184. mindspore/nn/extend/embedding.py +143 -0
  185. mindspore/nn/extend/layer/__init__.py +27 -0
  186. mindspore/nn/extend/layer/normalization.py +107 -0
  187. mindspore/nn/extend/pooling.py +117 -0
  188. mindspore/nn/generator.py +297 -0
  189. mindspore/nn/layer/basic.py +109 -1
  190. mindspore/nn/layer/container.py +2 -2
  191. mindspore/nn/layer/conv.py +6 -6
  192. mindspore/nn/layer/embedding.py +1 -1
  193. mindspore/nn/layer/normalization.py +21 -43
  194. mindspore/nn/layer/padding.py +4 -0
  195. mindspore/nn/optim/ada_grad.py +2 -2
  196. mindspore/nn/optim/adadelta.py +1 -1
  197. mindspore/nn/optim/adafactor.py +1 -1
  198. mindspore/nn/optim/adam.py +7 -7
  199. mindspore/nn/optim/adamax.py +2 -2
  200. mindspore/nn/optim/adasum.py +2 -2
  201. mindspore/nn/optim/asgd.py +2 -2
  202. mindspore/nn/optim/ftrl.py +1 -1
  203. mindspore/nn/optim/lamb.py +3 -3
  204. mindspore/nn/optim/lars.py +1 -1
  205. mindspore/nn/optim/lazyadam.py +2 -2
  206. mindspore/nn/optim/momentum.py +2 -2
  207. mindspore/nn/optim/optimizer.py +2 -2
  208. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  209. mindspore/nn/optim/rmsprop.py +2 -2
  210. mindspore/nn/optim/rprop.py +2 -2
  211. mindspore/nn/optim/sgd.py +2 -2
  212. mindspore/nn/optim/thor.py +2 -2
  213. mindspore/nn/wrap/cell_wrapper.py +9 -9
  214. mindspore/nn/wrap/grad_reducer.py +5 -5
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  216. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  217. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  218. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  219. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  220. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  221. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  222. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  223. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  224. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  225. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  226. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  227. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  228. mindspore/ops/extend/__init__.py +9 -1
  229. mindspore/ops/extend/array_func.py +134 -27
  230. mindspore/ops/extend/math_func.py +3 -3
  231. mindspore/ops/extend/nn_func.py +363 -2
  232. mindspore/ops/function/__init__.py +19 -2
  233. mindspore/ops/function/array_func.py +463 -439
  234. mindspore/ops/function/clip_func.py +7 -18
  235. mindspore/ops/function/grad/grad_func.py +5 -5
  236. mindspore/ops/function/linalg_func.py +4 -4
  237. mindspore/ops/function/math_func.py +260 -243
  238. mindspore/ops/function/nn_func.py +825 -62
  239. mindspore/ops/function/random_func.py +73 -4
  240. mindspore/ops/function/sparse_unary_func.py +1 -1
  241. mindspore/ops/function/vmap_func.py +1 -1
  242. mindspore/ops/functional.py +2 -2
  243. mindspore/ops/op_info_register.py +1 -31
  244. mindspore/ops/operations/__init__.py +2 -3
  245. mindspore/ops/operations/_grad_ops.py +2 -107
  246. mindspore/ops/operations/_inner_ops.py +5 -5
  247. mindspore/ops/operations/_sequence_ops.py +2 -2
  248. mindspore/ops/operations/array_ops.py +11 -233
  249. mindspore/ops/operations/comm_ops.py +32 -32
  250. mindspore/ops/operations/custom_ops.py +7 -89
  251. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  252. mindspore/ops/operations/math_ops.py +13 -163
  253. mindspore/ops/operations/nn_ops.py +9 -316
  254. mindspore/ops/operations/random_ops.py +1 -1
  255. mindspore/ops/operations/sparse_ops.py +3 -3
  256. mindspore/ops/primitive.py +2 -2
  257. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  258. mindspore/ops_generate/arg_handler.py +24 -0
  259. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  260. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  261. mindspore/ops_generate/pyboost_utils.py +2 -17
  262. mindspore/parallel/__init__.py +3 -2
  263. mindspore/parallel/_auto_parallel_context.py +106 -1
  264. mindspore/parallel/_parallel_serialization.py +34 -2
  265. mindspore/parallel/_utils.py +16 -0
  266. mindspore/parallel/algo_parameter_config.py +4 -4
  267. mindspore/parallel/checkpoint_transform.py +249 -77
  268. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  269. mindspore/parallel/parameter_broadcast.py +1 -1
  270. mindspore/parallel/shard.py +1 -1
  271. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  272. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  273. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  274. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  275. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  276. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  277. mindspore/profiler/parser/profiler_info.py +11 -1
  278. mindspore/profiler/profiling.py +13 -5
  279. mindspore/rewrite/api/node.py +12 -12
  280. mindspore/rewrite/api/symbol_tree.py +11 -11
  281. mindspore/run_check/_check_version.py +1 -1
  282. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  283. mindspore/train/amp.py +4 -4
  284. mindspore/train/anf_ir_pb2.py +8 -2
  285. mindspore/train/callback/_backup_and_restore.py +2 -2
  286. mindspore/train/callback/_callback.py +4 -4
  287. mindspore/train/callback/_checkpoint.py +2 -2
  288. mindspore/train/callback/_early_stop.py +2 -2
  289. mindspore/train/callback/_landscape.py +4 -4
  290. mindspore/train/callback/_loss_monitor.py +2 -2
  291. mindspore/train/callback/_on_request_exit.py +2 -2
  292. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  293. mindspore/train/callback/_summary_collector.py +2 -2
  294. mindspore/train/callback/_time_monitor.py +2 -2
  295. mindspore/train/dataset_helper.py +8 -3
  296. mindspore/train/loss_scale_manager.py +2 -2
  297. mindspore/train/metrics/metric.py +3 -3
  298. mindspore/train/mind_ir_pb2.py +22 -17
  299. mindspore/train/model.py +15 -15
  300. mindspore/train/serialization.py +18 -18
  301. mindspore/train/summary/summary_record.py +7 -7
  302. mindspore/train/train_thor/convert_utils.py +3 -3
  303. mindspore/version.py +1 -1
  304. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  305. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +309 -262
  306. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
  307. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
  308. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
  309. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
  310. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
  311. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
  312. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
  313. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
  314. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
  315. /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
  316. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  317. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  318. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,460 @@
1
+
2
+ /**
3
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+ #ifndef ROTARY_POS_EMB_BASE
18
+ #define ROTARY_POS_EMB_BASE
19
+
20
+ #include "apply_rotary_pos_emb_tiling.h"
21
+ #include "apply_rotary_pos_emb_value.h"
22
+ #include "kernel_operator.h"
23
+
24
+ template <typename QkDtype, typename CosDtype, bool IF_COS_BROADCAST>
25
+ class RopeBase {
26
+ public:
27
+ // QkDtype :输入qk和输出qk的数据类型
28
+ // CosDtype :输入cos/sin的数据类型
29
+ // IF_COS_BROADCAST :cos sin是否已扩展
30
+ // 构造函数
31
+ __aicore__ inline RopeBase(RopeTilingData *tilingData) {
32
+ setCtrl_ = get_ctrl();
33
+ #if __CCE_AICORE__ == 220
34
+ set_ctrl(sbitset0(get_ctrl(), REG_910B));
35
+ #elif __CCE_AICORE__ == 200
36
+ set_ctrl(sbitset1(get_ctrl(), REG_310P));
37
+ #endif
38
+ this->tilingData_ = tilingData;
39
+ batchSize_ = (tilingData_->cosFormat == 0)
40
+ ? 0
41
+ : ((tilingData_->batch + DEFAULT_REPEAT_STRIDE - 1) / DEFAULT_REPEAT_STRIDE) * DEFAULT_REPEAT_STRIDE;
42
+ hiddenSize_ =
43
+ tilingData_->hiddenSizeK > tilingData_->hiddenSizeQ ? tilingData_->hiddenSizeK : tilingData_->hiddenSizeQ;
44
+ nlCoreRun_ = (tilingData_->ntokens + tilingData_->realCore - 1) / tilingData_->realCore;
45
+ lCoreRun_ = tilingData_->ntokens - (tilingData_->realCore - 1) * nlCoreRun_;
46
+ headNum_ = tilingData_->headNumK > tilingData_->headNumQ ? tilingData_->headNumK : tilingData_->headNumQ;
47
+ rotateStride_ = tilingData_->headDim / tilingData_->rotaryCoeff;
48
+ dynamicRound_ = (block_idx == tilingData_->realCore - 1) ? lCoreRun_ : nlCoreRun_;
49
+ rotaryStrideOffset = (tilingData_->headDim == tilingData_->rotaryCoeff) ? 1 : rotateStride_;
50
+ alignRotary_ = rotateStride_ % ELE_NUM_FP16;
51
+ pipe_.InitBuffer(seqLenQueue_, 1, (batchSize_ * sizeof(int32_t)));
52
+ }
53
+
54
+ // 初始化Gm
55
+ __aicore__ inline void RopeInitGm(__gm__ uint8_t *q, __gm__ uint8_t *k, __gm__ uint8_t *cos, __gm__ uint8_t *sin,
56
+ __gm__ uint8_t *seqLen, __gm__ uint8_t *outQ, __gm__ uint8_t *outK) {
57
+ qGm_ = (__gm__ QkDtype *)q;
58
+ kGm_ = (__gm__ QkDtype *)k;
59
+ cosGm_ = (__gm__ CosDtype *)cos;
60
+ sinGm_ = (__gm__ CosDtype *)sin;
61
+ outQGm_ = (__gm__ QkDtype *)outQ;
62
+ outKGm_ = (__gm__ QkDtype *)outK;
63
+ seqLenGm_ = (__gm__ uint32_t *)seqLen;
64
+ }
65
+
66
+ template <typename T>
67
+ __aicore__ inline void Copy2Ub(__gm__ T *src, __ubuf__ T *dst, uint32_t copyLen) {
68
+ #if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
69
+ if (g_coreType == AscendC::AIC) return;
70
+ #endif
71
+ uint32_t blkSizeReal = BLK_SIZE / sizeof(T);
72
+ if (copyLen % blkSizeReal != 0) {
73
+ copy_gm_to_ubuf(dst, src, 0, 1, (copyLen + blkSizeReal - 1) / blkSizeReal, 0, 0);
74
+ pipe_barrier((PIPE_ALL));
75
+ } else {
76
+ copy_gm_to_ubuf(dst, src, 0, 1, copyLen / blkSizeReal, 0, 0);
77
+ pipe_barrier((PIPE_ALL));
78
+ }
79
+ }
80
+
81
+ template <typename T>
82
+ __aicore__ inline void Copy2Gm(__ubuf__ T *src, __gm__ T *dst, uint32_t hiddenSizeLen) {
83
+ #if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
84
+ if (g_coreType == AscendC::AIC) return;
85
+ #endif
86
+ uint32_t blkSizeReal = BLK_SIZE / sizeof(T);
87
+ if (hiddenSizeLen % blkSizeReal != 0) {
88
+ copy_ubuf_to_gm(dst, src, 0, 1, (hiddenSizeLen + blkSizeReal - 1) / blkSizeReal, 0, 0);
89
+ } else {
90
+ copy_ubuf_to_gm(dst, src, 0, 1, hiddenSizeLen / blkSizeReal, 0, 0);
91
+ }
92
+ }
93
+
94
+ // 此函数用来复用unpad情況下的cos和sin
95
+ // 例:cos[0~7] cos[0~3]用于第一个batch, cos[0~4]用于第二个batch
96
+ __aicore__ inline void ExpandCosSin(__ubuf__ CosDtype *tempBuf, __gm__ CosDtype *src, __gm__ CosDtype *extraGm) {
97
+ #if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
98
+ if (g_coreType == AscendC::AIC) return;
99
+ #endif
100
+ // cos or sin,[maxseqlen,headsize]-->[sumseqlen,hiddensize]
101
+ AscendC::LocalTensor<int32_t> seqLenLocal = seqLenQueue_.AllocTensor<int32_t>();
102
+ // copy_gm_to_ubuf((__ubuf__ int32_t *)seqLenLocal.GetPhyAddr(), seqLenGm_, 0, 1,
103
+ // batchSize_ * sizeof(int32_t) / 32, 0, 0);
104
+ int32_t seqLenTmp = this->tilingData_->ntokens / this->tilingData_->batch;
105
+ for (uint32_t i = 0; i < this->tilingData_->batch; i++) {
106
+ seqLenLocal.SetValue(i, seqLenTmp);
107
+ }
108
+ pipe_barrier((PIPE_ALL));
109
+ int32_t rowsPerLoop = (maxProcessNum_ - batchSize_ * NUM_TWO) / tilingData_->headDim;
110
+ int32_t cosoffset = 0;
111
+ for (uint32_t perBatch = 0; perBatch < tilingData_->batch; perBatch++) {
112
+ int32_t rowsRepeat = seqLenLocal.GetValue(perBatch) / rowsPerLoop;
113
+ int32_t rowsRemain = seqLenLocal.GetValue(perBatch) % rowsPerLoop;
114
+ for (int32_t j = 0; j < rowsRepeat; j++) {
115
+ Copy2Ub(src + (j * rowsPerLoop) * tilingData_->headDim, tempBuf, rowsPerLoop * tilingData_->headDim);
116
+ Copy2Gm(tempBuf, (extraGm + (cosoffset + j * rowsPerLoop) * tilingData_->headDim),
117
+ rowsPerLoop * tilingData_->headDim);
118
+ pipe_barrier((PIPE_ALL));
119
+ }
120
+ if (rowsRemain > 0) {
121
+ Copy2Ub(src + (rowsRepeat * rowsPerLoop) * tilingData_->headDim, tempBuf, rowsRemain * tilingData_->headDim);
122
+ Copy2Gm(tempBuf, (extraGm + (cosoffset + rowsRepeat * rowsPerLoop) * tilingData_->headDim),
123
+ rowsRemain * tilingData_->headDim);
124
+ pipe_barrier((PIPE_ALL));
125
+ }
126
+ cosoffset += seqLenLocal.GetValue(perBatch);
127
+ }
128
+ seqLenQueue_.FreeTensor(seqLenLocal);
129
+ pipe_barrier((PIPE_ALL));
130
+ }
131
+
132
+ // 构建tensor -1 -1 -1 0 0 0
133
+ // 构建tensor 0 0 0 1 1 1
134
+ template <typename BUF_TYPE>
135
+ __aicore__ inline void ExpandNeg(__ubuf__ BUF_TYPE *tempBuf, uint32_t bufPos, uint32_t headNumTemp,
136
+ uint32_t repeatTimeTemp) {
137
+ if (tilingData_->headDim != tilingData_->rotaryCoeff) {
138
+ if (alignRotary_ == 0) { // 对齐直接 -1 1
139
+ for (uint32_t i = 0; i < rotateStride_; ++i) {
140
+ *(tempBuf + negOne_ + i) = (BUF_TYPE)-1;
141
+ *(tempBuf + negOne_ + i + rotateStride_) = (BUF_TYPE)1;
142
+ }
143
+ set_flag(PIPE_S, PIPE_V, EVENT_ID1);
144
+ wait_flag(PIPE_S, PIPE_V, EVENT_ID1);
145
+ for (uint32_t i = 1; i < headNumTemp * tilingData_->rotaryCoeff / NUM_TWO; ++i) {
146
+ // halfHeadDim = rotateStride_ * 2
147
+ copy_ubuf_to_ubuf(tempBuf + negOne_ + rotateStride_ * NUM_TWO * i, tempBuf + negOne_, 0, 1,
148
+ rotateStride_ * sizeof(BUF_TYPE) / ELE_NUM_FP16, 0, 0);
149
+ }
150
+ } else {
151
+ for (uint32_t i = 0; i < rotateStride_; ++i) { // 非对齐 -1 0
152
+ *(tempBuf + negOne_ + i) = (BUF_TYPE)-1;
153
+ *(tempBuf + negOne_ + i + rotateStride_) = (BUF_TYPE)0;
154
+ }
155
+ set_flag(PIPE_S, PIPE_V, EVENT_ID1);
156
+ wait_flag(PIPE_S, PIPE_V, EVENT_ID1);
157
+ for (uint32_t i = 0; i < headNumTemp * tilingData_->rotaryCoeff / NUM_TWO; ++i) {
158
+ if ((rotateStride_ * NUM_TWO) * sizeof(BUF_TYPE) % BLK_SIZE == 0) {
159
+ copy_ubuf_to_ubuf(tempBuf + negOne_ + rotateStride_ * NUM_TWO * i, tempBuf + negOne_, 0, 1,
160
+ rotateStride_ * NUM_TWO * sizeof(BUF_TYPE) / ELE_NUM_FP16, 0, 0);
161
+ } else {
162
+ for (uint32_t j = 0; j < rotateStride_ * NUM_TWO; j++) {
163
+ *(tempBuf + negOne_ + rotateStride_ * NUM_TWO * i + j) = *(tempBuf + negOne_ + j);
164
+ }
165
+ }
166
+ }
167
+ set_flag(PIPE_S, PIPE_V, EVENT_ID1);
168
+ wait_flag(PIPE_S, PIPE_V, EVENT_ID1);
169
+ pipe_barrier(PIPE_V);
170
+ vadds(tempBuf + bufPos, tempBuf + negOne_, (BUF_TYPE)1, repeatTimeTemp, 1, 1, DEFAULT_REPEAT_STRIDE,
171
+ DEFAULT_REPEAT_STRIDE);
172
+ }
173
+ } else {
174
+ set_vector_mask((uint64_t)-1, (uint64_t)-1);
175
+ vector_dup(tempBuf + negOne_, (BUF_TYPE)-1.0, repeatTimeTemp, 1, 1, (uint16_t)DEFAULT_REPEAT_STRIDE,
176
+ (uint16_t)DEFAULT_REPEAT_STRIDE);
177
+ set_vector_mask(0xaaaaaaaaaaaaaaaa, 0xaaaaaaaaaaaaaaaa);
178
+ vector_dup(tempBuf + negOne_, (BUF_TYPE)0.0, repeatTimeTemp, 1, 1, (uint16_t)DEFAULT_REPEAT_STRIDE,
179
+ (uint16_t)DEFAULT_REPEAT_STRIDE);
180
+ set_vector_mask((uint64_t)-1, (uint64_t)-1);
181
+ pipe_barrier((PIPE_V));
182
+ vadds(tempBuf + bufPos, tempBuf + negOne_, (BUF_TYPE)1, repeatTimeTemp, 1, 1, DEFAULT_REPEAT_STRIDE,
183
+ DEFAULT_REPEAT_STRIDE);
184
+ }
185
+ }
186
+
187
+ // 从(tilingData_->headDim)->(heads*tilingData_->headDim)
188
+ __aicore__ inline void CosSinCommonBroardcast(__gm__ uint8_t *extraGm, uint32_t z, __ubuf__ CosDtype *tempBuf,
189
+ uint32_t calcLen) {
190
+ // 永远的先拷一次
191
+ uint32_t cosOffset = block_idx * nlCoreRun_ * tilingData_->headDim + z * tilingData_->headDim;
192
+ uint32_t sinOffset = block_idx * nlCoreRun_ * tilingData_->headDim + z * tilingData_->headDim;
193
+ set_flag(PIPE_S, PIPE_MTE2, EVENT_ID1);
194
+ wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID1);
195
+ copy_gm_to_ubuf(tempBuf + cosPad_, cosGm_ + cosOffset, 0, 1,
196
+ (tilingData_->headDim * sizeof(CosDtype) + BLK_SIZE - 1) / BLK_SIZE, 0, 0);
197
+ copy_gm_to_ubuf(tempBuf + sinPad_, sinGm_ + sinOffset, 0, 1,
198
+ (tilingData_->headDim * sizeof(CosDtype) + BLK_SIZE - 1) / BLK_SIZE, 0, 0);
199
+ if (tilingData_->cosFormat == 1) {
200
+ pipe_barrier(PIPE_ALL);
201
+ }
202
+ set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID3);
203
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3);
204
+ if ((tilingData_->headDim * sizeof(CosDtype)) % BLK_SIZE != 0) {
205
+ wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID3);
206
+ // 补齐cos,从(tilingData_->headDim)->(heads*tilingData_->headDim)
207
+ // headnum
208
+ for (uint32_t i = 0; i < calcLen / tilingData_->headDim; ++i) {
209
+ copy_ubuf_to_gm((__gm__ CosDtype *)extraGm + offsetExtraGm_ + tilingData_->headDim * i, tempBuf + cosPad_, 0, 1,
210
+ (tilingData_->headDim * sizeof(CosDtype) + BLK_SIZE - 1) / BLK_SIZE, 0, 0);
211
+ pipe_barrier((PIPE_ALL));
212
+ }
213
+ Copy2Ub<CosDtype>((__gm__ CosDtype *)extraGm + offsetExtraGm_, tempBuf + cosPad_, calcLen);
214
+ // 补齐sin,从(tilingData_->headDim)->(heads*tilingData_->headDim)
215
+ for (uint32_t i = 0; i < calcLen / tilingData_->headDim; ++i) {
216
+ copy_ubuf_to_gm((__gm__ CosDtype *)extraGm + offsetExtraGm_ + tilingData_->headDim * i, tempBuf + sinPad_, 0, 1,
217
+ (tilingData_->headDim * sizeof(CosDtype) + BLK_SIZE - 1) / BLK_SIZE, 0, 0);
218
+ pipe_barrier((PIPE_ALL));
219
+ }
220
+ Copy2Ub<CosDtype>((__gm__ CosDtype *)extraGm + offsetExtraGm_, tempBuf + sinPad_, calcLen);
221
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3);
222
+ } else {
223
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3);
224
+ for (uint32_t i = 1; i < calcLen / tilingData_->headDim; ++i) {
225
+ copy_ubuf_to_ubuf(tempBuf + cosPad_ + tilingData_->headDim * i, tempBuf + cosPad_, 0, 1,
226
+ tilingData_->headDim * sizeof(CosDtype) / BLK_SIZE, 0, 0);
227
+ copy_ubuf_to_ubuf(tempBuf + sinPad_ + tilingData_->headDim * i, tempBuf + sinPad_, 0, 1,
228
+ tilingData_->headDim * sizeof(CosDtype) / BLK_SIZE, 0, 0);
229
+ }
230
+ wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID3);
231
+ }
232
+ }
233
+
234
+ // 满足 cos sin 多头输入
235
+ template <typename BUF_TYPE>
236
+ __aicore__ inline void CosSinBroadcast(__gm__ uint8_t *extraGm, uint32_t z, __ubuf__ BUF_TYPE *tempBuf,
237
+ uint32_t Calclen) {
238
+ if constexpr (IF_COS_BROADCAST) {
239
+ copy_gm_to_ubuf(tempBuf + cosPad_,
240
+ cosGm_ + block_idx * nlCoreRun_ * tilingData_->hiddenSizeQ + z * tilingData_->hiddenSizeQ, 0, 1,
241
+ Calclen * sizeof(BUF_TYPE) / BLK_SIZE, 0, 0);
242
+ copy_gm_to_ubuf(tempBuf + sinPad_,
243
+ sinGm_ + block_idx * nlCoreRun_ * tilingData_->hiddenSizeQ + z * tilingData_->hiddenSizeQ, 0, 1,
244
+ Calclen * sizeof(BUF_TYPE) / BLK_SIZE, 0, 0);
245
+ } else {
246
+ CosSinCommonBroardcast(extraGm, z, tempBuf, Calclen);
247
+ }
248
+ }
249
+
250
+ // qk 公用函数
251
+ template <typename BUF_TYPE>
252
+ __aicore__ inline void QkComm(__gm__ BUF_TYPE *src, __gm__ uint8_t *extraGm1, uint32_t hiddenSizeTmp,
253
+ __ubuf__ BUF_TYPE *tempBuf, uint32_t headNumTemp) {
254
+ uint32_t hiddenSizeBlk = hiddenSizeTmp / ELE_NUM_FP16;
255
+ set_flag(PIPE_S, PIPE_MTE2, EVENT_ID1);
256
+ wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID1);
257
+ copy_gm_to_ubuf(tempBuf + oriPos_, // gm -> ub
258
+ src, 0, 1, hiddenSizeBlk, 0, 0);
259
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
260
+ set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID2);
261
+ if (alignRotary_ == 0) {
262
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
263
+ wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID2);
264
+ copy_ubuf_to_ubuf(tempBuf + removeBefore_ + rotaryStrideOffset, tempBuf + oriPos_, 0,
265
+ headNumTemp * tilingData_->rotaryCoeff / 2, rotaryStrideOffset / ELE_NUM_FP16,
266
+ rotaryStrideOffset / ELE_NUM_FP16, rotaryStrideOffset / ELE_NUM_FP16);
267
+
268
+ copy_ubuf_to_ubuf(tempBuf + removeBefore_, tempBuf + oriPos_ + rotaryStrideOffset, 0,
269
+ headNumTemp * tilingData_->rotaryCoeff / 2, rotaryStrideOffset / ELE_NUM_FP16,
270
+ rotaryStrideOffset / ELE_NUM_FP16, rotaryStrideOffset / ELE_NUM_FP16);
271
+ } else {
272
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
273
+ wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID2);
274
+ // ub -> workspace[0~hiddensize]
275
+ copy_ubuf_to_gm((__gm__ BUF_TYPE *)extraGm1 + offsetExtraGm_, tempBuf + oriPos_, 0, 1, hiddenSizeBlk, 0, 0);
276
+ // ub -> workspace[hiddensize ~ 2 * hiddensize]
277
+ copy_ubuf_to_gm((__gm__ BUF_TYPE *)extraGm1 + offsetExtraGm_ + hiddenSizeTmp, tempBuf + oriPos_, 0, 1,
278
+ hiddenSizeBlk, 0, 0);
279
+ // workspace[rotary ~ hiddensize + rotary] -> ub[hiddensize ~ 2 * hiddensize]
280
+ pipe_barrier((PIPE_ALL));
281
+ copy_gm_to_ubuf(tempBuf + removeBefore_, (__gm__ BUF_TYPE *)extraGm1 + offsetExtraGm_ + rotateStride_, 0, 1,
282
+ hiddenSizeBlk, 0, 0);
283
+ // gm[hiddensize - rotary ~ 2 * hiddensize - rotary] -> ub[2 *hiddensize ~ 3 * hiddensize]
284
+ copy_gm_to_ubuf(tempBuf + padBefore_,
285
+ (__gm__ BUF_TYPE *)extraGm1 + offsetExtraGm_ + hiddenSizeTmp - rotateStride_, 0, 1, hiddenSizeBlk,
286
+ 0, 0);
287
+ }
288
+ }
289
+
290
+ // 主体计算逻辑
291
+ template <typename BUF_TYPE>
292
+ __aicore__ inline void CalcRope(__ubuf__ BUF_TYPE *tempBuf, uint32_t repeatTimes1, uint32_t oriPosTemp,
293
+ uint32_t removeTemp, uint32_t padTemp, uint32_t posTemp, uint32_t res) {
294
+ set_vector_mask((uint64_t)-1, (uint64_t)-1);
295
+ #if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
296
+ if (g_coreType == AscendC::AIC) return;
297
+ #endif
298
+
299
+ vmul(tempBuf + oriPosTemp, tempBuf + cosPad_, tempBuf + oriPosTemp,
300
+ repeatTimes1, // repeat times
301
+ 1, // dstBlockStride
302
+ 1, // src0BlockStride
303
+ 1, // src1BlockStride
304
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
305
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
306
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
307
+ );
308
+ vmul(tempBuf + padTemp, tempBuf + posTemp, tempBuf + padTemp,
309
+ repeatTimes1, // repeat times
310
+ 1, // dstBlockStride
311
+ 1, // src0BlockStride
312
+ 1, // src1BlockStride
313
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
314
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
315
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
316
+ );
317
+ pipe_barrier((PIPE_V));
318
+ vmul(tempBuf + removeTemp, tempBuf + sinPad_, tempBuf + removeTemp,
319
+ repeatTimes1, // repeat times
320
+ 1, // dstBlockStride
321
+ 1, // src0BlockStride
322
+ 1, // src1BlockStride
323
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
324
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
325
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
326
+ );
327
+ vmul(tempBuf + padTemp, tempBuf + sinPad_, tempBuf + padTemp,
328
+ repeatTimes1, // repeat times
329
+ 1, // dstBlockStride
330
+ 1, // src0BlockStride
331
+ 1, // src1BlockStride
332
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
333
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
334
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
335
+ );
336
+ pipe_barrier((PIPE_V));
337
+
338
+ vmul(tempBuf + removeTemp, tempBuf + negOne_, tempBuf + removeTemp,
339
+ repeatTimes1, // repeat times
340
+ 1, // dstBlockStride
341
+ 1, // src0BlockStride
342
+ 1, // src1BlockStride
343
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
344
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
345
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
346
+ );
347
+ vadd(tempBuf + padTemp, tempBuf + oriPosTemp, tempBuf + padTemp,
348
+ repeatTimes1, // repeat times
349
+ 1, // dstBlockStride
350
+ 1, // src0BlockStride
351
+ 1, // src1BlockStride
352
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
353
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
354
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
355
+ );
356
+ pipe_barrier((PIPE_V));
357
+
358
+ vadd(tempBuf + res, tempBuf + removeTemp, tempBuf + padTemp,
359
+ repeatTimes1, // repeat times
360
+ 1, // dstBlockStride
361
+ 1, // src0BlockStride
362
+ 1, // src1BlockStride
363
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
364
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
365
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
366
+ );
367
+ pipe_barrier((PIPE_V));
368
+ set_ctrl(setCtrl_);
369
+ }
370
+
371
+ // 主体计算逻辑
372
+ template <typename BUF_TYPE>
373
+ __aicore__ inline void CalcRopeAlign(__ubuf__ BUF_TYPE *tempBuf, uint32_t repeatTimes1, uint32_t oriPosTemp,
374
+ uint32_t removeTemp, uint32_t padTemp) {
375
+ set_vector_mask((uint64_t)-1, (uint64_t)-1);
376
+ #if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
377
+ if (g_coreType == AscendC::AIC) return;
378
+ #endif
379
+ vmul(tempBuf + oriPosTemp, tempBuf + cosPad_, tempBuf + oriPosTemp,
380
+ repeatTimes1, // repeat times
381
+ 1, // dstBlockStride
382
+ 1, // src0BlockStride
383
+ 1, // src1BlockStride
384
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
385
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
386
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
387
+ );
388
+ vmul(tempBuf + removeTemp, tempBuf + negOne_, tempBuf + removeTemp,
389
+ repeatTimes1, // repeat times
390
+ 1, // dstBlockStride
391
+ 1, // src0BlockStride
392
+ 1, // src1BlockStride
393
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
394
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
395
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
396
+ );
397
+ pipe_barrier((PIPE_V));
398
+ vmul(tempBuf + removeTemp, tempBuf + sinPad_, tempBuf + removeTemp,
399
+ repeatTimes1, // repeat times
400
+ 1, // dstBlockStride
401
+ 1, // src0BlockStride
402
+ 1, // src1BlockStride
403
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
404
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
405
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
406
+ );
407
+ pipe_barrier((PIPE_V));
408
+ vadd(tempBuf + padTemp, tempBuf + removeTemp, tempBuf + oriPosTemp,
409
+ repeatTimes1, // repeat times
410
+ 1, // dstBlockStride
411
+ 1, // src0BlockStride
412
+ 1, // src1BlockStride
413
+ DEFAULT_REPEAT_STRIDE, // dstRepeatStride
414
+ DEFAULT_REPEAT_STRIDE, // src0RepeatStride
415
+ DEFAULT_REPEAT_STRIDE // src1RepeatStride
416
+ );
417
+ pipe_barrier((PIPE_V));
418
+ set_ctrl(setCtrl_);
419
+ }
420
+
421
+ public:
422
+ RopeTilingData *tilingData_ = nullptr;
423
+ __gm__ QkDtype *qGm_{nullptr};
424
+ __gm__ QkDtype *kGm_{nullptr};
425
+ __gm__ CosDtype *cosGm_{nullptr};
426
+ __gm__ CosDtype *sinGm_{nullptr};
427
+ __gm__ uint32_t *seqLenGm_{nullptr};
428
+ __gm__ QkDtype *outQGm_{nullptr};
429
+ __gm__ QkDtype *outKGm_{nullptr};
430
+ AscendC::TPipe pipe_;
431
+ AscendC::TQue<AscendC::QuePosition::VECIN, 1> seqLenQueue_;
432
+
433
+ uint32_t cosPad_{0}; // broadcast的cos在uB中的位置
434
+ uint32_t sinPad_{0}; // broadcast的sin在uB中的位置
435
+ uint32_t negOne_{0}; // -1 -1 -1 0 0 0在uB中的位置
436
+ uint32_t oriPos_{0}; // q,k在uB中的位置
437
+ uint32_t padBefore_{0}; // 保存qk[-x : hiddensize - x]
438
+ uint32_t removeBefore_{0}; // 保存qk[x : hiddensize + x]
439
+ uint32_t repeatSize_{0}; // 一拍做几个元素
440
+ uint32_t maxProcessNum_{0}; // 最大处理元素个数
441
+ uint32_t repeatTimesQ_{0}; // q重复次数
442
+ uint32_t repeatTimesK_{0}; // k重复次数
443
+ uint32_t hiddenSizeAlign_{0}; // 对齐后的hiddensize
444
+ uint32_t repeatTimes_{0}; // 对齐后重复次数
445
+ uint32_t headNum_{0}; // 几个头
446
+ uint32_t hiddenSize_{0}; // hiddensizeQ,K的最大值
447
+ uint32_t nlCoreRun_{0}; // 非最后一个核需要跑几次
448
+ uint32_t lCoreRun_{0}; // 最后一个核需要跑几次
449
+ uint32_t batchSize_{0}; // batch向上取整
450
+ uint32_t rotateStride_{0}; // headdim / 旋转系数
451
+ uint32_t offsetExtraGm_{0}; // 使用workspace需要的offset
452
+ uint32_t dynamicRound_{0}; // 每个核做几轮
453
+ uint32_t setCtrl_; // 复位寄存器
454
+ uint32_t alignHalfHeadDim_{0}; // headDim / 旋转系数 * 2 是否对齐
455
+ uint32_t rotaryStrideOffset{0}; // 每次旋转长度
456
+ uint32_t alignRotary_; // 旋转距离是否对齐
457
+ uint32_t syncOffset_; // 每个核使用workspace的offset
458
+ };
459
+
460
+ #endif
@@ -0,0 +1,217 @@
1
+ /**
2
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef ROTARY_POS_EMB_BF16
17
+ #define ROTARY_POS_EMB_BF16
18
+ #include "apply_rotary_pos_emb_base.h"
19
+ template <typename QK_DTYPE, typename COS_DTYPE, bool IF_COS_BROADCAST>
20
+ class RopeBf16 : public RopeBase<QK_DTYPE, COS_DTYPE, IF_COS_BROADCAST> {
21
+ public:
22
+ __aicore__ inline RopeBf16(RopeTilingData *tilingData) : RopeBase<QK_DTYPE, COS_DTYPE, IF_COS_BROADCAST>(tilingData) {
23
+ this->repeatSize_ = 64; // 64 = 256B / sizeof(float)
24
+ this->maxProcessNum_ = 3 * MAX_LEN_FP16; // 3 is fp16 space needed
25
+ this->repeatTimesQ_ = (this->tilingData_->hiddenSizeQ + this->repeatSize_ - 1) / this->repeatSize_;
26
+ this->repeatTimesK_ = (this->tilingData_->hiddenSizeK + this->repeatSize_ - 1) / this->repeatSize_;
27
+ headDimAlign_ = ((this->tilingData_->headDim + ELE_NUM_FP16 - 1) / ELE_NUM_FP16) * ELE_NUM_FP16;
28
+ this->alignHalfHeadDim_ = (this->rotateStride_ * NUM_TWO) % ELE_NUM_FP32;
29
+ this->hiddenSizeAlign_ = ((this->hiddenSize_ + this->repeatSize_ - 1) / this->repeatSize_) * this->repeatSize_;
30
+ this->syncOffset_ =
31
+ (this->tilingData_->headDim % ELE_NUM_FP16 == 0) ? this->hiddenSizeAlign_ : this->headNum_ * headDimAlign_;
32
+ this->offsetExtraGm_ = NUM_TWO * block_idx * this->syncOffset_;
33
+
34
+ sliceSizeTmp_ = (SLICE_SIZE / this->tilingData_->headDim) * this->tilingData_->headDim; // 向下取整
35
+
36
+ // fp16
37
+ this->oriPos_ = 0;
38
+ this->removeBefore_ = this->oriPos_ + sliceSizeTmp_;
39
+ this->padBefore_ = this->removeBefore_ + sliceSizeTmp_;
40
+
41
+ // fp32
42
+ this->cosPad_ = 0;
43
+ this->sinPad_ = this->cosPad_ + sliceSizeTmp_;
44
+ this->negOne_ = this->sinPad_ + sliceSizeTmp_;
45
+ oriPosF32_ = this->negOne_ + sliceSizeTmp_;
46
+ PadBeforeF32_ = oriPosF32_ + sliceSizeTmp_;
47
+ removeBeforeF32_ = PadBeforeF32_ + sliceSizeTmp_;
48
+ posOneF32_ = removeBeforeF32_ + sliceSizeTmp_;
49
+
50
+ this->pipe_.InitBuffer(
51
+ qkfp32QueueCO2_, 1,
52
+ (this->tilingData_->maxUbSize - (this->batchSize_ + this->maxProcessNum_) * sizeof(half))); // 留給fp32的
53
+ AscendC::LocalTensor<float> qkfp32_perloop_ub = qkfp32QueueCO2_.AllocTensor<float>();
54
+ qkfp32Ubuf_ = (__ubuf__ float *)qkfp32_perloop_ub.GetPhyAddr();
55
+ this->pipe_.InitBuffer(outQueueCO2_, 1, ((this->maxProcessNum_) * sizeof(half)));
56
+ AscendC::LocalTensor<QK_DTYPE> cache_perloop_ub2 = outQueueCO2_.AllocTensor<QK_DTYPE>();
57
+ commonUbuf_ = (__ubuf__ QK_DTYPE *)cache_perloop_ub2.GetPhyAddr();
58
+
59
+ // 判断是否需要切块计算
60
+ if (this->tilingData_->hiddenSizeQ > sliceSizeTmp_) {
61
+ sliceTimeQ_ = (this->tilingData_->hiddenSizeQ + sliceSizeTmp_ - 1) / sliceSizeTmp_; // 向上取整
62
+ lastSliceSizeQ_ = this->tilingData_->hiddenSizeQ - (sliceTimeQ_ - 1) * sliceSizeTmp_; // 1024
63
+ } else {
64
+ sliceTimeQ_ = 1;
65
+ lastSliceSizeQ_ = this->tilingData_->hiddenSizeQ;
66
+ }
67
+
68
+ if (this->tilingData_->hiddenSizeK > sliceSizeTmp_) {
69
+ sliceTimeK_ = (this->tilingData_->hiddenSizeK + sliceSizeTmp_ - 1) / sliceSizeTmp_; // 向上取整
70
+ lastSliceSizeK_ = this->tilingData_->hiddenSizeK - (sliceTimeK_ - 1) * sliceSizeTmp_;
71
+ } else {
72
+ sliceTimeK_ = 1;
73
+ lastSliceSizeK_ = this->tilingData_->hiddenSizeK;
74
+ }
75
+ }
76
+
77
+ __aicore__ inline void ConvertCos(uint32_t repeatTimes) {
78
+ vconv_bf162f32(qkfp32Ubuf_ + this->cosPad_, commonUbuf_ + this->cosPad_, repeatTimes, 1, 1, DEFAULT_REPEAT_STRIDE,
79
+ DEFAULT_REPEAT_STRIDE / NUM_TWO);
80
+ vconv_bf162f32(qkfp32Ubuf_ + this->sinPad_, commonUbuf_ + this->sinPad_, repeatTimes, 1, 1, DEFAULT_REPEAT_STRIDE,
81
+ DEFAULT_REPEAT_STRIDE / NUM_TWO);
82
+ }
83
+
84
+ __aicore__ inline void CastB162F32(uint32_t repeatTimes1) {
85
+ vconv_bf162f32(qkfp32Ubuf_ + oriPosF32_, commonUbuf_ + this->oriPos_, repeatTimes1, 1, 1, DEFAULT_REPEAT_STRIDE,
86
+ DEFAULT_REPEAT_STRIDE / NUM_TWO);
87
+ vconv_bf162f32(qkfp32Ubuf_ + removeBeforeF32_, commonUbuf_ + this->removeBefore_, repeatTimes1, 1, 1,
88
+ DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE / NUM_TWO);
89
+ vconv_bf162f32(qkfp32Ubuf_ + PadBeforeF32_, commonUbuf_ + this->padBefore_, repeatTimes1, 1, 1,
90
+ DEFAULT_REPEAT_STRIDE, DEFAULT_REPEAT_STRIDE / NUM_TWO);
91
+ }
92
+
93
+ __aicore__ inline void CastF322B16(__gm__ QK_DTYPE *dst, __ubuf__ QK_DTYPE *src1, __ubuf__ float *src,
94
+ uint32_t repeatTimes1, uint32_t hiddenSize1) {
95
+ vconv_f322bf16r(src1, src, repeatTimes1, 1, 1, DEFAULT_REPEAT_STRIDE / NUM_TWO, DEFAULT_REPEAT_STRIDE);
96
+ set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
97
+ wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
98
+ copy_ubuf_to_gm(dst, src1, 0, 1, hiddenSize1 / ELE_NUM_FP16, 0, 0);
99
+ }
100
+
101
+ __aicore__ inline void Process(__gm__ uint8_t *extraGm) {
102
+ if (this->tilingData_->cosFormat == 1) {
103
+ pipe_barrier((PIPE_ALL));
104
+ this->ExpandCosSin(commonUbuf_, this->cosGm_, (__gm__ COS_DTYPE *)extraGm);
105
+ this->cosGm_ = (__gm__ COS_DTYPE *)extraGm;
106
+ pipe_barrier((PIPE_ALL));
107
+ this->ExpandCosSin(commonUbuf_, this->sinGm_,
108
+ (__gm__ COS_DTYPE *)extraGm + this->tilingData_->ntokens * this->tilingData_->headDim);
109
+ this->sinGm_ = (__gm__ COS_DTYPE *)extraGm + this->tilingData_->ntokens * this->tilingData_->headDim;
110
+ extraGm =
111
+ extraGm + this->tilingData_->ntokens * this->tilingData_->headDim * 4; // sizeof(uint8_t) * 2 = sizeof(half)
112
+ pipe_barrier((PIPE_ALL));
113
+ }
114
+
115
+ uint32_t dynamicSliceQ =
116
+ this->tilingData_->hiddenSizeQ > sliceSizeTmp_ ? sliceSizeTmp_ : this->tilingData_->hiddenSizeQ;
117
+ uint32_t headNumTempQ = dynamicSliceQ / this->tilingData_->headDim;
118
+
119
+ uint32_t dynamicSliceK =
120
+ this->tilingData_->hiddenSizeK > sliceSizeTmp_ ? sliceSizeTmp_ : this->tilingData_->hiddenSizeK;
121
+ uint32_t headNumTempK = dynamicSliceK / this->tilingData_->headDim;
122
+ uint32_t repeatTemp = (dynamicSliceQ + this->repeatSize_ - 1) / this->repeatSize_;
123
+ this->ExpandNeg(qkfp32Ubuf_, posOneF32_, headNumTempQ, repeatTemp);
124
+ for (uint32_t zz = 0; zz < this->dynamicRound_; ++zz) {
125
+ this->CosSinBroadcast(extraGm, zz, commonUbuf_, dynamicSliceQ);
126
+ if (this->tilingData_->headDim % ELE_NUM_FP16 == 0) {
127
+ pipe_barrier(PIPE_V);
128
+ ConvertCos(repeatTemp);
129
+ } else {
130
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
131
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
132
+ ConvertCos(repeatTemp);
133
+ }
134
+ for (uint32_t perSlice = 0; perSlice < sliceTimeQ_; ++perSlice) { // 核内每块
135
+ uint32_t dynamicSliceQTemp = (perSlice == sliceTimeQ_ - 1) ? lastSliceSizeQ_ : sliceSizeTmp_;
136
+ headNumTempQ = dynamicSliceQTemp / this->tilingData_->headDim;
137
+ uint32_t repeatTimeOnce = (dynamicSliceQTemp + this->repeatSize_ - 1) / this->repeatSize_;
138
+ pipe_barrier(PIPE_MTE2);
139
+ set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1);
140
+ wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1);
141
+ this->QkComm(this->qGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeQ +
142
+ zz * this->tilingData_->hiddenSizeQ + perSlice * sliceSizeTmp_,
143
+ extraGm, dynamicSliceQTemp, commonUbuf_, headNumTempQ);
144
+
145
+ if (this->alignRotary_ == 0) {
146
+ pipe_barrier((PIPE_V));
147
+ CastB162F32(repeatTimeOnce);
148
+
149
+ pipe_barrier((PIPE_V));
150
+ this->CalcRopeAlign(qkfp32Ubuf_, repeatTimeOnce, oriPosF32_, removeBeforeF32_, PadBeforeF32_);
151
+ } else {
152
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
153
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
154
+
155
+ CastB162F32(repeatTimeOnce);
156
+ pipe_barrier((PIPE_V));
157
+ this->CalcRope(qkfp32Ubuf_, repeatTimeOnce, oriPosF32_, removeBeforeF32_, PadBeforeF32_, posOneF32_,
158
+ PadBeforeF32_);
159
+ }
160
+
161
+ CastF322B16(this->outQGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeQ +
162
+ zz * this->tilingData_->hiddenSizeQ + perSlice * sliceSizeTmp_,
163
+ commonUbuf_ + this->padBefore_, qkfp32Ubuf_ + PadBeforeF32_, repeatTimeOnce, dynamicSliceQTemp);
164
+ pipe_barrier((PIPE_ALL));
165
+ }
166
+
167
+ for (uint32_t perSlice = 0; perSlice < sliceTimeK_; ++perSlice) { // 核内每块
168
+ uint32_t dynamicSliceKTemp = (perSlice == sliceTimeK_ - 1) ? lastSliceSizeK_ : sliceSizeTmp_;
169
+ headNumTempK = dynamicSliceKTemp / this->tilingData_->headDim;
170
+ uint32_t repeatTimeOnce = (dynamicSliceKTemp + this->repeatSize_ - 1) / this->repeatSize_;
171
+ pipe_barrier(PIPE_MTE2);
172
+ this->QkComm(this->kGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeK +
173
+ zz * this->tilingData_->hiddenSizeK + perSlice * sliceSizeTmp_,
174
+ extraGm, dynamicSliceKTemp, commonUbuf_, headNumTempK);
175
+
176
+ if (this->alignRotary_ == 0) {
177
+ pipe_barrier((PIPE_V));
178
+ CastB162F32(repeatTimeOnce);
179
+
180
+ pipe_barrier((PIPE_V));
181
+ this->CalcRopeAlign(qkfp32Ubuf_, repeatTimeOnce, oriPosF32_, removeBeforeF32_, PadBeforeF32_);
182
+ } else {
183
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
184
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
185
+ CastB162F32(repeatTimeOnce);
186
+
187
+ pipe_barrier((PIPE_V));
188
+ this->CalcRope(qkfp32Ubuf_, repeatTimeOnce, oriPosF32_, removeBeforeF32_, PadBeforeF32_, posOneF32_,
189
+ PadBeforeF32_);
190
+ }
191
+
192
+ CastF322B16(this->outKGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeK +
193
+ zz * this->tilingData_->hiddenSizeK + perSlice * sliceSizeTmp_,
194
+ commonUbuf_ + this->padBefore_, qkfp32Ubuf_ + PadBeforeF32_, repeatTimeOnce, dynamicSliceKTemp);
195
+ pipe_barrier((PIPE_ALL));
196
+ }
197
+ }
198
+ }
199
+
200
+ private:
201
+ AscendC::TQue<AscendC::QuePosition::VECIN, 1> outQueueCO2_;
202
+ AscendC::TQue<AscendC::QuePosition::VECIN, 1> qkfp32QueueCO2_;
203
+ __ubuf__ QK_DTYPE *commonUbuf_{nullptr};
204
+ __ubuf__ float *qkfp32Ubuf_{nullptr};
205
+ uint32_t oriPosF32_{0}; // fp32的buf中qk的位置
206
+ uint32_t PadBeforeF32_{0}; // fp32的buf中保存qk[-x : hiddensize - x]
207
+ uint32_t removeBeforeF32_{0}; // fp32的buf中保存qk[x : hiddensize + x]
208
+ uint32_t posOneF32_{0}; // fp32的buf中0 0 0 1 1 1的位置
209
+ uint32_t headDimAlign_; // 对齐的headDim
210
+ uint32_t sliceTimeQ_; // 切分块的次数
211
+ uint32_t lastSliceSizeQ_; // 最后一块的大小
212
+ uint32_t sliceTimeK_;
213
+ uint32_t lastSliceSizeK_;
214
+ uint32_t sliceSizeTmp_;
215
+ };
216
+
217
+ #endif