mindspore 2.3.0rc1__cp38-none-any.whl → 2.3.0rc2__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (318) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/bin/cache_admin +0 -0
  12. mindspore/bin/cache_server +0 -0
  13. mindspore/boost/boost_cell_wrapper.py +1 -1
  14. mindspore/boost/group_loss_scale_manager.py +1 -1
  15. mindspore/common/__init__.py +4 -2
  16. mindspore/common/_register_for_recompute.py +48 -0
  17. mindspore/common/_stub_tensor.py +1 -0
  18. mindspore/common/api.py +56 -4
  19. mindspore/common/dtype.py +5 -3
  20. mindspore/common/dump.py +2 -2
  21. mindspore/common/hook_handle.py +51 -4
  22. mindspore/common/initializer.py +1 -1
  23. mindspore/common/jit_config.py +17 -6
  24. mindspore/common/parameter.py +7 -2
  25. mindspore/common/recompute.py +247 -0
  26. mindspore/common/sparse_tensor.py +2 -2
  27. mindspore/common/symbol.py +1 -1
  28. mindspore/common/tensor.py +74 -36
  29. mindspore/communication/__init__.py +3 -3
  30. mindspore/communication/management.py +30 -30
  31. mindspore/context.py +28 -15
  32. mindspore/dataset/__init__.py +5 -5
  33. mindspore/dataset/audio/__init__.py +2 -2
  34. mindspore/dataset/audio/transforms.py +51 -51
  35. mindspore/dataset/callback/ds_callback.py +2 -2
  36. mindspore/dataset/engine/cache_client.py +1 -1
  37. mindspore/dataset/engine/datasets.py +3 -3
  38. mindspore/dataset/engine/datasets_audio.py +14 -14
  39. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  40. mindspore/dataset/engine/datasets_text.py +38 -38
  41. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  42. mindspore/dataset/engine/datasets_vision.py +68 -68
  43. mindspore/dataset/text/__init__.py +3 -3
  44. mindspore/dataset/text/transforms.py +26 -26
  45. mindspore/dataset/transforms/__init__.py +1 -1
  46. mindspore/dataset/vision/__init__.py +3 -3
  47. mindspore/dataset/vision/transforms.py +92 -92
  48. mindspore/dataset/vision/utils.py +1 -1
  49. mindspore/experimental/optim/adadelta.py +2 -2
  50. mindspore/experimental/optim/adagrad.py +2 -2
  51. mindspore/experimental/optim/adam.py +2 -2
  52. mindspore/experimental/optim/adamax.py +2 -2
  53. mindspore/experimental/optim/adamw.py +2 -2
  54. mindspore/experimental/optim/asgd.py +2 -2
  55. mindspore/experimental/optim/lr_scheduler.py +24 -20
  56. mindspore/experimental/optim/nadam.py +2 -2
  57. mindspore/experimental/optim/optimizer.py +1 -1
  58. mindspore/experimental/optim/radam.py +2 -2
  59. mindspore/experimental/optim/rmsprop.py +2 -2
  60. mindspore/experimental/optim/rprop.py +2 -2
  61. mindspore/experimental/optim/sgd.py +2 -2
  62. mindspore/hal/stream.py +2 -0
  63. mindspore/include/mindapi/base/types.h +5 -0
  64. mindspore/lib/libdnnl.so.2 +0 -0
  65. mindspore/lib/libmindspore.so +0 -0
  66. mindspore/lib/libmindspore_backend.so +0 -0
  67. mindspore/lib/libmindspore_common.so +0 -0
  68. mindspore/lib/libmindspore_core.so +0 -0
  69. mindspore/lib/libmindspore_glog.so.0 +0 -0
  70. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  71. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  72. mindspore/lib/libmindspore_shared_lib.so +0 -0
  73. mindspore/lib/libopencv_core.so.4.5 +0 -0
  74. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  75. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  76. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  77. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  78. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  79. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  80. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  81. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  82. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  83. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  84. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  85. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
  86. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  87. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  88. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
  89. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
  90. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
  91. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
  92. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
  93. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
  94. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  164. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
  165. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
  166. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
  167. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
  168. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
  169. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
  170. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  171. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/log.py +2 -2
  174. mindspore/mint/__init__.py +457 -0
  175. mindspore/mint/nn/__init__.py +430 -0
  176. mindspore/mint/nn/functional.py +424 -0
  177. mindspore/mint/optim/__init__.py +24 -0
  178. mindspore/mint/optim/adamw.py +186 -0
  179. mindspore/multiprocessing/__init__.py +4 -0
  180. mindspore/nn/__init__.py +3 -0
  181. mindspore/nn/cell.py +51 -47
  182. mindspore/nn/extend/__init__.py +29 -0
  183. mindspore/nn/extend/basic.py +140 -0
  184. mindspore/nn/extend/embedding.py +143 -0
  185. mindspore/nn/extend/layer/__init__.py +27 -0
  186. mindspore/nn/extend/layer/normalization.py +107 -0
  187. mindspore/nn/extend/pooling.py +117 -0
  188. mindspore/nn/generator.py +297 -0
  189. mindspore/nn/layer/basic.py +109 -1
  190. mindspore/nn/layer/container.py +2 -2
  191. mindspore/nn/layer/conv.py +6 -6
  192. mindspore/nn/layer/embedding.py +1 -1
  193. mindspore/nn/layer/normalization.py +21 -43
  194. mindspore/nn/layer/padding.py +4 -0
  195. mindspore/nn/optim/ada_grad.py +2 -2
  196. mindspore/nn/optim/adadelta.py +1 -1
  197. mindspore/nn/optim/adafactor.py +1 -1
  198. mindspore/nn/optim/adam.py +7 -7
  199. mindspore/nn/optim/adamax.py +2 -2
  200. mindspore/nn/optim/adasum.py +2 -2
  201. mindspore/nn/optim/asgd.py +2 -2
  202. mindspore/nn/optim/ftrl.py +1 -1
  203. mindspore/nn/optim/lamb.py +3 -3
  204. mindspore/nn/optim/lars.py +1 -1
  205. mindspore/nn/optim/lazyadam.py +2 -2
  206. mindspore/nn/optim/momentum.py +2 -2
  207. mindspore/nn/optim/optimizer.py +2 -2
  208. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  209. mindspore/nn/optim/rmsprop.py +2 -2
  210. mindspore/nn/optim/rprop.py +2 -2
  211. mindspore/nn/optim/sgd.py +2 -2
  212. mindspore/nn/optim/thor.py +2 -2
  213. mindspore/nn/wrap/cell_wrapper.py +9 -9
  214. mindspore/nn/wrap/grad_reducer.py +5 -5
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  216. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  217. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  218. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  219. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  220. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  221. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  222. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  223. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  224. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  225. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  226. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  227. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  228. mindspore/ops/extend/__init__.py +9 -1
  229. mindspore/ops/extend/array_func.py +134 -27
  230. mindspore/ops/extend/math_func.py +3 -3
  231. mindspore/ops/extend/nn_func.py +363 -2
  232. mindspore/ops/function/__init__.py +19 -2
  233. mindspore/ops/function/array_func.py +463 -439
  234. mindspore/ops/function/clip_func.py +7 -18
  235. mindspore/ops/function/grad/grad_func.py +5 -5
  236. mindspore/ops/function/linalg_func.py +4 -4
  237. mindspore/ops/function/math_func.py +260 -243
  238. mindspore/ops/function/nn_func.py +825 -62
  239. mindspore/ops/function/random_func.py +73 -4
  240. mindspore/ops/function/sparse_unary_func.py +1 -1
  241. mindspore/ops/function/vmap_func.py +1 -1
  242. mindspore/ops/functional.py +2 -2
  243. mindspore/ops/op_info_register.py +1 -31
  244. mindspore/ops/operations/__init__.py +2 -3
  245. mindspore/ops/operations/_grad_ops.py +2 -107
  246. mindspore/ops/operations/_inner_ops.py +5 -5
  247. mindspore/ops/operations/_sequence_ops.py +2 -2
  248. mindspore/ops/operations/array_ops.py +11 -233
  249. mindspore/ops/operations/comm_ops.py +32 -32
  250. mindspore/ops/operations/custom_ops.py +7 -89
  251. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  252. mindspore/ops/operations/math_ops.py +13 -163
  253. mindspore/ops/operations/nn_ops.py +9 -316
  254. mindspore/ops/operations/random_ops.py +1 -1
  255. mindspore/ops/operations/sparse_ops.py +3 -3
  256. mindspore/ops/primitive.py +2 -2
  257. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  258. mindspore/ops_generate/arg_handler.py +24 -0
  259. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  260. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  261. mindspore/ops_generate/pyboost_utils.py +2 -17
  262. mindspore/parallel/__init__.py +3 -2
  263. mindspore/parallel/_auto_parallel_context.py +106 -1
  264. mindspore/parallel/_parallel_serialization.py +34 -2
  265. mindspore/parallel/_utils.py +16 -0
  266. mindspore/parallel/algo_parameter_config.py +4 -4
  267. mindspore/parallel/checkpoint_transform.py +249 -77
  268. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  269. mindspore/parallel/parameter_broadcast.py +1 -1
  270. mindspore/parallel/shard.py +1 -1
  271. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  272. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  273. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  274. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  275. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  276. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  277. mindspore/profiler/parser/profiler_info.py +11 -1
  278. mindspore/profiler/profiling.py +13 -5
  279. mindspore/rewrite/api/node.py +12 -12
  280. mindspore/rewrite/api/symbol_tree.py +11 -11
  281. mindspore/run_check/_check_version.py +1 -1
  282. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  283. mindspore/train/amp.py +4 -4
  284. mindspore/train/anf_ir_pb2.py +8 -2
  285. mindspore/train/callback/_backup_and_restore.py +2 -2
  286. mindspore/train/callback/_callback.py +4 -4
  287. mindspore/train/callback/_checkpoint.py +2 -2
  288. mindspore/train/callback/_early_stop.py +2 -2
  289. mindspore/train/callback/_landscape.py +4 -4
  290. mindspore/train/callback/_loss_monitor.py +2 -2
  291. mindspore/train/callback/_on_request_exit.py +2 -2
  292. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  293. mindspore/train/callback/_summary_collector.py +2 -2
  294. mindspore/train/callback/_time_monitor.py +2 -2
  295. mindspore/train/dataset_helper.py +8 -3
  296. mindspore/train/loss_scale_manager.py +2 -2
  297. mindspore/train/metrics/metric.py +3 -3
  298. mindspore/train/mind_ir_pb2.py +22 -17
  299. mindspore/train/model.py +15 -15
  300. mindspore/train/serialization.py +18 -18
  301. mindspore/train/summary/summary_record.py +7 -7
  302. mindspore/train/train_thor/convert_utils.py +3 -3
  303. mindspore/version.py +1 -1
  304. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  305. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +309 -262
  306. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
  307. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
  308. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
  309. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
  310. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
  311. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
  312. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
  313. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
  314. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
  315. /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
  316. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  317. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  318. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,65 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef MATMUL_TILING_UTILS_H
