mindspore 2.3.0rc1__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (316) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/boost/boost_cell_wrapper.py +1 -1
  12. mindspore/boost/group_loss_scale_manager.py +1 -1
  13. mindspore/common/__init__.py +4 -2
  14. mindspore/common/_register_for_recompute.py +48 -0
  15. mindspore/common/_stub_tensor.py +1 -0
  16. mindspore/common/api.py +56 -4
  17. mindspore/common/dtype.py +5 -3
  18. mindspore/common/dump.py +2 -2
  19. mindspore/common/hook_handle.py +51 -4
  20. mindspore/common/initializer.py +1 -1
  21. mindspore/common/jit_config.py +17 -6
  22. mindspore/common/parameter.py +7 -2
  23. mindspore/common/recompute.py +247 -0
  24. mindspore/common/sparse_tensor.py +2 -2
  25. mindspore/common/symbol.py +1 -1
  26. mindspore/common/tensor.py +74 -36
  27. mindspore/communication/__init__.py +3 -3
  28. mindspore/communication/management.py +30 -30
  29. mindspore/context.py +28 -15
  30. mindspore/dataset/__init__.py +5 -5
  31. mindspore/dataset/audio/__init__.py +2 -2
  32. mindspore/dataset/audio/transforms.py +51 -51
  33. mindspore/dataset/callback/ds_callback.py +2 -2
  34. mindspore/dataset/engine/cache_client.py +1 -1
  35. mindspore/dataset/engine/datasets.py +3 -3
  36. mindspore/dataset/engine/datasets_audio.py +14 -14
  37. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  38. mindspore/dataset/engine/datasets_text.py +38 -38
  39. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  40. mindspore/dataset/engine/datasets_vision.py +68 -68
  41. mindspore/dataset/text/__init__.py +3 -3
  42. mindspore/dataset/text/transforms.py +26 -26
  43. mindspore/dataset/transforms/__init__.py +1 -1
  44. mindspore/dataset/vision/__init__.py +3 -3
  45. mindspore/dataset/vision/transforms.py +92 -92
  46. mindspore/dataset/vision/utils.py +1 -1
  47. mindspore/experimental/optim/adadelta.py +2 -2
  48. mindspore/experimental/optim/adagrad.py +2 -2
  49. mindspore/experimental/optim/adam.py +2 -2
  50. mindspore/experimental/optim/adamax.py +2 -2
  51. mindspore/experimental/optim/adamw.py +2 -2
  52. mindspore/experimental/optim/asgd.py +2 -2
  53. mindspore/experimental/optim/lr_scheduler.py +24 -20
  54. mindspore/experimental/optim/nadam.py +2 -2
  55. mindspore/experimental/optim/optimizer.py +1 -1
  56. mindspore/experimental/optim/radam.py +2 -2
  57. mindspore/experimental/optim/rmsprop.py +2 -2
  58. mindspore/experimental/optim/rprop.py +2 -2
  59. mindspore/experimental/optim/sgd.py +2 -2
  60. mindspore/hal/stream.py +2 -0
  61. mindspore/include/mindapi/base/types.h +5 -0
  62. mindspore/lib/libdnnl.so.2 +0 -0
  63. mindspore/lib/libmindspore.so +0 -0
  64. mindspore/lib/libmindspore_backend.so +0 -0
  65. mindspore/lib/libmindspore_common.so +0 -0
  66. mindspore/lib/libmindspore_core.so +0 -0
  67. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  68. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  69. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  70. mindspore/lib/libmindspore_shared_lib.so +0 -0
  71. mindspore/lib/libopencv_core.so.4.5 +0 -0
  72. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  73. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  75. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  76. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  77. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  78. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  79. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  80. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  81. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  82. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  83. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
  84. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  85. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  86. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
  87. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
  88. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
  89. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
  90. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
  91. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
  92. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
  93. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
  94. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
  164. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
  165. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
  166. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
  167. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
  168. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  169. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  170. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  171. mindspore/log.py +2 -2
  172. mindspore/mint/__init__.py +457 -0
  173. mindspore/mint/nn/__init__.py +430 -0
  174. mindspore/mint/nn/functional.py +424 -0
  175. mindspore/mint/optim/__init__.py +24 -0
  176. mindspore/mint/optim/adamw.py +186 -0
  177. mindspore/multiprocessing/__init__.py +4 -0
  178. mindspore/nn/__init__.py +3 -0
  179. mindspore/nn/cell.py +51 -47
  180. mindspore/nn/extend/__init__.py +29 -0
  181. mindspore/nn/extend/basic.py +140 -0
  182. mindspore/nn/extend/embedding.py +143 -0
  183. mindspore/nn/extend/layer/__init__.py +27 -0
  184. mindspore/nn/extend/layer/normalization.py +107 -0
  185. mindspore/nn/extend/pooling.py +117 -0
  186. mindspore/nn/generator.py +297 -0
  187. mindspore/nn/layer/basic.py +109 -1
  188. mindspore/nn/layer/container.py +2 -2
  189. mindspore/nn/layer/conv.py +6 -6
  190. mindspore/nn/layer/embedding.py +1 -1
  191. mindspore/nn/layer/normalization.py +21 -43
  192. mindspore/nn/layer/padding.py +4 -0
  193. mindspore/nn/optim/ada_grad.py +2 -2
  194. mindspore/nn/optim/adadelta.py +1 -1
  195. mindspore/nn/optim/adafactor.py +1 -1
  196. mindspore/nn/optim/adam.py +7 -7
  197. mindspore/nn/optim/adamax.py +2 -2
  198. mindspore/nn/optim/adasum.py +2 -2
  199. mindspore/nn/optim/asgd.py +2 -2
  200. mindspore/nn/optim/ftrl.py +1 -1
  201. mindspore/nn/optim/lamb.py +3 -3
  202. mindspore/nn/optim/lars.py +1 -1
  203. mindspore/nn/optim/lazyadam.py +2 -2
  204. mindspore/nn/optim/momentum.py +2 -2
  205. mindspore/nn/optim/optimizer.py +2 -2
  206. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  207. mindspore/nn/optim/rmsprop.py +2 -2
  208. mindspore/nn/optim/rprop.py +2 -2
  209. mindspore/nn/optim/sgd.py +2 -2
  210. mindspore/nn/optim/thor.py +2 -2
  211. mindspore/nn/wrap/cell_wrapper.py +9 -9
  212. mindspore/nn/wrap/grad_reducer.py +5 -5
  213. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  214. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  215. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  216. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  217. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  218. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  219. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  220. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  221. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  222. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  223. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  224. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  225. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  226. mindspore/ops/extend/__init__.py +9 -1
  227. mindspore/ops/extend/array_func.py +134 -27
  228. mindspore/ops/extend/math_func.py +3 -3
  229. mindspore/ops/extend/nn_func.py +363 -2
  230. mindspore/ops/function/__init__.py +19 -2
  231. mindspore/ops/function/array_func.py +463 -439
  232. mindspore/ops/function/clip_func.py +7 -18
  233. mindspore/ops/function/grad/grad_func.py +5 -5
  234. mindspore/ops/function/linalg_func.py +4 -4
  235. mindspore/ops/function/math_func.py +260 -243
  236. mindspore/ops/function/nn_func.py +825 -62
  237. mindspore/ops/function/random_func.py +73 -4
  238. mindspore/ops/function/sparse_unary_func.py +1 -1
  239. mindspore/ops/function/vmap_func.py +1 -1
  240. mindspore/ops/functional.py +2 -2
  241. mindspore/ops/op_info_register.py +1 -31
  242. mindspore/ops/operations/__init__.py +2 -3
  243. mindspore/ops/operations/_grad_ops.py +2 -107
  244. mindspore/ops/operations/_inner_ops.py +5 -5
  245. mindspore/ops/operations/_sequence_ops.py +2 -2
  246. mindspore/ops/operations/array_ops.py +11 -233
  247. mindspore/ops/operations/comm_ops.py +32 -32
  248. mindspore/ops/operations/custom_ops.py +7 -89
  249. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  250. mindspore/ops/operations/math_ops.py +13 -163
  251. mindspore/ops/operations/nn_ops.py +9 -316
  252. mindspore/ops/operations/random_ops.py +1 -1
  253. mindspore/ops/operations/sparse_ops.py +3 -3
  254. mindspore/ops/primitive.py +2 -2
  255. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  256. mindspore/ops_generate/arg_handler.py +24 -0
  257. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  258. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  259. mindspore/ops_generate/pyboost_utils.py +2 -17
  260. mindspore/parallel/__init__.py +3 -2
  261. mindspore/parallel/_auto_parallel_context.py +106 -1
  262. mindspore/parallel/_parallel_serialization.py +34 -2
  263. mindspore/parallel/_utils.py +16 -0
  264. mindspore/parallel/algo_parameter_config.py +4 -4
  265. mindspore/parallel/checkpoint_transform.py +249 -77
  266. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  267. mindspore/parallel/parameter_broadcast.py +1 -1
  268. mindspore/parallel/shard.py +1 -1
  269. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  270. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  271. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  272. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  273. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  274. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  275. mindspore/profiler/parser/profiler_info.py +11 -1
  276. mindspore/profiler/profiling.py +13 -5
  277. mindspore/rewrite/api/node.py +12 -12
  278. mindspore/rewrite/api/symbol_tree.py +11 -11
  279. mindspore/run_check/_check_version.py +1 -1
  280. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  281. mindspore/train/amp.py +4 -4
  282. mindspore/train/anf_ir_pb2.py +8 -2
  283. mindspore/train/callback/_backup_and_restore.py +2 -2
  284. mindspore/train/callback/_callback.py +4 -4
  285. mindspore/train/callback/_checkpoint.py +2 -2
  286. mindspore/train/callback/_early_stop.py +2 -2
  287. mindspore/train/callback/_landscape.py +4 -4
  288. mindspore/train/callback/_loss_monitor.py +2 -2
  289. mindspore/train/callback/_on_request_exit.py +2 -2
  290. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  291. mindspore/train/callback/_summary_collector.py +2 -2
  292. mindspore/train/callback/_time_monitor.py +2 -2
  293. mindspore/train/dataset_helper.py +8 -3
  294. mindspore/train/loss_scale_manager.py +2 -2
  295. mindspore/train/metrics/metric.py +3 -3
  296. mindspore/train/mind_ir_pb2.py +22 -17
  297. mindspore/train/model.py +15 -15
  298. mindspore/train/serialization.py +18 -18
  299. mindspore/train/summary/summary_record.py +7 -7
  300. mindspore/train/train_thor/convert_utils.py +3 -3
  301. mindspore/version.py +1 -1
  302. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  303. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +307 -260
  304. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
  305. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
  306. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
  307. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
  308. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
  309. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
  310. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
  311. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
  312. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
  313. /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
  314. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  315. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  316. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,116 @@
