mindspore 2.3.0rc1__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (316) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/boost/boost_cell_wrapper.py +1 -1
  12. mindspore/boost/group_loss_scale_manager.py +1 -1
  13. mindspore/common/__init__.py +4 -2
  14. mindspore/common/_register_for_recompute.py +48 -0
  15. mindspore/common/_stub_tensor.py +1 -0
  16. mindspore/common/api.py +56 -4
  17. mindspore/common/dtype.py +5 -3
  18. mindspore/common/dump.py +2 -2
  19. mindspore/common/hook_handle.py +51 -4
  20. mindspore/common/initializer.py +1 -1
  21. mindspore/common/jit_config.py +17 -6
  22. mindspore/common/parameter.py +7 -2
  23. mindspore/common/recompute.py +247 -0
  24. mindspore/common/sparse_tensor.py +2 -2
  25. mindspore/common/symbol.py +1 -1
  26. mindspore/common/tensor.py +74 -36
  27. mindspore/communication/__init__.py +3 -3
  28. mindspore/communication/management.py +30 -30
  29. mindspore/context.py +28 -15
  30. mindspore/dataset/__init__.py +5 -5
  31. mindspore/dataset/audio/__init__.py +2 -2
  32. mindspore/dataset/audio/transforms.py +51 -51
  33. mindspore/dataset/callback/ds_callback.py +2 -2
  34. mindspore/dataset/engine/cache_client.py +1 -1
  35. mindspore/dataset/engine/datasets.py +3 -3
  36. mindspore/dataset/engine/datasets_audio.py +14 -14
  37. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  38. mindspore/dataset/engine/datasets_text.py +38 -38
  39. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  40. mindspore/dataset/engine/datasets_vision.py +68 -68
  41. mindspore/dataset/text/__init__.py +3 -3
  42. mindspore/dataset/text/transforms.py +26 -26
  43. mindspore/dataset/transforms/__init__.py +1 -1
  44. mindspore/dataset/vision/__init__.py +3 -3
  45. mindspore/dataset/vision/transforms.py +92 -92
  46. mindspore/dataset/vision/utils.py +1 -1
  47. mindspore/experimental/optim/adadelta.py +2 -2
  48. mindspore/experimental/optim/adagrad.py +2 -2
  49. mindspore/experimental/optim/adam.py +2 -2
  50. mindspore/experimental/optim/adamax.py +2 -2
  51. mindspore/experimental/optim/adamw.py +2 -2
  52. mindspore/experimental/optim/asgd.py +2 -2
  53. mindspore/experimental/optim/lr_scheduler.py +24 -20
  54. mindspore/experimental/optim/nadam.py +2 -2
  55. mindspore/experimental/optim/optimizer.py +1 -1
  56. mindspore/experimental/optim/radam.py +2 -2
  57. mindspore/experimental/optim/rmsprop.py +2 -2
  58. mindspore/experimental/optim/rprop.py +2 -2
  59. mindspore/experimental/optim/sgd.py +2 -2
  60. mindspore/hal/stream.py +2 -0
  61. mindspore/include/mindapi/base/types.h +5 -0
  62. mindspore/lib/libdnnl.so.2 +0 -0
  63. mindspore/lib/libmindspore.so +0 -0
  64. mindspore/lib/libmindspore_backend.so +0 -0
  65. mindspore/lib/libmindspore_common.so +0 -0
  66. mindspore/lib/libmindspore_core.so +0 -0
  67. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  68. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  69. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  70. mindspore/lib/libmindspore_shared_lib.so +0 -0
  71. mindspore/lib/libopencv_core.so.4.5 +0 -0
  72. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  73. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  75. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  76. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  77. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  78. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  79. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  80. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  81. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  82. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  83. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
  84. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  85. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  86. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
  87. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
  88. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
  89. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
  90. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
  91. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
  92. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
  93. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
  94. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
  164. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
  165. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
  166. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
  167. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
  168. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  169. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  170. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  171. mindspore/log.py +2 -2
  172. mindspore/mint/__init__.py +457 -0
  173. mindspore/mint/nn/__init__.py +430 -0
  174. mindspore/mint/nn/functional.py +424 -0
  175. mindspore/mint/optim/__init__.py +24 -0
  176. mindspore/mint/optim/adamw.py +186 -0
  177. mindspore/multiprocessing/__init__.py +4 -0
  178. mindspore/nn/__init__.py +3 -0
  179. mindspore/nn/cell.py +51 -47
  180. mindspore/nn/extend/__init__.py +29 -0
  181. mindspore/nn/extend/basic.py +140 -0
  182. mindspore/nn/extend/embedding.py +143 -0
  183. mindspore/nn/extend/layer/__init__.py +27 -0
  184. mindspore/nn/extend/layer/normalization.py +107 -0
  185. mindspore/nn/extend/pooling.py +117 -0
  186. mindspore/nn/generator.py +297 -0
  187. mindspore/nn/layer/basic.py +109 -1
  188. mindspore/nn/layer/container.py +2 -2
  189. mindspore/nn/layer/conv.py +6 -6
  190. mindspore/nn/layer/embedding.py +1 -1
  191. mindspore/nn/layer/normalization.py +21 -43
  192. mindspore/nn/layer/padding.py +4 -0
  193. mindspore/nn/optim/ada_grad.py +2 -2
  194. mindspore/nn/optim/adadelta.py +1 -1
  195. mindspore/nn/optim/adafactor.py +1 -1
  196. mindspore/nn/optim/adam.py +7 -7
  197. mindspore/nn/optim/adamax.py +2 -2
  198. mindspore/nn/optim/adasum.py +2 -2
  199. mindspore/nn/optim/asgd.py +2 -2
  200. mindspore/nn/optim/ftrl.py +1 -1
  201. mindspore/nn/optim/lamb.py +3 -3
  202. mindspore/nn/optim/lars.py +1 -1
  203. mindspore/nn/optim/lazyadam.py +2 -2
  204. mindspore/nn/optim/momentum.py +2 -2
  205. mindspore/nn/optim/optimizer.py +2 -2
  206. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  207. mindspore/nn/optim/rmsprop.py +2 -2
  208. mindspore/nn/optim/rprop.py +2 -2
  209. mindspore/nn/optim/sgd.py +2 -2
  210. mindspore/nn/optim/thor.py +2 -2
  211. mindspore/nn/wrap/cell_wrapper.py +9 -9
  212. mindspore/nn/wrap/grad_reducer.py +5 -5
  213. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  214. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  215. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  216. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  217. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  218. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  219. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  220. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  221. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  222. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  223. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  224. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  225. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  226. mindspore/ops/extend/__init__.py +9 -1
  227. mindspore/ops/extend/array_func.py +134 -27
  228. mindspore/ops/extend/math_func.py +3 -3
  229. mindspore/ops/extend/nn_func.py +363 -2
  230. mindspore/ops/function/__init__.py +19 -2
  231. mindspore/ops/function/array_func.py +463 -439
  232. mindspore/ops/function/clip_func.py +7 -18
  233. mindspore/ops/function/grad/grad_func.py +5 -5
  234. mindspore/ops/function/linalg_func.py +4 -4
  235. mindspore/ops/function/math_func.py +260 -243
  236. mindspore/ops/function/nn_func.py +825 -62
  237. mindspore/ops/function/random_func.py +73 -4
  238. mindspore/ops/function/sparse_unary_func.py +1 -1
  239. mindspore/ops/function/vmap_func.py +1 -1
  240. mindspore/ops/functional.py +2 -2
  241. mindspore/ops/op_info_register.py +1 -31
  242. mindspore/ops/operations/__init__.py +2 -3
  243. mindspore/ops/operations/_grad_ops.py +2 -107
  244. mindspore/ops/operations/_inner_ops.py +5 -5
  245. mindspore/ops/operations/_sequence_ops.py +2 -2
  246. mindspore/ops/operations/array_ops.py +11 -233
  247. mindspore/ops/operations/comm_ops.py +32 -32
  248. mindspore/ops/operations/custom_ops.py +7 -89
  249. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  250. mindspore/ops/operations/math_ops.py +13 -163
  251. mindspore/ops/operations/nn_ops.py +9 -316
  252. mindspore/ops/operations/random_ops.py +1 -1
  253. mindspore/ops/operations/sparse_ops.py +3 -3
  254. mindspore/ops/primitive.py +2 -2
  255. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  256. mindspore/ops_generate/arg_handler.py +24 -0
  257. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  258. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  259. mindspore/ops_generate/pyboost_utils.py +2 -17
  260. mindspore/parallel/__init__.py +3 -2
  261. mindspore/parallel/_auto_parallel_context.py +106 -1
  262. mindspore/parallel/_parallel_serialization.py +34 -2
  263. mindspore/parallel/_utils.py +16 -0
  264. mindspore/parallel/algo_parameter_config.py +4 -4
  265. mindspore/parallel/checkpoint_transform.py +249 -77
  266. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  267. mindspore/parallel/parameter_broadcast.py +1 -1
  268. mindspore/parallel/shard.py +1 -1
  269. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  270. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  271. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  272. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  273. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  274. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  275. mindspore/profiler/parser/profiler_info.py +11 -1
  276. mindspore/profiler/profiling.py +13 -5
  277. mindspore/rewrite/api/node.py +12 -12
  278. mindspore/rewrite/api/symbol_tree.py +11 -11
  279. mindspore/run_check/_check_version.py +1 -1
  280. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  281. mindspore/train/amp.py +4 -4
  282. mindspore/train/anf_ir_pb2.py +8 -2
  283. mindspore/train/callback/_backup_and_restore.py +2 -2
  284. mindspore/train/callback/_callback.py +4 -4
  285. mindspore/train/callback/_checkpoint.py +2 -2
  286. mindspore/train/callback/_early_stop.py +2 -2
  287. mindspore/train/callback/_landscape.py +4 -4
  288. mindspore/train/callback/_loss_monitor.py +2 -2
  289. mindspore/train/callback/_on_request_exit.py +2 -2
  290. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  291. mindspore/train/callback/_summary_collector.py +2 -2
  292. mindspore/train/callback/_time_monitor.py +2 -2
  293. mindspore/train/dataset_helper.py +8 -3
  294. mindspore/train/loss_scale_manager.py +2 -2
  295. mindspore/train/metrics/metric.py +3 -3
  296. mindspore/train/mind_ir_pb2.py +22 -17
  297. mindspore/train/model.py +15 -15
  298. mindspore/train/serialization.py +18 -18
  299. mindspore/train/summary/summary_record.py +7 -7
  300. mindspore/train/train_thor/convert_utils.py +3 -3
  301. mindspore/version.py +1 -1
  302. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  303. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +307 -260
  304. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
  305. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
  306. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
  307. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
  308. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
  309. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
  310. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
  311. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
  312. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
  313. /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
  314. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  315. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  316. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef MS_KERNELS_INTERNAL_KERNEL_SUB_IMPL_H_