18
+ #define MATMUL_TILING_UTILS_H
19
+
20
+ #include <stdint.h>
21
+ #include <sstream>
22
+ #include <cstdlib>
23
+ #include <vector>
24
+
25
+ namespace mindspore {
26
+ namespace internal {
27
+ namespace tiling {
28
+
29
+ static std::vector<int> getMatMulTilingFromEnv() {
30
+ std::vector<int> result;
31
+ auto env_name = "INTERNAL_MATMUL_TILING";
32
+ const char* envVarValue = std::getenv(env_name);
33
+
34
+ if (envVarValue != nullptr) {
35
+ std::string envVarString(envVarValue);
36
+ std::stringstream ss(envVarString);
37
+ std::string item;
38
+
39
+ while (std::getline(ss, item, ',')) {
40
+ result.push_back(std::stoi(item));
41
+ }
42
+ }
43
+
44
+ return result;
45
+ }
46
+
47
+
48
+ static bool getShuffleFlagFromEnv() {
49
+ auto env_name = "CUSTOM_MATMUL_SHUFFLE";
50
+ const char* envVarValue = std::getenv(env_name);
51
+ if (envVarValue != nullptr) {
52
+ std::string envVarString(envVarValue);
53
+ if (envVarString != "0" && envVarString != "off") {
54
+ return true;
55
+ }
56
+ return false;
57
+ }
58
+ return true;
59
+ }
60
+
61
+
62
+ } // namespace tiling
63
+ } // namespace internal
64
+ } // namespace mindspore
65
+ #endif // MATMUL_TILING_UTILS_H
@@ -24,11 +24,12 @@
24
24
  #include "asdops/tensor.h"