1
+ /**
2
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef ROTARY_POS_EMB_FP16
17
+ #define ROTARY_POS_EMB_FP16
18
+ #include "apply_rotary_pos_emb_base.h"
19
+ template <typename QK_DTYPE, typename COS_DTYPE, bool IF_COS_BROADCAST>
20
+ class RopeFp16 : public RopeBase<QK_DTYPE, COS_DTYPE, IF_COS_BROADCAST> {
21
+ public:
22
+ __aicore__ inline RopeFp16(RopeTilingData *tilingData) : RopeBase<QK_DTYPE, COS_DTYPE, IF_COS_BROADCAST>(tilingData) {
23
+ this->repeatSize_ = 128; // 128 = 256B / sizeof(half)
24
+ this->maxProcessNum_ = this->tilingData_->maxUbSize / sizeof(uint16_t);
25
+ this->repeatTimesQ_ = (this->tilingData_->hiddenSizeQ + this->repeatSize_ - 1) / this->repeatSize_;
26
+ this->repeatTimesK_ = (this->tilingData_->hiddenSizeK + this->repeatSize_ - 1) / this->repeatSize_;
27
+ headDimAlign_ = ((this->tilingData_->headDim + ELE_NUM_FP16 - 1) / ELE_NUM_FP16) * ELE_NUM_FP16;
28
+ this->alignHalfHeadDim_ = (this->rotateStride_ * NUM_TWO) % ELE_NUM_FP16;
29
+ this->hiddenSizeAlign_ = ((this->hiddenSize_ + this->repeatSize_ - 1) / this->repeatSize_) * this->repeatSize_;
30
+
31
+ this->cosPad_ = 0;
32
+ this->sinPad_ = this->cosPad_ + this->hiddenSizeAlign_;
33
+ this->negOne_ = this->sinPad_ + this->hiddenSizeAlign_;
34
+ this->oriPos_ = this->negOne_ + this->hiddenSizeAlign_;
35
+ this->padBefore_ = this->oriPos_ + this->hiddenSizeAlign_;
36
+ this->removeBefore_ = this->padBefore_ + this->hiddenSizeAlign_;
37
+ sinResPos_ = this->removeBefore_ + this->hiddenSizeAlign_;
38
+ this->repeatTimes_ = this->hiddenSizeAlign_ / this->repeatSize_;
39
+
40
+ this->syncOffset_ =
41
+ (this->tilingData_->headDim % ELE_NUM_FP16 == 0) ? this->hiddenSizeAlign_ : this->headNum_ * headDimAlign_;
42
+ this->offsetExtraGm_ = NUM_TWO * block_idx * this->syncOffset_;
43
+ this->pipe_.InitBuffer(outQueueCO2_, 1, ((this->maxProcessNum_ - this->batchSize_ * NUM_TWO) * sizeof(QK_DTYPE)));
44
+ AscendC::LocalTensor<QK_DTYPE> cache_perloop_ub_ = outQueueCO2_.AllocTensor<QK_DTYPE>();
45
+ commonUbuf_ = (__ubuf__ QK_DTYPE *)cache_perloop_ub_.GetPhyAddr();
46
+ }
47
+
48
+ __aicore__ inline void Process(__gm__ uint8_t *extraGm) {
49
+ if (this->tilingData_->cosFormat == 1) {
50
+ pipe_barrier((PIPE_ALL));
51
+ this->ExpandCosSin(commonUbuf_, this->cosGm_, (__gm__ COS_DTYPE *)extraGm);
52
+ this->cosGm_ = (__gm__ COS_DTYPE *)extraGm;
53
+ pipe_barrier((PIPE_ALL));
54
+ this->ExpandCosSin(commonUbuf_, this->sinGm_,
55
+ (__gm__ COS_DTYPE *)extraGm + this->tilingData_->ntokens * this->tilingData_->headDim);
56
+ this->sinGm_ = (__gm__ COS_DTYPE *)extraGm + this->tilingData_->ntokens * this->tilingData_->headDim;
57
+ extraGm =
58
+ extraGm + this->tilingData_->ntokens * this->tilingData_->headDim * 4; // sizeof(uint8_t) * 2 = sizeof(half)
59
+ pipe_barrier((PIPE_ALL));
60
+ }
61
+
62
+ this->ExpandNeg(commonUbuf_, sinResPos_, this->headNum_, this->repeatTimes_); // 根据是否对齐选择1 -1 還是 -1 0
63
+ for (uint32_t zz = 0; zz < this->dynamicRound_; ++zz) {
64
+ this->CosSinBroadcast(extraGm, zz, commonUbuf_, this->tilingData_->hiddenSizeQ); // cos sin 和 QK 无关
65
+
66
+ this->QkComm(this->qGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeQ +
67
+ zz * this->tilingData_->hiddenSizeQ,
68
+ extraGm, this->tilingData_->hiddenSizeQ, commonUbuf_, this->tilingData_->headNumQ);
69
+
70
+ if (this->alignRotary_ == 0) {
71
+ pipe_barrier((PIPE_V));
72
+ this->CalcRopeAlign(commonUbuf_, this->repeatTimesQ_, this->oriPos_, this->removeBefore_, this->padBefore_);
73
+ } else {
74
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
75
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
76
+ this->CalcRope(commonUbuf_, this->repeatTimesQ_, this->oriPos_, this->removeBefore_, this->padBefore_,
77
+ sinResPos_, this->padBefore_);
78
+ }
79
+ pipe_barrier((PIPE_ALL)); // 需要
80
+ copy_ubuf_to_gm(this->outQGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeQ +
81
+ zz * this->tilingData_->hiddenSizeQ,
82
+ commonUbuf_ + this->padBefore_, 0, 1, this->tilingData_->hiddenSizeQ / ELE_NUM_FP16, 0, 0);
83
+
84
+ set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
85
+ wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
86
+
87
+ this->QkComm(this->kGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeK +
88
+ zz * this->tilingData_->hiddenSizeK,
89
+ extraGm, this->tilingData_->hiddenSizeK, commonUbuf_, this->tilingData_->headNumK);
90
+
91
+ if (this->alignRotary_ == 0) {
92
+ pipe_barrier((PIPE_V));
93
+ this->CalcRopeAlign(commonUbuf_, this->repeatTimesK_, this->oriPos_, this->removeBefore_, this->padBefore_);
94
+ } else {
95
+ set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
96
+ wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
97
+ this->CalcRope(commonUbuf_, this->repeatTimesK_, this->oriPos_, this->removeBefore_, this->padBefore_,
98
+ sinResPos_, this->padBefore_);
99
+ }
100
+ pipe_barrier((PIPE_ALL)); // 需要
101
+ copy_ubuf_to_gm(this->outKGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeK +
102
+ zz * this->tilingData_->hiddenSizeK,
103
+ commonUbuf_ + this->padBefore_, 0, 1, this->tilingData_->hiddenSizeK / ELE_NUM_FP16, 0, 0);
104
+ set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
105
+ wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
106
+ }
107
+ }
108
+
109
+ private:
110
+ AscendC::TQue<AscendC::QuePosition::VECIN, 1> outQueueCO2_;
111
+ __ubuf__ QK_DTYPE *commonUbuf_{nullptr};
112
+ uint32_t headDimAlign_; // 对齐的headDim
113
+ uint32_t sinResPos_{0}; // fp32的buf中0 0 0 1 1 1的位置
114
+ };
115
+
116
+ #endif
@@ -14,38 +14,30 @@
14
14
  * limitations under the License.
15
15
  */