17
+ #define MS_KERNELS_INTERNAL_KERNEL_SUB_IMPL_H_
18
+
19
+ #include <vector>
20
+ #include "include/internal_kernel.h"
21
+ #include "asdops/types.h"
22
+
23
+ namespace mindspore {
24
+ namespace internal {
25
+ class SubImpl : public InternelKernelImpl {
26
+ public:
27
+ SubImpl(const OpParamPtr &param) : InternelKernelImpl(param) {}
28
+ virtual ~SubImpl() {}
29
+ bool Init(const ValidateInfo &info) override;
30
+ void SetStream(const void *stream_ptr) override;
31
+ void SetDeviceTilingBuf(const DeviceRawBuf &tilingBuf) override;
32
+ int Launch() override;
33
+ uint64_t GetTilingBufSize() override;
34
+ int Tiling(HostRawBuf &tilingBuf) override;
35
+ std::vector<uint64_t> GetWorkSpaceSize() override;
36
+ int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
37
+
38
+ private:
39
+ int32_t GetMaxUbCount(uint32_t in_dtype);
40
+
41
+ private:
42
+ void *stream_ptr_ = nullptr;
43
+ uint8_t *device_tiling_ = nullptr;
44
+ };
45
+ } // namespace internal
46
+ } // namespace mindspore
47
+ #endif // MS_KERNELS_INTERNAL_KERNEL_ADD_IMPL_H_
@@ -0,0 +1,25 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef MS_KERNELS_INTERNAL_ASCENDC_SUB_TILING_H
17
+ #define MS_KERNELS_INTERNAL_ASCENDC_SUB_TILING_H
18
+ #include "utils/elewise_tiling.h"
19
+ namespace mindspore::internal {
20
+ struct SubTilingData : public ElewiseTailTilingData {
21
+ uint32_t broadcast_mode_{0};
22
+ uint32_t input_dtype_{0};
23
+ };
24
+ } // namespace mindspore::internal
25
+ #endif // MS_KERNELS_INTERNAL_ASCENDC_ADD_TILING_H
@@ -18,8 +18,6 @@
18
18
 