25
25
 
26
26
  #include "utils.h"
27
- // #include "pp_matmul_info.h"
28
27
  #include "backend_param.h"
29
- #include "tiling_data.h"
28
+ #include "matmul_common/pp_matmul_info.h"
29
+ #include "matmul_common/tiling_utils.h"
30
+ #include "matmul_common/tiling_data.h"
31
+ #include "matmul_common/pp_matmul_common_tiling.h"
30
32
  #include "param/matmul_qkv_param.h"
31
- // #include "pp_matmul_common_tiling.h"
32
33
  #include "tune_repo/utils.h"
33
34
 
34
35
  #include "internal_kernel.h"
@@ -39,6 +40,8 @@
39
40
  namespace mindspore {
40
41
  namespace internal {
41
42
 
43
+ using namespace tiling;
44
+
42
45
  class MatMulStridedSliceFusionImpl : public InternelKernelImpl {
43
46
  public:
44
47
  MatMulStridedSliceFusionImpl(const OpParamPtr &param) : InternelKernelImpl(param){};
@@ -48,7 +51,8 @@ class MatMulStridedSliceFusionImpl : public InternelKernelImpl {
48
51
  int Launch() override;
49
52
  size_t GetTilingBufSize() override;
50
53
  int Tiling(HostRawBuf &tilingBuf) override;
51
- int TilingLLMCustom(HostRawBuf &tilingBuf);
54
+ void TilingBasicFromPp(uint32_t &blockDim, PpTilingData &tilingdata);
55
+ int TilingLLMCustom(HostRawBuf &tilingBuf, const uint32_t &blockDim, const PpTilingData &tilingdata, bool has_tuned);
52
56
  std::vector<uint64_t> GetWorkSpaceSize() override;
53
57
  int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
54
58
 
@@ -66,8 +70,8 @@ class MatMulStridedSliceFusionImpl : public InternelKernelImpl {
66
70
 
67
71
  REPO tuningTable_;
68
72
  tiling::MatmulStridedSliceFusionTilingData t_;
69
- void GetTunedKey(std::vector<int> &tune_key);
70
- void GetTunedValue(const std::vector<int> &tuned_config);
73
+ std::vector<int> GetTunedKey();
74
+ void SetTunedValue(const std::vector<int> &tuned_config);
71
75
  };
72
76
 
73
77
  } // namespace internal
@@ -27,7 +27,6 @@
27
27
  #include "asdops/params/norm.h"
28
28
  #include "asdops/params/softmax.h"
29
29
  #include "asdops/params/split.h"
30
- #include "attention_param.h"
31
30
  #include "asdops/params/expand.h"
32
31
  #include "asdops/params/fill.h"
33
32
  #include "asdops/params/reduce.h"
@@ -99,6 +98,10 @@ struct AddLayerNormParam {
99
98
  };
100
99
 
101
100
  struct ApplyRotaryPosEmbParam {
101
+ // cosFormat=0 shape是[maxSeqLen, headDim], cos/sin不交替
102
+ // cosFormat=1 shape是[maxSeqLen, headDim], cos/sin交替
103
+ // cosFormat=2 shape是[batch*seqLen, headDim], cos/sin不交替
104
+ // cosFormat=3 shape是[batch*seqLen, headDim], cos/sin交替
102
105
  int32_t cosFormat{0};
103
106
  };
104
107
 
@@ -0,0 +1,41 @@
1
+ #ifndef BS_FLASHATTENTION_BS__ATTENTION_MIX_HWSYNC_H
2
+ #define BS_FLASHATTENTION_BS__ATTENTION_MIX_HWSYNC_H
3
+ constexpr float DROPOUT_PROP = 0.5;
4
+ constexpr uint32_t LOOP_LEN = 5;
5
+ constexpr uint32_t UB_HALF_BUF_SIZE = 8 * 2048;
6
+ constexpr uint32_t BIT_UINT8 = 8;
7
+ constexpr uint32_t BIT_BLOCK = 256;
8
+ constexpr uint32_t BLOCK_SIZE = 16;
9
+ constexpr uint32_t VECTOR_SIZE = 128;
10
+ constexpr uint32_t VECTOR_SIZE_FP32 = 64;
11
+ constexpr uint32_t CUBE_MATRIX_SIZE = 256;// 16 * 16
12
+ constexpr uint64_t UB_UINT8_BLOCK_SIZE = 16384; // 64 * 128 * 2B
13
+ constexpr uint64_t UB_UINT8_LINE_SIZE = 512; // 64 * 4B,申请两倍空间防踩踏。
14
+ constexpr uint64_t UB_FLOAT_LINE_SIZE = 128; // 64,申请两倍空间防踩踏。
15
+ constexpr uint64_t UB_HALF_LINE_SIZE = 256; // UB_FLOAT_LINE_SIZE * 2
16
+
17
+ constexpr uint32_t L0AB_HALF_BUF_SIZE = 16384; // 128 * 128
18
+ constexpr uint64_t L1_SIZE = 512 * 1024; // 512KB
19
+ constexpr uint64_t L0AB_UINT8_BLOCK_SIZE = 32768; // 128 * 128 * 2B
20
+ constexpr uint64_t L1_MAX_SHARE_NUM = (L1_SIZE - 8 * L0AB_UINT8_BLOCK_SIZE) / L0AB_UINT8_BLOCK_SIZE / 2;
21
+ constexpr uint64_t SUB_SP_SIZE = 2048 * 8; // 1024*16, 2048*8, 4096*4, 8192*2, 16K*1,五种分块方法
22
+
23
+ enum class L1Mode{load, // 读取数据至L1的share区
24
+ share, // 使用share区的数据
25
+ noshare}; // 不读且不用share区
26
+
27
+ inline uint64_t ceil(uint64_t y, uint64_t x) {
28
+ return (y + x - 1) / x;
29
+ }
30
+
31
+ inline uint64_t round(uint64_t y, uint64_t x) {
32
+ return ceil(y, x) * x;
33
+ }
34
+
35
+ #if BFLOAT16
36
+ #define CALC_DATA_TYPE bfloat16_t
37
+ #else
38
+ #define CALC_DATA_TYPE half
39
+ #endif
40
+
41
+ #endif //BS_FLASHATTENTION_BS__ATTENTION_MIX_HWSYNC_H
@@ -24,7 +24,7 @@
24
24
  #include "asdops/tensor.h"
25
25
 
26
26
  #include "internal_kernel.h"
27
-
27
+ #include "param/attention_param.h"
28
28
  #include "acl_rt.h"
29
29
 
30
30
  #include <unordered_map>
@@ -0,0 +1,63 @@
1
+ #ifndef __BS_ATTENTION_TILING_H__
2
+ #define __BS_ATTENTION_TILING_H__
3
+
4
+ #pragma pack (8)
5
+ typedef struct {
6
+ uint64_t batch_size;
7
+ uint64_t num_heads;
8
+ uint64_t max_seqlen;
9
+ uint64_t head_dim;
10
+ uint64_t num_group;
11
+ uint64_t q_seqlen;
12
+ uint64_t kv_seqlen;
13
+ uint64_t table_block_size;
14
+ uint64_t sync_addr;
15
+ uint64_t core_num;
16
+ float tor;
17
+ } BSAttentionTilingData;
18
+ #pragma pack()
19
+
20
+ #define MAX_CORE_NUM 25
21
+ #define ATTENTION_DEBUG false // 开启时会对S/P写入调试数据
22
+ #define ROWMAX true
23
+ #define OP_NAME PagedAttention
24
+ #define BUFFER_NUM 4 // 核间流水数,暂不支持修改
25
+ constexpr uint64_t WORKSPACE_MAX_SEQLEN = 16384; // max seqlen
26
+ constexpr uint64_t WORKSPACE_SIZE = 64 * WORKSPACE_MAX_SEQLEN;
27
+
28
+ #if BFLOAT16
29
+ #define TYPE_NAME _bf16
30
+ #else
31
+ #define TYPE_NAME _fp16
32
+ #endif
33
+
34
+ #if BSH
35
+ #define LAYOUT_NAME _BSH
36
+ #else
37
+ #define LAYOUT_NAME _BNSD
38
+ #endif
39
+
40
+ #define TRI_NAME _full
41
+
42
+ #define CONCAT_(A, B, C, D, E) A##B##C##D##E
43
+ #define CONCAT(A, B, C, D, E) CONCAT_(A, B, C, D, E)
44
+ #define FUNC_NAME_AIC CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aic)
45
+ #define FUNC_NAME_AIV CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aiv)
46
+
47
+ // **************mask patten模式**************//
48
+ // 第一种:下三角,开启LOWER_TRIANGLE时会直接采用下三角,不依赖mask
49
+ // #define LOWER_TRIANGLE false
50
+
51
+ // 第二种:Block Sparse,LOWER_TRIANGLE关闭时,开启BLOCK_SPARSE,会使用pre_token和next_token,不依赖mask(待开发)
52
+ // #define BLOCK_SPARSE false
53
+
54
+ // 第三种:读取MASK,LOWER_TRIANGLE和BLOCK_SPARSE关闭时,开启AMASK,会使用mask作为输入
55
+ // #define AMASK true
56
+
57
+ // 第四种:全矩阵,LOWER_TRIANGLE、BLOCK_SPARSE和AMASK如果全部关闭,则此attention采用全矩阵运算,不抑制S中的元素
58
+ // *******************************************//
59
+
60
+ constexpr uint64_t WORKSPACE_MAX_SEQLEN_BLOCK = WORKSPACE_MAX_SEQLEN / 16;
61
+ constexpr uint64_t BUFFER_SIZE = MAX_CORE_NUM * WORKSPACE_SIZE * sizeof(uint16_t);
62
+
63
+ #endif
@@ -34,11 +34,11 @@ struct AddParam : public OpParam {
34
34
  DIMS input1_dims_;
35
35
  DIMS input2_dims_;
36
36
  bool canSupport() {
37
- if (ADD_SUPPORT_DTYPE.find(input1_dtype_) == ADD_SUPPORT_DTYPE.end()) {
37
+ if (ADD_SUPPORT_DTYPE.find(input1_dtype_) == ADD_SUPPORT_DTYPE.end() || input1_dims_ != input2_dims_) {
38
38
  return false;
39
39
  }
40
40
  if (input1_dims_ == input2_dims_) {
41
- return true;
41
+ return false;
42
42
  }
43
43
  if (std::abs(int(input1_dims_.size()) - int(input2_dims_.size())) > 1) {
44
44
  return false;
@@ -16,12 +16,21 @@
16
16
  #ifndef ATTENTION_PARAMS_H
17
17
  #define ATTENTION_PARAMS_H
18
18
 
19
+ #include "types.h"
20
+ #include "op_param.h"
21
+
19
22
  namespace mindspore {
20
23
  namespace internal {
21
- struct FlashAttentionScoreParam {
24
+ struct FlashAttentionScoreParam : public OpParam {
25
+ int head_num = 0;
26
+ int inner_precise = 0;
27
+ int pre_tokens = 2147483647;
28
+ int next_tokens = 0;
29
+ int sparse_mode = 0;
22
30
  };
23
31
 
24
- struct PagedAttentionParam {
32
+ struct PagedAttentionParam : public OpParam {
33
+ int inner_precise = 0;
25
34
  };
26
35
  } // namespace internal
27
36
  } // namespace mindspore
@@ -0,0 +1,37 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef MATMUL_EXT_PARAMS_H_
17
+ #define MATMUL_EXT_PARAMS_H_
18
+
19
+ #include "types.h"
20
+ #include "op_param.h"
21
+
22
+ namespace mindspore {
23
+ namespace internal {
24
+
25
+ struct MatMulExtParam : public OpParam {
26
+ int input_dtype = -1;
27
+ int weight_dtype = -1;
28
+ int output_dtype = -1;
29
+ bool with_relu = false;
30
+ bool with_gelu = false;
31
+ bool with_bias = false;
32
+ bool with_bias_fastgelu = false;
33
+ };
34
+
35
+ } // namespace internal
36
+ } // namespace mindspore
37
+ #endif
@@ -0,0 +1,45 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef SUB_PARAMS_H_
17
+ #define SUB_PARAMS_H_
18
+
19
+ #include "types.h"
20
+ #include "op_param.h"
21
+ #include <set>
22
+
23
+ namespace mindspore {
24
+ namespace internal {
25
+ struct SubParam : public OpParam {
26
+ TensorDType input1_dtype_;
27
+ TensorDType input2_dtype_;
28
+ DIMS input1_dims_;
29
+ DIMS input2_dims_;
30
+ bool canSupport() {
31
+ if (input2_dtype_ != AsdOps::TensorDType::TENSOR_DTYPE_INT32) {
32
+ return false;
33
+ }
34
+ if (input2_dims_.size() == 0 || (input2_dims_.size() == 1 && input2_dims_[0] == 1)) {
35
+ return true;
36
+ }
37
+ if (input1_dims_.size() == 0 || (input1_dims_.size() == 1 && input1_dims_[0] == 1)) {
38
+ return true;
39
+ }
40
+ return false;
41
+ }
42
+ };
43
+ } // namespace internal
44
+ } // namespace mindspore
45
+ #endif
@@ -19,8 +19,7 @@
19
19
 
20
20
  struct ReshapeAndCacheTilingData {
21
21
  int32_t num_tokens;
22
- int32_t num_heads;
23
- int32_t head_size;
22
+ int32_t hidden_size;
24
23
  };
25
24
 
26
25
  #endif
@@ -0,0 +1,23 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef MS_KERNELS_INTERNAL_KERNEL_ASCENDC_RMS_NORM_H_
18
+ #define MS_KERNELS_INTERNAL_KERNEL_ASCENDC_RMS_NORM_H_
19
+
20
+ void rms_norm_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *x, uint8_t *gamma, uint8_t *y, uint8_t *rstd,
21
+ uint8_t *workspace, uint8_t *tiling);
22
+
23
+ #endif
@@ -0,0 +1,175 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*!
18
+ * \file rms_norm_base.h
19
+ * \brief
20
+ */
21
+ #ifndef _RMS_NORM_BASE_H_
22
+ #define _RMS_NORM_BASE_H_
23
+ #include "kernel_operator.h"
24
+
25
+ using namespace AscendC;
26
+
27
+ #if __CCE_AICORE__ != 220
28
+ #define bfloat16_t int16_t
29
+ #endif
30
+ constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue
31
+ constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
32
+ constexpr int32_t NUM_PER_BLK_FP32 = 8;
33
+ constexpr float MINUS_HALF = -0.5;
34
+ constexpr float ZERO = 0;
35
+ constexpr float ONE = 1;
36
+
37
+ template <typename T>
38
+ __aicore__ inline T CeilDiv(T x, T y) {
39
+ return y == 0 ? x : (x + y - 1) / y;
40
+ }
41
+
42
+ template <typename Tp, Tp v>
43
+ struct integral_constant {
44
+ static constexpr Tp value = v;
45
+ };
46
+ using true_type = integral_constant<bool, true>;
47
+ using false_type = integral_constant<bool, false>;
48
+ template <typename, typename>
49
+ struct is_same : public false_type {};
50
+ template <typename Tp>
51
+ struct is_same<Tp, Tp> : public true_type {};
52
+
53
+ __aicore__ inline void ReduceSumFP32(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
54
+ const LocalTensor<float> &work_local, int32_t count) {
55
+ // count need smaller than 255 repeat
56
+ if (g_coreType == AIV) {
57
+ uint64_t mask = NUM_PER_REP_FP32;
58
+ int32_t repeatTimes = count / NUM_PER_REP_FP32;
59
+ int32_t tailCount = count % NUM_PER_REP_FP32;
60
+ int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
61
+ BinaryRepeatParams repeatParams;
62
+ repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
63
+ repeatParams.src0BlkStride = 1;
64
+ repeatParams.src1RepStride = 0;
65
+ repeatParams.src1BlkStride = 1;
66
+ repeatParams.dstRepStride = 0;
67
+ repeatParams.dstBlkStride = 1;
68
+ Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
69
+ pipe_barrier(PIPE_V);
70
+ if (likely(repeatTimes > 0)) {
71
+ Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
72
+ pipe_barrier(PIPE_V);
73
+ }
74
+ if (unlikely(tailCount != 0)) {
75
+ Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
76
+ pipe_barrier(PIPE_V);
77
+ }
78
+ AscendCUtils::SetMask<float>(NUM_PER_REP_FP32);
79
+ vcadd((__ubuf__ float *)dst_local.GetPhyAddr(), (__ubuf__ float *)work_local.GetPhyAddr(), 1, 0, 1, 0, false);
80
+ pipe_barrier(PIPE_V);
81
+ }
82
+ }
83
+
84
+ __aicore__ inline void ReduceSumCustom(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
85
+ const LocalTensor<float> &work_local, int32_t count) {
86
+ #if __CCE_AICORE__ == 220
87
+ ReduceSumFP32(dst_local, src_local, work_local, count);
88
+ #else
89
+ ReduceSum(dst_local, src_local, dst_local, count);
90
+ #endif
91
+ }
92
+
93
+ __aicore__ inline void ReduceSumFP32ToBlock(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
94
+ const LocalTensor<float> &work_local, int32_t count) {
95
+ // count need smaller than 255 repeat
96
+ uint64_t mask = NUM_PER_REP_FP32;
97
+ int32_t repeatTimes = count / NUM_PER_REP_FP32;
98
+ int32_t tailCount = count % NUM_PER_REP_FP32;
99
+ int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
100
+ BinaryRepeatParams repeatParams;
101
+ repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
102
+ repeatParams.src0BlkStride = 1;
103
+ repeatParams.src1RepStride = 0;
104
+ repeatParams.src1BlkStride = 1;
105
+ repeatParams.dstRepStride = 0;
106
+ repeatParams.dstBlkStride = 1;
107
+ Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
108
+ pipe_barrier(PIPE_V);
109
+ if (likely(repeatTimes > 0)) {
110
+ Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
111
+ pipe_barrier(PIPE_V);
112
+ }
113
+ if (unlikely(tailCount != 0)) {
114
+ Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
115
+ pipe_barrier(PIPE_V);
116
+ }
117
+ BlockReduceSum(dst_local, work_local, 1, mask, 1, 1, DEFAULT_REPEAT_STRIDE);
118
+ pipe_barrier(PIPE_V);
119
+ }
120
+
121
+ __aicore__ inline void BlockReduceSumFP32(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
122
+ int32_t count) {
123
+ // count need multiple of 8
124
+ int32_t repeatTimes = count / NUM_PER_REP_FP32;
125
+ int32_t tailCount = count % NUM_PER_REP_FP32;
126
+ int32_t dstAddr = repeatTimes * 8;
127
+ int32_t srcAddr = repeatTimes * NUM_PER_REP_FP32;
128
+ if (likely(repeatTimes > 0)) {
129
+ BlockReduceSum(dst_local, src_local, repeatTimes, NUM_PER_REP_FP32, 1, 1, DEFAULT_REPEAT_STRIDE);
130
+ pipe_barrier(PIPE_V);
131
+ }
132
+ if (tailCount != 0) {
133
+ BlockReduceSum(dst_local[dstAddr], src_local[srcAddr], 1, tailCount, 1, 1, DEFAULT_REPEAT_STRIDE);
134
+ pipe_barrier(PIPE_V);
135
+ }
136
+ }
137
+
138
+ template <typename T, typename U, typename R>
139
+ __aicore__ inline void DataCopyCustom(const U &dstTensor, const R &srcTensor, const uint32_t count) {
140
+ #if __CCE_AICORE__ == 220
141
+ DataCopyParams copyParams;
142
+ copyParams.blockLen = count * sizeof(T);
143
+ copyParams.blockCount = 1;
144
+ if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
145
+ DataCopyPadParams padParams;
146
+ DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
147
+ } else {
148
+ DataCopyPad(dstTensor, srcTensor, copyParams);
149
+ }
150
+ #else
151
+ // only support count greater than 32byte
152
+ int32_t numPerBlock = ONE_BLK_SIZE / sizeof(T);
153
+ if (count % numPerBlock == 0) {
154
+ DataCopy(dstTensor, srcTensor, count);
155
+ } else {
156
+ if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
157
+ int32_t num = AlignUp(count, numPerBlock);
158
+ DataCopy(dstTensor, srcTensor, num);
159
+ } else {
160
+ int32_t num = count / numPerBlock * numPerBlock;
161
+ DataCopy(dstTensor, srcTensor, num);
162
+ set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
163
+ wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
164
+ for (int32_t i = 0; i < numPerBlock; i++) {
165
+ T tensorValue = srcTensor.GetValue(count - numPerBlock + i);
166
+ srcTensor.SetValue(i, tensorValue);
167
+ }
168
+ set_flag(PIPE_S, PIPE_MTE3, EVENT_ID0);
169
+ wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID0);
170
+ DataCopy(dstTensor[count - numPerBlock], srcTensor, numPerBlock);
171
+ }
172
+ }
173
+ #endif
174
+ }
175
+ #endif // RMS_NORM_BASE_H_