16
16
 
17
- #ifndef MS_KERNELS_INTERNAL_KERNEL_ASCENDC_APPLY_ROTARY_POS_EMB_TILING_DATA_H_
18
- #define MS_KERNELS_INTERNAL_KERNEL_ASCENDC_APPLY_ROTARY_POS_EMB_TILING_DATA_H_
17
+ #ifndef MS_KERNELS_INTERNAL_KERNEL_ASCENDC_ROPE_TILING_DATA_H_
18
+ #define MS_KERNELS_INTERNAL_KERNEL_ASCENDC_ROPE_TILING_DATA_H_
19
19
 
20
20
  #include <stdint.h>
21
21
 
22
- struct ApplyRotaryPosEmbTilingData {
22
+ struct RopeTilingData {
23
23
  uint32_t hiddenSizeQ{16};
24
24
  uint32_t hiddenSizeK{16};
25
- uint32_t headDim{1};
25
+ uint32_t headDim{1}; // qk头长度的最大值
26
26
  uint32_t headNumQ{1};
27
27
  uint32_t headNumK{1};
28
- uint32_t rotaryCoeff{4};
29
- uint32_t ntokens{1};
30
- uint32_t ropeFormat{0};
31
- uint32_t realCore{0};
32
- uint32_t cosFormat{0};
33
- uint32_t batch{32};
34
- uint32_t highPrecision{0};
35
- uint32_t maxUbSize{0};
28
+ uint32_t rotaryCoeff{4}; // 旋转系数
29
+ uint32_t ntokens{1}; // 总token数
30
+ uint32_t realCore{0}; // 实际用到核数
31
+ uint32_t cosFormat{0}; // 是否复用cos sin
32
+ uint32_t batch{32}; // 几个batch
33
+ uint32_t maxUbSize{0}; // 最大UB内存
34
+ uint32_t tilingId{0};
36
35
 
37
- int32_t ndim;
38
- int32_t qkDtype; // 0=fp16 1=bf16 2=fp32
39
- int32_t posDtype; // 0=i64 1=u64 2=i32 3=u32
40
-
41
- // int32_t batch;
42
- int32_t numHeadQ;
43
- int32_t numHeadK;
44
- // int32_t hiddenDim;
45
- int32_t seqLen;
46
- int32_t maxSeqLen;
47
-
48
- int32_t posSize; // seqLen==1 ? batch : seqLen
36
+ uint32_t seqLen;
37
+ uint32_t broadCastCos{0};
38
+ uint32_t posDtype;
39
+ uint32_t posSize;
40
+ uint32_t maxSeqLen;
49
41
  };
50
42
 
51
43
  #endif
@@ -0,0 +1,27 @@
1
+ /**
2
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef COMMON_VAL_H
17
+ #define COMMON_VAL_H
18
+ const constexpr uint32_t NUM_TWO = 2; // 2
19
+ const constexpr uint32_t BLK_SIZE = 32; // 一个block字节数
20
+ const constexpr uint32_t ELE_NUM_FP16 = 16; // 一个block fp16元素个数
21
+ const constexpr uint32_t ELE_NUM_FP32 = 8; // 一个block字节数 fp32元素个数
22
+ const constexpr uint32_t MAX_LEN_FP16 = 8192; // 非fp16情况下最大长度(hiddensize)
23
+ const constexpr uint8_t DEFAULT_REPEAT_STRIDE = 8; // 默认stride, 8 * 32 = 256
24
+ const constexpr int64_t REG_910B = 48; // 饱和模式寄存器位置
25
+ const constexpr int64_t REG_310P = 53; // 饱和模式寄存器位置
26
+ const constexpr int64_t SLICE_SIZE = 4096; // 切片大小
27
+ #endif
@@ -41,8 +41,6 @@ class AsdOpsImpl : public InternelKernelImpl {
41
41
  std::vector<uint64_t> GetWorkSpaceSize() override;
42
42
  int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
43
43
 
44
- void SetCacheInfo(const CacheInfo &cache_info);
45
-
46
44
  private:
47
45
  AsdOps::Tactic *InitAndGetTactic();
48
46
 
@@ -52,8 +50,6 @@ class AsdOpsImpl : public InternelKernelImpl {
52
50
  AsdOps::LaunchParam launch_param_;
53
51
  AsdOps::OpDesc op_desc_;
54
52
  bool validated_ = false;
55
-
56
- RunInfo run_info_;
57
53
  };
58
54
 
59
55
  } // namespace internal
@@ -24,7 +24,7 @@
24
24
  #include "asdops/tensor.h"
25
25
 
26
26
  #include "internal_kernel.h"
27
-
27
+ #include "param/attention_param.h"
28
28
  #include "acl_rt.h"
29
29
 
30
30
  #include <unordered_map>
@@ -49,6 +49,7 @@ class FlashAttentionScoreImpl : public InternelKernelImpl {
49
49
 
50
50
  private:
51
51
  uint64_t B, N, Q_S, KV_S, D, G, CORE_NUM;
52
+ int inner_precise, pre_tokens, next_tokens, sparse_mode;
52
53
  bool BFLOAT16, BSH, ALIBI, AMASK;
53
54
  const std::vector<Tensor *> *inputs_;
54
55
  const std::vector<Tensor *> *outputs_;
@@ -18,16 +18,12 @@ typedef struct {
18
18
  #pragma pack()
19
19
 
20
20
  #define MAX_CORE_NUM 25
21
- #define ATTENTION_DEBUG false // 开启时会对S/P写入调试数据
22
- #define BUFFER_NUM 2 // 核间流水数,暂不支持修改
23
-
24
- #if PA
25
- #define INC true
26
- #define OP_NAME PagedAttention
27
- #else
28
- #define INC false
21
+ #define ATTENTION_DEBUG false // 开启时会对S/P写入调试数据
22
+ #define ROWMAX true
29
23
  #define OP_NAME FlashAttentionScore
30
- #endif
24
+ #define BUFFER_NUM 2 // 核间流水数,暂不支持修改
25
+ constexpr uint64_t WORKSPACE_MAX_SEQLEN = 16384; // max seqlen
26
+ constexpr uint64_t WORKSPACE_SIZE = 128 * WORKSPACE_MAX_SEQLEN;
31
27
 
32
28
  #if BFLOAT16
33
29
  #define TYPE_NAME _bf16
@@ -41,14 +37,16 @@ typedef struct {
41
37
  #define LAYOUT_NAME _BNSD
42
38
  #endif
43
39
 
44
- #define CONCAT_(A, B, C, D) A##B##C##D
45
- #define CONCAT(A, B, C, D) CONCAT_(A, B, C, D)
46
- #define FUNC_NAME_AIC CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, _mix_aic)
47
- #define FUNC_NAME_AIV CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, _mix_aiv)
48
-
49
- #if INC
50
- #define CORE_PER_KV_HEAD 4 // 增量推理时开启,每个kv_head切分成多少个任务计算
40
+ #if LOWER_TRIANGLE
41
+ #define TRI_NAME _tri
42
+ #else
43
+ #define TRI_NAME _full
51
44
  #endif
45
+
46
+ #define CONCAT_(A, B, C, D, E) A##B##C##D##E
47
+ #define CONCAT(A, B, C, D, E) CONCAT_(A, B, C, D, E)
48
+ #define FUNC_NAME_AIC CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aic)
49
+ #define FUNC_NAME_AIV CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aiv)
52
50
 
53
51
  // **************mask patten模式**************//
54
52
  // 第一种:下三角,开启LOWER_TRIANGLE时会直接采用下三角,不依赖mask
@@ -63,9 +61,7 @@ typedef struct {
63
61
  // 第四种:全矩阵,LOWER_TRIANGLE、BLOCK_SPARSE和AMASK如果全部关闭,则此attention采用全矩阵运算,不抑制S中的元素
64
62
  // *******************************************//
65
63
 
66
- constexpr uint64_t WORKSPACE_MAX_SEQLEN = 10240; // max seqlen: 10240
67
- constexpr uint64_t WORKSPACE_MAX_SEQLEN_BLOCK = WORKSPACE_MAX_SEQLEN / 16; // max seqlen: 10240
68
- constexpr uint64_t WORKSPACE_SIZE = 128 * WORKSPACE_MAX_SEQLEN;
64
+ constexpr uint64_t WORKSPACE_MAX_SEQLEN_BLOCK = WORKSPACE_MAX_SEQLEN / 16;
69
65
  constexpr uint64_t BUFFER_SIZE = MAX_CORE_NUM * WORKSPACE_SIZE * sizeof(uint16_t);
70
66
 
71
67
  #endif
@@ -26,15 +26,13 @@ struct GeLUTilingData {
26
26
  uint32_t tailBlockTileNum{0};
27
27
  };
28
28
  static std::ostream &operator<<(std::ostream &os, const GeLUTilingData &dt) {
29
- os << "blockDims:" << dt.blockDims << std::endl;
30
- os << "totalLength:" << dt.totalLength << std::endl;
31
- os << "blockLength:" << dt.blockLength << std::endl;
32
- os << "tileLength:" << dt.tileLength << std::endl;
33
- os << "tileNum:" << dt.tileNum << std::endl;
34
- os << "tailBlockTileNum:" << dt.tailBlockTileNum << std::endl;
35
- os << "tilingKey:" << dt.tilingKey << std::endl;
36
- // os << "axisDim:" << dt.axisDim << std::endl;
37
- // os << "splitNum:" << dt.splitNum << std::endl;
29
+ os << "blockDims:" << dt.blockDims;
30
+ os << ", totalLength:" << dt.totalLength;
31
+ os << ", blockLength:" << dt.blockLength;
32
+ os << ", tileLength:" << dt.tileLength;
33
+ os << ", tileNum:" << dt.tileNum;
34
+ os << ", tailBlockTileNum:" << dt.tailBlockTileNum;
35
+ os << ", tilingKey:" << dt.tilingKey;
38
36
  return os;
39
37
  }
40
38
  #endif // MS_KERNELS_INTERNAL_ASCENDC_GELU_TILING_H
@@ -0,0 +1,58 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef LCCL_WRAPPER_H_
17
+ #define LCCL_WRAPPER_H_
18
+
19
+ #include <memory>
20
+ #include "lccl.h"
21
+
22
+ #ifdef __cplusplus
23
+ extern "C" {
24
+ #endif
25
+
26
+ using namespace Lcal;
27
+ using LcclComm = std::shared_ptr<Lccl>;
28
+ enum class LcclResult {
29
+ LCAL_SUCCESS = 0,
30
+ LCAL_ERROR_NOT_INITIALIZED = -1,
31
+ LCAL_ERROR_ASDRT = -2,
32
+ LCAL_ERROR_PARA_CHECK_FAIL = -3,
33
+ LCAL_ERROR_INTERNAL = -4,
34
+ LCAL_ERROR_TIMEOUT = -5,
35
+ LCCL_ERROR_INIT_HCCL_FAILED = -6
36
+ };
37
+
38
+ extern LcclResult LcclCommInitRank(uint32_t nRanks, uint32_t rank, LcclComm *comm);
39
+
40
+ extern LcclResult LcclAllReduce(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType,
41
+ HcclReduceOp op, aclrtStream stream);
42
+
43
+ extern LcclResult LcclReduceScatter(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType,
44
+ HcclReduceOp op, aclrtStream stream);
45
+
46
+ extern LcclResult LcclAllGather(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream);
47
+
48
+ extern LcclResult LcclAll2All(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream);
49
+
50
+ extern LcclResult LcclBroadcast(void *buff, int64_t count, HcclDataType dataType, int32_t root, aclrtStream stream);
51
+
52
+ extern LcclResult LcclCommDestroy(LcclComm comm);
53
+
54
+ #ifdef __cplusplus
55
+ }
56
+ #endif
57
+
58
+ #endif // LCCL_WRAPPER_H_
@@ -24,11 +24,13 @@
24
24
  #include "asdops/tensor.h"
25
25
 
26
26
  #include "utils.h"
27
- #include "pp_matmul_info.h"
28
- #include "tiling_data.h"
29
- #include "pp_matmul_common_tiling.h"
27
+ #include "backend_param.h"
28
+ #include "param/matmul_ext_param.h"
29
+ #include "matmul_common/pp_matmul_info.h"
30
+ #include "matmul_common/tiling_utils.h"
31
+ #include "matmul_common/tiling_data.h"
32
+ #include "matmul_common/pp_matmul_common_tiling.h"
30
33
  #include "tune_repo/utils.h"
31
-
32
34
  #include "internal_kernel.h"
33
35
 
34
36
  #include "acl_rt.h"
@@ -37,6 +39,8 @@
37
39
  namespace mindspore {
38
40
  namespace internal {
39
41
 
42
+ using namespace tiling;
43
+
40
44
  enum class MatMulAlgo { PP = 0, LLM_CUSTOM = 1 };
41
45
 
42
46
  class MatMulImpl : public InternelKernelImpl {
@@ -48,12 +52,15 @@ class MatMulImpl : public InternelKernelImpl {
48
52
  int Launch() override;
49
53
  size_t GetTilingBufSize() override;
50
54
  int Tiling(HostRawBuf &tilingBuf) override;
51
- int TilingPp(HostRawBuf &tilingBuf, uint64_t tilingId);
52
- int TilingLLMCustom(HostRawBuf &tilingBuf, uint64_t tilingId);
55
+ void TilingBasicFromPp(uint32_t &blockDim, PpTilingData &tilingdata);
56
+ int TilingPp(HostRawBuf &tilingBuf, uint64_t tilingId, const uint32_t &blockDim, const PpTilingData &tilingdata);
57
+ int TilingLLMCustom(HostRawBuf &tilingBuf, uint64_t tilingId, const uint32_t &blockDim,
58
+ const PpTilingData &tilingdata);
53
59
  std::vector<uint64_t> GetWorkSpaceSize() override;
54
60
  int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
55
- bool CanUseLLMCustom();
56
- int GetCoresFromValidLLMCustomTile(const std::vector<int> &cand);
61
+ bool UseCustomMatMul();
62
+ void GetTunedKey();
63
+ void SetTunedValueCustom(const std::vector<int> &tuned_config);
57
64
 
58
65
  private:
59
66
  uint32_t m_, k_, n_;
@@ -61,7 +68,11 @@ class MatMulImpl : public InternelKernelImpl {
61
68
  MatMulAlgo algo_ = MatMulAlgo::PP;
62
69
  DeviceRawBuf tiling_addr_;
63
70
  std::string soc_{"Ascend910B2"};
71
+ HardwareInfo hwInfo_;
72
+ CustomMatmulTilingData t_;
73
+ std::vector<int> tune_key_;
64
74
  REPO tuningTable_;
75
+ REPO tuningTableCustom_;
65
76
  TensorDType input_dtype_;
66
77
  TensorDType output_dtype_;
67
78
  int block_dim_ = 0;
@@ -17,6 +17,8 @@
17
17
  #ifndef MATMUL_COMMMON_TILING_H
18
18
  #define MATMUL_COMMMON_TILING_H
19
19
 
20
+ #include <cmath>
21
+ #include <iostream>
20
22
  #include "pp_matmul_info.h"
21
23
 
22
24
  namespace mindspore {
@@ -100,9 +102,12 @@ inline __attribute__((always_inline)) float CostFunc(const HardwareType &hwInfo,
100
102
  template <bool PRI_FLAG, typename OpShareType, typename TilingType, typename HardwareType, typename MatMulInfoType>
101
103
  void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareType &hwInfo, const MatMulInfoType &mmInfo,
102
104
  bool compressFlag = false, const uint32_t tilingN = 1) {
105
+ using namespace std;
103
106
  float costMin = 1;
104
- uint32_t priAxes = Round<16>(PRI_FLAG ? opShape.m : opShape.n);
105
- uint32_t axes = Round<16>(PRI_FLAG ? opShape.n : opShape.m);
107
+ const uint32_t CONST_16 = 16;
108
+ uint32_t roundBase = pow(2, ceil(log(CeilDiv(PRI_FLAG ? opShape.n : opShape.m, CONST_16)))) * CONST_16;
109
+ uint32_t priAxes = RoundUp(PRI_FLAG ? opShape.m : opShape.n, CONST_16);
110
+ uint32_t axes = RoundUp(PRI_FLAG ? opShape.n : opShape.m, roundBase);
106
111
  float axes0Max = static_cast<float>(AXES_ALIGN_SIZE) / mmInfo.inDtype;
107
112
 
108
113
  uint32_t n0TilingInit =
@@ -129,10 +134,15 @@ void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareTyp
129
134
  }
130
135
  opShape.m0 = PRI_FLAG ? priAxes0 : axes0;
131
136
  opShape.n0 = PRI_FLAG ? axes0 : priAxes0;
137
+ if ((mmInfo.qkv_n0 + mmInfo.qkv_n1 + mmInfo.qkv_n2 != 0) &&
138
+ (mmInfo.qkv_n0 < opShape.n0 || mmInfo.qkv_n1 < opShape.n0 ||
139
+ (mmInfo.qkv_n2 < opShape.n0 && mmInfo.qkv_n2 > 1))) {
140
+ continue;
141
+ }
132
142
  float cost = CostFunc<HardwareType, OpShareType>(hwInfo, opShape);
133
143
  if (cost < costMin) {
134
144
  costMin = cost;
135
- tilingParam.SetBaseOp(hwInfo.coreNum, opShape.m0, opShape.n0);
145
+ tilingParam.SetBaseOp(hwInfo.coreNum, opShape.m0, opShape.n0, mmInfo.qkv_n0, mmInfo.qkv_n1, mmInfo.qkv_n2);
136
146
  }
137
147
  }
138
148
  }
@@ -140,7 +150,7 @@ void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareTyp
140
150
 
141
151
  template <typename PpTilingDataType>
142
152
  uint32_t Swizzl(PpTilingDataType &tilingData) {
143
- uint32_t swizzleDirect = 0;
153
+ uint32_t swizzlDirect = 0;
144
154
  uint32_t swizzlCount = 1;
145
155
  float m0 = tilingData.opShape.m0;
146
156
  float n0 = tilingData.opShape.n0;
@@ -154,14 +164,14 @@ uint32_t Swizzl(PpTilingDataType &tilingData) {
154
164
  float cost;
155
165
  // B0 + A < A0 + B
156
166
  if (i * n0 + m < m0 * c + n) {
157
- swizzleDirect = 1; // Nz
167
+ swizzlDirect = 1; // Nz
158
168
  cost = n0 * i + m0 * c;
159
169
  if (cost <= mincost) {
160
170
  mincost = cost;
161
171
  swizzlCount = i;
162
172
  }
163
173
  } else {
164
- swizzleDirect = 0; // Zn
174
+ swizzlDirect = 0; // Zn
165
175
  cost = m0 * i + n0 * c;
166
176
  if (cost < mincost) {
167
177
  mincost = cost;
@@ -169,9 +179,9 @@ uint32_t Swizzl(PpTilingDataType &tilingData) {
169
179
  }
170
180
  }
171
181
  }
172
- tilingData.swizzleDirect = swizzleDirect;
182
+ tilingData.swizzlDirect = swizzlDirect;
173
183
  tilingData.swizzlCount = swizzlCount;
174
- return swizzleDirect;
184
+ return swizzlDirect;
175
185
  }
176
186
 
177
187
  } // namespace tiling
@@ -26,7 +26,10 @@ namespace internal {
26
26
  namespace tiling {
27
27
  struct MatMulInfo {
28
28
  uint32_t batchSize{0};
29
- uint32_t m{0}; // 实际输入的 m
29
+ uint32_t m{0}; // 实际输入的 m
30
+ uint32_t qkv_n0{0};
31
+ uint32_t qkv_n1{0};
32
+ uint32_t qkv_n2{0};
30
33
  uint32_t n{0}; // 实际输入的 n
31
34
  uint32_t k{0}; // 实际输入的 k
32
35
  bool transA{0}; // false: 0, true: 1
@@ -57,12 +60,12 @@ struct PpTilingData {
57
60
  uint32_t swizzlCount{1};
58
61
  uint32_t tilingKey{0};
59
62
  uint32_t blockDim{1};
60
- uint32_t swizzleDirect{0};
63
+ uint32_t swizzlDirect{0};
61
64
  uint32_t splitk{0};
62
65
 
63
66
  void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n);
64
- void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase);
65
- void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK);
67
+ void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, uint32_t qkv_n0, uint32_t qkv_n1, uint32_t qkv_n2);
68
+ void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzlDirect, uint32_t enSplitK);
66
69
  uint32_t End(const MatMulInfo &mmInfo);
67
70
  };
68
71
  } // namespace tiling
@@ -18,7 +18,6 @@
18
18
  #define MATMUL_TILING_DATA_H
19
19
 
20
20
  #include <stdint.h>
21
- #include <algorithm>
22
21
 
23
22
  namespace mindspore {
24
23
  namespace internal {
@@ -38,8 +37,16 @@ struct PpMatmulTilingData {
38
37
  uint32_t swizzlCount{0};
39
38
  uint32_t tilingKey{0};
40
39
  uint32_t blockDim{1};
41
- uint32_t swizzleDirect{0};
40
+ uint32_t swizzlDirect{0};
42
41
  uint32_t splitk{0};
42
+ uint32_t enShuffleK{0};
43
+ uint32_t unused0{0};
44
+ uint32_t unused1{0};
45
+ uint32_t unused2{0};
46
+ uint32_t unused3{0};
47
+ uint32_t unused4{0};
48
+ uint32_t unused5{0};
49
+ uint32_t unused6{0};
43
50
  uint32_t tilingId{0};
44
51
  };
45
52
 
@@ -60,18 +67,49 @@ struct CustomMatmulTilingData {
60
67
  uint32_t BaseKNum{0};
61
68
  uint32_t BaseNNum{0};
62
69
  uint32_t MmadM{0};
70
+ uint32_t MmadK{0};
71
+ uint32_t MmadN{0};
72
+ uint32_t fractal_k_num{0};
73
+ uint32_t FractalKInBlockNum{0};
74
+ uint32_t PartKInMmad{0};
75
+ uint32_t TransA{0};
76
+ uint32_t TransB{0};
77
+ uint32_t shuffleFlag{0};
63
78
  uint32_t tilingId{0};
79
+ };
80
+
81
+ constexpr size_t maxTilingBufSize = sizeof(CustomMatmulTilingData);
82
+
83
+ struct MatmulStridedSliceFusionTilingData {
84
+ uint32_t tilingId{0};
85
+ uint32_t BlockDimM{0};
86
+ uint32_t BlockDimN{0};
87
+ uint32_t BlockTotal{0};
88
+ uint32_t M{0};
89
+ uint32_t K{0};
90
+ uint32_t N{0};
91
+ uint32_t N0{0};
92
+ uint32_t N1{0};
93
+ uint32_t N2{0};
94
+ uint32_t BaseM{0};
95
+ uint32_t BaseK{0};
96
+ uint32_t BaseN{0};
97
+ uint32_t BlockLenM{0};
98
+ uint32_t BlockLenK{0};
99
+ uint32_t BlockLenN{0};
100
+ uint32_t BaseMNum{0};
101
+ uint32_t BaseKNum{0};
102
+ uint32_t BaseNNum{0};
103
+ uint32_t MmadM{0};
64
104
  uint32_t MmadK{0};
65
105
  uint32_t MmadN{0};
66
- uint32_t FractalKNum{0};
67
106
  uint32_t FractalKInBlockNum{0};
68
- uint32_t PartKInL0A{2};
107
+ uint32_t PartKInMmad{2};
69
108
  uint32_t TransA{0};
70
109
  uint32_t TransB{1};
110
+ uint32_t shuffleFlag{0};
71
111
  };
72
112
 
73
- constexpr size_t maxTilingBufSize = std::max(sizeof(PpMatmulTilingData), sizeof(CustomMatmulTilingData));
74
-
75
113
  } // namespace tiling
76
114
  } // namespace internal
77
115
  } // namespace mindspore