19
19
  #include "tune_repo/utils.h"
20
20
 
21
- static const int SMALL_M_MARK = -1;
22
-
23
21
  /// The key and value of this map are arranged in the following format.
24
22
  /// - key: {b_shape, m_shape, k_shape, n_shape, trans_a, trans_b, inDtype, outDtype}
25
23
  /// - value : {m0, n0, k0, mLoop, nLoop, kLoop, coreLoop, blockDim, swizzlCount, swizzlDirect}
@@ -55,32 +53,334 @@ static REPO tlmTuneConfig910B2{
55
53
  {{1, 4, 6912, 11264, 0, 1, 2, 2}, {16, 96, 1024, 1, 118, 7, 118, 24, 169, 1}},
56
54
  };
57
55
 
58
- static REPO MatMulQkvTuneConfig910B2{
59
- // prefill key: seqlen, k, n0, n1, n2, ta, tb, in_type, out_type
60
- {{1024, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 256, 64, 128}},
61
- {{1024, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 256, 64, 128}},
62
- {{2048, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 256, 64, 128}},
63
- {{2048, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 256, 64, 128}},
64
- {{4096, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 256, 64, 128}},
65
- {{4096, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 256, 64, 128}},
56
+ static REPO MatMulTuneConfig910B2{
57
+ {{4096, 4096, 11008, 0, 0, 0, 1, 1, 1}, {8, 3, 256, 256, 128, 64}},
58
+ {{4096, 4096, 11008, 0, 0, 0, 1, 27, 27}, {8, 3, 256, 256, 128, 64}},
59
+ {{2048, 4096, 11008, 0, 0, 0, 1, 1, 1}, {8, 3, 256, 256, 128, 64}},
60
+ {{2048, 4096, 11008, 0, 0, 0, 1, 27, 27}, {8, 3, 256, 256, 128, 64}},
61
+ {{4096, 11008, 4096, 0, 0, 0, 1, 1, 1}, {3, 8, 128, 256, 256, 64}},
62
+ {{4096, 11008, 4096, 0, 0, 0, 1, 27, 27}, {3, 8, 128, 256, 256, 64}},
63
+ {{1024, 11008, 4096, 0, 0, 0, 1, 1, 1}, {2, 12, 176, 256, 176, 64}},
64
+ {{1024, 11008, 4096, 0, 0, 0, 1, 27, 27}, {2, 12, 176, 256, 176, 64}},
65
+ {{4096, 7168, 8192, 0, 0, 0, 1, 1, 1}, {3, 8, 128, 256, 256, 64}},
66
+ {{4096, 7168, 8192, 0, 0, 0, 1, 27, 27}, {3, 8, 128, 256, 256, 64}},
67
+ {{2048, 7168, 8192, 0, 0, 0, 1, 1, 1}, {4, 6, 256, 256, 128, 64}},
68
+ {{2048, 7168, 8192, 0, 0, 0, 1, 27, 27}, {4, 6, 256, 256, 128, 64}},
69
+ {{1024, 7168, 8192, 0, 0, 0, 1, 1, 1}, {4, 6, 256, 256, 128, 64}},
70
+ {{1024, 7168, 8192, 0, 0, 0, 1, 27, 27}, {4, 6, 256, 256, 128, 64}},
71
+ {{16, 7168, 8192, 0, 0, 0, 1, 1, 1}, {1, 24, 16, 256, 176, 64}},
72
+ {{16, 7168, 8192, 0, 0, 0, 1, 27, 27}, {1, 24, 16, 256, 176, 64}},
73
+ {{4096, 8192, 7168, 0, 0, 0, 1, 1, 1}, {8, 3, 256, 256, 128, 64}},
74
+ {{4096, 8192, 7168, 0, 0, 0, 1, 27, 27}, {8, 3, 256, 256, 128, 64}},
75
+ {{2048, 8192, 7168, 0, 0, 0, 1, 1, 1}, {3, 8, 240, 256, 128, 64}},
76
+ {{2048, 8192, 7168, 0, 0, 0, 1, 27, 27}, {3, 8, 240, 256, 128, 64}},
77
+ {{1024, 8192, 7168, 0, 0, 0, 1, 1, 1}, {2, 12, 256, 256, 128, 64}},
78
+ {{1024, 8192, 7168, 0, 0, 0, 1, 27, 27}, {2, 12, 256, 256, 128, 64}},
79
+ {{4096, 2048, 8192, 0, 0, 0, 1, 1, 1}, {3, 8, 128, 256, 256, 64}},
80
+ {{4096, 2048, 8192, 0, 0, 0, 1, 27, 27}, {3, 8, 128, 256, 256, 64}},
81
+ {{2048, 2048, 8192, 0, 0, 0, 1, 1, 1}, {8, 3, 128, 256, 256, 64}},
82
+ {{2048, 2048, 8192, 0, 0, 0, 1, 27, 27}, {8, 3, 128, 256, 256, 64}},
83
+ {{1024, 2048, 8192, 0, 0, 0, 1, 1, 1}, {8, 3, 128, 256, 256, 64}},
84
+ {{1024, 2048, 8192, 0, 0, 0, 1, 27, 27}, {8, 3, 128, 256, 256, 64}},
85
+ {{4096, 8192, 3584, 0, 0, 0, 1, 1, 1}, {6, 4, 240, 256, 128, 64}},
86
+ {{4096, 8192, 3584, 0, 0, 0, 1, 27, 27}, {6, 4, 240, 256, 128, 64}},
87
+ {{2048, 8192, 3584, 0, 0, 0, 1, 1, 1}, {4, 6, 256, 256, 128, 64}},
88
+ {{2048, 8192, 3584, 0, 0, 0, 1, 27, 27}, {4, 6, 256, 256, 128, 64}},
89
+ {{512, 8192, 3584, 0, 0, 0, 1, 1, 1}, {1, 23, 176, 256, 160, 64}},
90
+ {{512, 8192, 3584, 0, 0, 0, 1, 27, 27}, {1, 23, 176, 256, 160, 64}},
91
+ {{4096, 3584, 8192, 0, 0, 0, 1, 1, 1}, {4, 6, 256, 256, 128, 64}},
92
+ {{4096, 3584, 8192, 0, 0, 0, 1, 27, 27}, {4, 6, 256, 256, 128, 64}},
93
+ {{2048, 3584, 8192, 0, 0, 0, 1, 1, 1}, {4, 6, 256, 256, 128, 64}},
94
+ {{2048, 3584, 8192, 0, 0, 0, 1, 27, 27}, {4, 6, 256, 256, 128, 64}},
95
+ {{512, 3584, 8192, 0, 0, 0, 1, 1, 1}, {1, 22, 256, 256, 128, 64}},
96
+ {{512, 3584, 8192, 0, 0, 0, 1, 27, 27}, {1, 22, 256, 256, 128, 64}},
97
+ {{4096, 1024, 8192, 0, 0, 0, 1, 1, 1}, {3, 8, 128, 256, 256, 64}},
98
+ {{4096, 1024, 8192, 0, 0, 0, 1, 27, 27}, {3, 8, 128, 256, 256, 64}},
99
+ {{2048, 1024, 8192, 0, 0, 0, 1, 1, 1}, {8, 3, 128, 256, 256, 64}},
100
+ {{2048, 1024, 8192, 0, 0, 0, 1, 27, 27}, {8, 3, 128, 256, 256, 64}},
101
+ {{1024, 1024, 8192, 0, 0, 0, 1, 1, 1}, {8, 3, 128, 256, 256, 64}},
102
+ {{1024, 1024, 8192, 0, 0, 0, 1, 27, 27}, {8, 3, 128, 256, 256, 64}},
103
+ {{128, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 22, 128, 256, 192, 64}},
104
+ {{128, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 22, 128, 256, 192, 64}},
105
+ {{64, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 22, 64, 256, 192, 64}},
106
+ {{64, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 22, 64, 256, 192, 64}},
107
+ };
108
+
109
+ static REPO MatMulTuneConfig910B4{
110
+ {{512, 3584, 8192, 0, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
111
+ {{512, 3584, 8192, 0, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
112
+ {{64, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 64, 256, 208, 64}},
113
+ {{64, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 64, 256, 208, 64}},
114
+ {{16, 2048, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 16, 256, 208, 64}},
115
+ {{16, 2048, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 16, 256, 208, 64}},
116
+ {{4096, 1536, 12288, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
117
+ {{4096, 1536, 12288, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
118
+ {{1, 1536, 12288, 0, 0, 0, 1, 1, 1}, {1, 18, 1, 512, 240, 64}},
119
+ {{1, 1536, 12288, 0, 0, 0, 1, 27, 27}, {1, 18, 1, 512, 240, 64}},
120
+ {{2048, 5376, 12288, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
121
+ {{2048, 5376, 12288, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
122
+ {{512, 1536, 12288, 0, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
123
+ {{512, 1536, 12288, 0, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
124
+ {{256, 7168, 8192, 0, 0, 0, 1, 1, 1}, {2, 10, 128, 512, 128, 128}},
125
+ {{256, 7168, 8192, 0, 0, 0, 1, 27, 27}, {2, 10, 128, 512, 128, 128}},
126
+ {{512, 11008, 4096, 0, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 112, 64}},
127
+ {{512, 11008, 4096, 0, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 112, 64}},
128
+ {{128, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 128, 256, 208, 64}},
129
+ {{128, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 128, 256, 208, 64}},
130
+ {{1024, 2048, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
131
+ {{1024, 2048, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
132
+ {{256, 4096, 11008, 0, 0, 0, 1, 1, 1}, {1, 18, 256, 256, 128, 64}},
133
+ {{256, 4096, 11008, 0, 0, 0, 1, 27, 27}, {1, 18, 256, 256, 128, 64}},
134
+ {{64, 2048, 8192, 0, 0, 0, 1, 1, 1}, {1, 19, 64, 256, 224, 64}},
135
+ {{64, 2048, 8192, 0, 0, 0, 1, 27, 27}, {1, 19, 64, 256, 224, 64}},
136
+ {{512, 8192, 3584, 0, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
137
+ {{512, 8192, 3584, 0, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
138
+ {{2048, 1024, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
139
+ {{2048, 1024, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
140
+ {{4096, 7168, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
141
+ {{4096, 7168, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
142
+ {{4096, 1024, 8192, 0, 0, 0, 1, 1, 1}, {5, 4, 128, 256, 256, 64}},
143
+ {{4096, 1024, 8192, 0, 0, 0, 1, 27, 27}, {5, 4, 128, 256, 256, 64}},
144
+ {{2048, 8192, 7168, 0, 0, 0, 1, 1, 1}, {2, 10, 128, 256, 256, 64}},
145
+ {{2048, 8192, 7168, 0, 0, 0, 1, 27, 27}, {2, 10, 128, 256, 256, 64}},
146
+ {{1, 3584, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 1, 256, 208, 64}},
147
+ {{1, 3584, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 1, 256, 208, 64}},
148
+ {{1, 2752, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 1, 512, 208, 64}},
149
+ {{1, 2752, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 1, 512, 208, 64}},
150
+ {{4096, 5376, 12288, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
151
+ {{4096, 5376, 12288, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
152
+ {{2048, 2048, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
153
+ {{2048, 2048, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
154
+ {{32, 1536, 12288, 0, 0, 0, 1, 1, 1}, {1, 20, 32, 256, 208, 64}},
155
+ {{32, 1536, 12288, 0, 0, 0, 1, 27, 27}, {1, 20, 32, 256, 208, 64}},
156
+ {{2048, 7168, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
157
+ {{2048, 7168, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
158
+ {{2048, 3584, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
159
+ {{2048, 3584, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
160
+ {{1024, 3584, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
161
+ {{1024, 3584, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
162
+ {{4096, 4096, 11008, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
163
+ {{4096, 4096, 11008, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
164
+ {{256, 2752, 8192, 0, 0, 0, 1, 1, 1}, {2, 10, 128, 512, 128, 128}},
165
+ {{256, 2752, 8192, 0, 0, 0, 1, 27, 27}, {2, 10, 128, 512, 128, 128}},
166
+ {{2048, 4096, 11008, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
167
+ {{2048, 4096, 11008, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
168
+ {{32, 2048, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 32, 256, 208, 64}},
169
+ {{32, 2048, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 32, 256, 208, 64}},
170
+ {{16, 3584, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 16, 256, 208, 64}},
171
+ {{16, 3584, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 16, 256, 208, 64}},
172
+ {{4096, 8192, 7168, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
173
+ {{4096, 8192, 7168, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
174
+ {{4096, 2048, 8192, 0, 0, 0, 1, 1, 1}, {5, 4, 128, 256, 256, 64}},
175
+ {{4096, 2048, 8192, 0, 0, 0, 1, 27, 27}, {5, 4, 128, 256, 256, 64}},
176
+ {{4096, 3584, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
177
+ {{4096, 3584, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
178
+ {{1, 2048, 8192, 0, 0, 0, 1, 1, 1}, {1, 18, 1, 512, 240, 64}},
179
+ {{1, 2048, 8192, 0, 0, 0, 1, 27, 27}, {1, 18, 1, 512, 240, 64}},
180
+ {{256, 3584, 8192, 0, 0, 0, 1, 1, 1}, {1, 19, 256, 256, 112, 64}},
181
+ {{256, 3584, 8192, 0, 0, 0, 1, 27, 27}, {1, 19, 256, 256, 112, 64}},
182
+ {{1024, 5376, 12288, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
183
+ {{1024, 5376, 12288, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
184
+ {{256, 8192, 7168, 0, 0, 0, 1, 1, 1}, {2, 10, 128, 512, 128, 128}},
185
+ {{256, 8192, 7168, 0, 0, 0, 1, 27, 27}, {2, 10, 128, 512, 128, 128}},
186
+ {{32, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 32, 256, 208, 64}},
187
+ {{32, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 32, 256, 208, 64}},
188
+ {{512, 1024, 8192, 0, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
189
+ {{512, 1024, 8192, 0, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
190
+ {{4096, 8192, 3584, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
191
+ {{4096, 8192, 3584, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
192
+ {{512, 2752, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 128, 512, 128, 128}},
193
+ {{512, 2752, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 128, 512, 128, 128}},
194
+ {{1024, 1024, 8192, 0, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
195
+ {{1024, 1024, 8192, 0, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
196
+ {{1, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 1, 256, 208, 64}},
197
+ {{1, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 1, 256, 208, 64}},
198
+ {{16, 1024, 8192, 0, 0, 0, 1, 1, 1}, {1, 20, 16, 256, 208, 64}},
199
+ {{16, 1024, 8192, 0, 0, 0, 1, 27, 27}, {1, 20, 16, 256, 208, 64}},
200
+ {{512, 5376, 12288, 0, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
201
+ {{512, 5376, 12288, 0, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
66
202
 
67
- // infer key: batch, k, n0, n1, n2, ta, tb, in_type, out_type
68
- {{SMALL_M_MARK, 11264, 1408, 128, 128, 0, 1, 1, 1}, {1, 13, SMALL_M_MARK, 256, 128, SMALL_M_MARK, 64, 128}},
69
- {{SMALL_M_MARK, 11264, 1408, 128, 128, 0, 1, 27, 27}, {1, 13, SMALL_M_MARK, 256, 128, SMALL_M_MARK, 64, 128}},
70
203
  };
71
204
 
72
- static REPO MatMulQkvTuneConfig910B4{
205
+ static REPO MultiMatMulsTuneConfig910B2{
73
206
  // prefill key: seqlen, k, n0, n1, n2, ta, tb, in_type, out_type
74
- {{1024, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 256, 64, 128}},
75
- {{1024, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 256, 64, 128}},
76
- {{2048, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 256, 64, 128}},
77
- {{2048, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 256, 64, 128}},
78
- {{4096, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 256, 64, 128}},
79
- {{4096, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 256, 64, 128}},
207
+ {{1024, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
208
+ {{1024, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
209
+ {{2048, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
210
+ {{2048, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
211
+ {{4096, 11264, 1408, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
212
+ {{4096, 11264, 1408, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
80
213
 
81
- // infer key: batch, k, n0, n1, n2, ta, tb, in_type, out_type
82
- {{SMALL_M_MARK, 11264, 1408, 128, 128, 0, 1, 1, 1}, {1, 13, SMALL_M_MARK, 256, 128, SMALL_M_MARK, 64, 128}},
83
- {{SMALL_M_MARK, 11264, 1408, 128, 128, 0, 1, 27, 27}, {1, 13, SMALL_M_MARK, 256, 128, SMALL_M_MARK, 64, 128}},
214
+ {{256, 8192, 1024, 128, 128, 0, 1, 1, 1}, {2, 10, 128, 512, 128, 128}},
215
+ {{256, 8192, 1024, 128, 128, 0, 1, 27, 27}, {2, 10, 128, 512, 128, 128}},
216
+
217
+ // // infer key: batch, k, n0, n1, n2, ta, tb, in_type, out_type
218
+ // {{SMALL_M_MARK, 11264, 1408, 128, 128, 0, 1, 1, 1}, {1, 13, SMALL_M_MARK, 256, 128, 64}},
219
+ // {{SMALL_M_MARK, 11264, 1408, 128, 128, 0, 1, 27, 27}, {1, 13, SMALL_M_MARK, 256, 128, 64}},
220
+ };
221
+
222
+ static REPO MultiMatMulsTuneConfig910B4{
223
+ // prefill key: seqlen, k, n0, n1, n2, ta, tb, in_type, out_type
224
+ {{2048, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
225
+ {{2048, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
226
+ {{1024, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
227
+ {{1024, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
228
+ {{512, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
229
+ {{512, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
230
+ {{32, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {1, 18, 32, 256, 256, 64}},
231
+ {{32, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {1, 18, 32, 256, 256, 64}},
232
+ {{32, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {1, 19, 32, 256, 192, 64}},
233
+ {{32, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {1, 19, 32, 256, 192, 64}},
234
+ {{2048, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
235
+ {{2048, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
236
+ {{512, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
237
+ {{512, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
238
+ {{4096, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
239
+ {{4096, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
240
+ {{2048, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
241
+ {{2048, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
242
+ {{1024, 8192, 2048, 256, 256, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
243
+ {{1024, 8192, 2048, 256, 256, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
244
+ {{1024, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
245
+ {{1024, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
246
+ {{1024, 8192, 1024, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
247
+ {{1024, 8192, 1024, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
248
+ {{128, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {1, 20, 128, 256, 208, 64}},
249
+ {{128, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {1, 20, 128, 256, 208, 64}},
250
+ {{64, 8192, 1024, 128, 128, 0, 1, 1, 1}, {1, 20, 64, 512, 64, 256}},
251
+ {{64, 8192, 1024, 128, 128, 0, 1, 27, 27}, {1, 20, 64, 512, 64, 256}},
252
+ {{2048, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
253
+ {{2048, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
254
+ {{1, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {1, 19, 1, 256, 224, 64}},
255
+ {{1, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {1, 19, 1, 256, 224, 64}},
256
+ {{16, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {1, 19, 16, 512, 192, 64}},
257
+ {{16, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {1, 19, 16, 512, 192, 64}},
258
+ {{16, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {1, 20, 16, 256, 112, 128}},
259
+ {{16, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {1, 20, 16, 256, 112, 128}},
260
+ {{16, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {1, 18, 16, 256, 256, 64}},
261
+ {{16, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {1, 18, 16, 256, 256, 64}},
262
+ {{512, 8192, 2048, 256, 256, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
263
+ {{512, 8192, 2048, 256, 256, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
264
+ {{256, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {2, 9, 128, 256, 256, 64}},
265
+ {{256, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {2, 9, 128, 256, 256, 64}},
266
+ {{2048, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
267
+ {{2048, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
268
+ {{512, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {2, 9, 256, 256, 128, 64}},
269
+ {{512, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {2, 9, 256, 256, 128, 64}},
270
+ {{1, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {1, 20, 1, 512, 144, 64}},
271
+ {{1, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {1, 20, 1, 512, 144, 64}},
272
+ {{256, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {1, 20, 256, 256, 128, 64}},
273
+ {{256, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {1, 20, 256, 256, 128, 64}},
274
+ {{1024, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
275
+ {{1024, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
276
+ {{1024, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
277
+ {{1024, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
278
+ {{64, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {1, 20, 64, 512, 160, 64}},
279
+ {{64, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {1, 20, 64, 512, 160, 64}},
280
+ {{4096, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
281
+ {{4096, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
282
+ {{4096, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
283
+ {{4096, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
284
+ {{64, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {1, 19, 64, 256, 192, 64}},
285
+ {{64, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {1, 19, 64, 256, 192, 64}},
286
+ {{256, 8192, 2048, 256, 256, 0, 1, 1, 1}, {2, 10, 128, 512, 128, 128}},
287
+ {{256, 8192, 2048, 256, 256, 0, 1, 27, 27}, {2, 10, 128, 512, 128, 128}},
288
+ {{2048, 8192, 1024, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
289
+ {{2048, 8192, 1024, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
290
+ {{32, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {1, 18, 32, 512, 128, 128}},
291
+ {{32, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {1, 18, 32, 512, 128, 128}},
292
+ {{128, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {1, 18, 128, 256, 256, 64}},
293
+ {{128, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {1, 18, 128, 256, 256, 64}},
294
+ {{256, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {1, 19, 256, 256, 128, 64}},
295
+ {{256, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {1, 19, 256, 256, 128, 64}},
296
+ {{128, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {1, 19, 128, 256, 192, 64}},
297
+ {{128, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {1, 19, 128, 256, 192, 64}},
298
+ {{16, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {1, 18, 16, 512, 128, 128}},
299
+ {{16, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {1, 18, 16, 512, 128, 128}},
300
+ {{32, 8192, 2048, 256, 256, 0, 1, 1, 1}, {1, 20, 32, 256, 128, 128}},
301
+ {{32, 8192, 2048, 256, 256, 0, 1, 27, 27}, {1, 20, 32, 256, 128, 128}},
302
+ {{32, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {1, 19, 32, 256, 192, 64}},
303
+ {{32, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {1, 19, 32, 256, 192, 64}},
304
+ {{1, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {1, 18, 1, 256, 256, 64}},
305
+ {{1, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {1, 18, 1, 256, 256, 64}},
306
+ {{1, 8192, 2048, 256, 256, 0, 1, 1, 1}, {1, 20, 1, 512, 128, 128}},
307
+ {{1, 8192, 2048, 256, 256, 0, 1, 27, 27}, {1, 20, 1, 512, 128, 128}},
308
+ {{2048, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
309
+ {{2048, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
310
+ {{128, 8192, 2048, 256, 256, 0, 1, 1, 1}, {1, 20, 128, 512, 128, 128}},
311
+ {{128, 8192, 2048, 256, 256, 0, 1, 27, 27}, {1, 20, 128, 512, 128, 128}},
312
+ {{512, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {2, 10, 128, 256, 224, 64}},
313
+ {{512, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {2, 10, 128, 256, 224, 64}},
314
+ {{16, 8192, 1024, 128, 128, 0, 1, 1, 1}, {1, 20, 16, 512, 64, 256}},
315
+ {{16, 8192, 1024, 128, 128, 0, 1, 27, 27}, {1, 20, 16, 512, 64, 256}},
316
+ {{512, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
317
+ {{512, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
318
+ {{1, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {1, 19, 1, 512, 192, 64}},
319
+ {{1, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {1, 19, 1, 512, 192, 64}},
320
+ {{128, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {1, 20, 128, 256, 144, 64}},
321
+ {{128, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {1, 20, 128, 256, 144, 64}},
322
+ {{128, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {1, 18, 128, 256, 256, 64}},
323
+ {{128, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {1, 18, 128, 256, 256, 64}},
324
+ {{256, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {1, 20, 256, 256, 96, 64}},
325
+ {{256, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {1, 20, 256, 256, 96, 64}},
326
+ {{128, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {1, 19, 128, 512, 96, 128}},
327
+ {{128, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {1, 19, 128, 512, 96, 128}},
328
+ {{256, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {2, 10, 128, 256, 224, 64}},
329
+ {{256, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {2, 10, 128, 256, 224, 64}},
330
+ {{64, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {1, 18, 64, 256, 256, 64}},
331
+ {{64, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {1, 18, 64, 256, 256, 64}},
332
+ {{256, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {2, 10, 128, 256, 256, 64}},
333
+ {{256, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {2, 10, 128, 256, 256, 64}},
334
+ {{2048, 8192, 2048, 256, 256, 0, 1, 1, 1}, {2, 10, 128, 256, 256, 64}},
335
+ {{2048, 8192, 2048, 256, 256, 0, 1, 27, 27}, {2, 10, 128, 256, 256, 64}},
336
+ {{1, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {1, 18, 1, 512, 128, 128}},
337
+ {{1, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {1, 18, 1, 512, 128, 128}},
338
+ {{4096, 4096, 11008, 11008, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
339
+ {{4096, 4096, 11008, 11008, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
340
+ {{32, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {1, 20, 32, 256, 208, 64}},
341
+ {{32, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {1, 20, 32, 256, 208, 64}},
342
+ {{16, 4096, 4096, 4096, 4096, 0, 1, 1, 1}, {1, 18, 16, 256, 240, 64}},
343
+ {{16, 4096, 4096, 4096, 4096, 0, 1, 27, 27}, {1, 18, 16, 256, 240, 64}},
344
+ {{32, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {1, 20, 32, 512, 144, 64}},
345
+ {{32, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {1, 20, 32, 512, 144, 64}},
346
+ {{64, 8192, 2048, 256, 256, 0, 1, 1, 1}, {1, 20, 64, 512, 128, 128}},
347
+ {{64, 8192, 2048, 256, 256, 0, 1, 27, 27}, {1, 20, 64, 512, 128, 128}},
348
+ {{1, 8192, 1024, 128, 128, 0, 1, 1, 1}, {1, 20, 1, 512, 64, 256}},
349
+ {{1, 8192, 1024, 128, 128, 0, 1, 27, 27}, {1, 20, 1, 512, 64, 256}},
350
+ {{16, 8192, 2048, 256, 256, 0, 1, 1, 1}, {1, 20, 16, 512, 128, 128}},
351
+ {{16, 8192, 2048, 256, 256, 0, 1, 27, 27}, {1, 20, 16, 512, 128, 128}},
352
+ {{128, 8192, 1024, 128, 128, 0, 1, 1, 1}, {1, 20, 128, 512, 64, 128}},
353
+ {{128, 8192, 1024, 128, 128, 0, 1, 27, 27}, {1, 20, 128, 512, 64, 128}},
354
+ {{64, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {1, 18, 64, 256, 256, 64}},
355
+ {{64, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {1, 18, 64, 256, 256, 64}},
356
+ {{4096, 8192, 3584, 3584, 0, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
357
+ {{4096, 8192, 3584, 3584, 0, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
358
+ {{32, 8192, 1024, 128, 128, 0, 1, 1, 1}, {1, 20, 32, 512, 64, 256}},
359
+ {{32, 8192, 1024, 128, 128, 0, 1, 27, 27}, {1, 20, 32, 512, 64, 256}},
360
+ {{1, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {1, 20, 1, 256, 112, 128}},
361
+ {{1, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {1, 20, 1, 256, 112, 128}},
362
+ {{512, 8192, 1024, 128, 128, 0, 1, 1, 1}, {2, 10, 256, 256, 128, 64}},
363
+ {{512, 8192, 1024, 128, 128, 0, 1, 27, 27}, {2, 10, 256, 256, 128, 64}},
364
+ {{1024, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
365
+ {{1024, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
366
+ {{4096, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {3, 6, 128, 256, 256, 64}},
367
+ {{4096, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {3, 6, 128, 256, 256, 64}},
368
+ {{64, 12288, 5376, 5376, 0, 0, 1, 1, 1}, {1, 19, 64, 256, 192, 64}},
369
+ {{64, 12288, 5376, 5376, 0, 0, 1, 27, 27}, {1, 19, 64, 256, 192, 64}},
370
+ {{64, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {1, 20, 64, 256, 144, 64}},
371
+ {{64, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {1, 20, 64, 256, 144, 64}},
372
+ {{256, 8192, 1024, 128, 128, 0, 1, 1, 1}, {2, 10, 128, 256, 128, 128}},
373
+ {{256, 8192, 1024, 128, 128, 0, 1, 27, 27}, {2, 10, 128, 256, 128, 128}},
374
+ {{4096, 8192, 1024, 128, 128, 0, 1, 1, 1}, {4, 5, 256, 256, 128, 64}},
375
+ {{4096, 8192, 1024, 128, 128, 0, 1, 27, 27}, {4, 5, 256, 256, 128, 64}},
376
+ {{16, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {1, 20, 16, 512, 144, 64}},
377
+ {{16, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {1, 20, 16, 512, 144, 64}},
378
+ {{4096, 8192, 2048, 256, 256, 0, 1, 1, 1}, {4, 5, 128, 256, 256, 64}},
379
+ {{4096, 8192, 2048, 256, 256, 0, 1, 27, 27}, {4, 5, 128, 256, 256, 64}},
380
+ {{1024, 12288, 1536, 1536, 1536, 0, 1, 1, 1}, {2, 9, 128, 256, 256, 64}},
381
+ {{1024, 12288, 1536, 1536, 1536, 0, 1, 27, 27}, {2, 9, 128, 256, 256, 64}},
382
+ {{512, 8192, 2752, 2752, 0, 0, 1, 1, 1}, {2, 10, 256, 256, 112, 64}},
383
+ {{512, 8192, 2752, 2752, 0, 0, 1, 27, 27}, {2, 10, 256, 256, 112, 64}},
84
384
  };
85
385
 
86
386
  #endif // TUNE_REPO_MATMUL_TABLE_H
@@ -31,26 +31,37 @@ using DIMS = AsdOps::SVector<int64_t>;
31
31
  using RunInfo = AsdOps::RunInfo;
32
32
 
33
33
  struct TilingCacheInfo {
34
- AsdOps::KernelInfo kernel_info_;
34
+ RunInfo run_info_;
35
35
  uint64_t workspace_size_{0};
36
36
  uint32_t core_num_{1};
37
+ uint32_t cache_id_{1};
38
+ bool use_asd_tiling_{false};
37
39
 
38
40
  TilingCacheInfo() {}
39
41
 
40
42
  TilingCacheInfo(const TilingCacheInfo &other) {
41
- other.kernel_info_.CopyTo(this->kernel_info_);
43
+ if (other.use_asd_tiling_) {
44
+ other.run_info_.CopyTo(this->run_info_);
45
+ }
42
46
  this->workspace_size_ = other.workspace_size_;
43
47
  this->core_num_ = other.core_num_;
48
+ this->cache_id_ = other.cache_id_;
49
+ this->use_asd_tiling_ = other.use_asd_tiling_;
44
50
  }
45
51
 
46
52
  const TilingCacheInfo &operator=(const TilingCacheInfo &other) {
47
- other.kernel_info_.CopyTo(this->kernel_info_);
53
+ if (other.use_asd_tiling_) {
54
+ other.run_info_.CopyTo(this->run_info_);
55
+ }
48
56
  this->workspace_size_ = other.workspace_size_;
49
57
  this->core_num_ = other.core_num_;
58
+ this->cache_id_ = other.cache_id_;
59
+ this->use_asd_tiling_ = other.use_asd_tiling_;
50
60
  return *this;
51
61
  }
52
62
 
53
- AsdOps::KernelInfo &GetKernelInfo() { return kernel_info_; }
63
+ AsdOps::KernelInfo &GetKernelInfo() { return run_info_.GetKernelInfo(); }
64
+ void SetAsdTiling() { use_asd_tiling_ = true; }
54
65
  };
55
66
 
56
67
  using CacheInfo = TilingCacheInfo;
@@ -5,6 +5,7 @@
5
5
  #include "src/utils/elewise_tiling.h"
6
6
  #include "src/compare/compare_tiling.h"
7
7
  #include "src/cast/cast_tiling.h"
8
+ #include "src/sub/sub_tiling.h"
8
9
 
9
10
  namespace mindspore::internal {
10
11
  static std::ostream &operator<<(std::ostream &os, const CastTilingData &dt) {
@@ -42,5 +43,12 @@ static std::ostream &operator<<(std::ostream &os, const CompareTilingData &dt) {
42
43
  os << *ele_tiling;
43
44
  return os;
44
45
  }
46
+ static std::ostream &operator<<(std::ostream &os, const SubTilingData &dt) {
47
+ os << ", broadcast_mode_:" << dt.broadcast_mode_;
48
+ os << ", input_dtype_:" << dt.input_dtype_;
49
+ ElewiseTailTilingData *ele_tiling = (ElewiseTailTilingData *)&dt;
50
+ os << *ele_tiling;
51
+ return os;
52
+ }
45
53
  } // namespace mindspore::internal
46
54
  #endif // MS_KERNELS_INTERNAL_KERNEL_UTILS_LOG_LOG_TILING_H_
@@ -0,0 +1,22 @@
1
+ /*
2
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef LCAL_H
17
+ #define LCAL_H
18
+ #include "lcal_types.h"
19
+ #include "lcal_comm.h"
20
+ #include "lccl.h"
21
+ #include "lcoc.h"
22
+ #endif // LCAL_H