mindspore 2.3.0rc1__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (316) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
  4. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  6. mindspore/_checkparam.py +20 -0
  7. mindspore/_extends/parse/parser.py +1 -1
  8. mindspore/_extends/parse/standard_method.py +6 -5
  9. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  10. mindspore/amp.py +5 -5
  11. mindspore/boost/boost_cell_wrapper.py +1 -1
  12. mindspore/boost/group_loss_scale_manager.py +1 -1
  13. mindspore/common/__init__.py +4 -2
  14. mindspore/common/_register_for_recompute.py +48 -0
  15. mindspore/common/_stub_tensor.py +1 -0
  16. mindspore/common/api.py +56 -4
  17. mindspore/common/dtype.py +5 -3
  18. mindspore/common/dump.py +2 -2
  19. mindspore/common/hook_handle.py +51 -4
  20. mindspore/common/initializer.py +1 -1
  21. mindspore/common/jit_config.py +17 -6
  22. mindspore/common/parameter.py +7 -2
  23. mindspore/common/recompute.py +247 -0
  24. mindspore/common/sparse_tensor.py +2 -2
  25. mindspore/common/symbol.py +1 -1
  26. mindspore/common/tensor.py +74 -36
  27. mindspore/communication/__init__.py +3 -3
  28. mindspore/communication/management.py +30 -30
  29. mindspore/context.py +28 -15
  30. mindspore/dataset/__init__.py +5 -5
  31. mindspore/dataset/audio/__init__.py +2 -2
  32. mindspore/dataset/audio/transforms.py +51 -51
  33. mindspore/dataset/callback/ds_callback.py +2 -2
  34. mindspore/dataset/engine/cache_client.py +1 -1
  35. mindspore/dataset/engine/datasets.py +3 -3
  36. mindspore/dataset/engine/datasets_audio.py +14 -14
  37. mindspore/dataset/engine/datasets_standard_format.py +3 -3
  38. mindspore/dataset/engine/datasets_text.py +38 -38
  39. mindspore/dataset/engine/datasets_user_defined.py +3 -3
  40. mindspore/dataset/engine/datasets_vision.py +68 -68
  41. mindspore/dataset/text/__init__.py +3 -3
  42. mindspore/dataset/text/transforms.py +26 -26
  43. mindspore/dataset/transforms/__init__.py +1 -1
  44. mindspore/dataset/vision/__init__.py +3 -3
  45. mindspore/dataset/vision/transforms.py +92 -92
  46. mindspore/dataset/vision/utils.py +1 -1
  47. mindspore/experimental/optim/adadelta.py +2 -2
  48. mindspore/experimental/optim/adagrad.py +2 -2
  49. mindspore/experimental/optim/adam.py +2 -2
  50. mindspore/experimental/optim/adamax.py +2 -2
  51. mindspore/experimental/optim/adamw.py +2 -2
  52. mindspore/experimental/optim/asgd.py +2 -2
  53. mindspore/experimental/optim/lr_scheduler.py +24 -20
  54. mindspore/experimental/optim/nadam.py +2 -2
  55. mindspore/experimental/optim/optimizer.py +1 -1
  56. mindspore/experimental/optim/radam.py +2 -2
  57. mindspore/experimental/optim/rmsprop.py +2 -2
  58. mindspore/experimental/optim/rprop.py +2 -2
  59. mindspore/experimental/optim/sgd.py +2 -2
  60. mindspore/hal/stream.py +2 -0
  61. mindspore/include/mindapi/base/types.h +5 -0
  62. mindspore/lib/libdnnl.so.2 +0 -0
  63. mindspore/lib/libmindspore.so +0 -0
  64. mindspore/lib/libmindspore_backend.so +0 -0
  65. mindspore/lib/libmindspore_common.so +0 -0
  66. mindspore/lib/libmindspore_core.so +0 -0
  67. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  68. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  69. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  70. mindspore/lib/libmindspore_shared_lib.so +0 -0
  71. mindspore/lib/libopencv_core.so.4.5 +0 -0
  72. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  73. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  75. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
  76. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  77. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  78. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  79. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  80. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  81. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  82. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  83. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
  84. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  85. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  86. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
  87. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
  88. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
  89. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
  90. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
  91. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
  92. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
  93. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
  94. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
  164. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
  165. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
  166. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
  167. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
  168. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  169. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  170. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  171. mindspore/log.py +2 -2
  172. mindspore/mint/__init__.py +457 -0
  173. mindspore/mint/nn/__init__.py +430 -0
  174. mindspore/mint/nn/functional.py +424 -0
  175. mindspore/mint/optim/__init__.py +24 -0
  176. mindspore/mint/optim/adamw.py +186 -0
  177. mindspore/multiprocessing/__init__.py +4 -0
  178. mindspore/nn/__init__.py +3 -0
  179. mindspore/nn/cell.py +51 -47
  180. mindspore/nn/extend/__init__.py +29 -0
  181. mindspore/nn/extend/basic.py +140 -0
  182. mindspore/nn/extend/embedding.py +143 -0
  183. mindspore/nn/extend/layer/__init__.py +27 -0
  184. mindspore/nn/extend/layer/normalization.py +107 -0
  185. mindspore/nn/extend/pooling.py +117 -0
  186. mindspore/nn/generator.py +297 -0
  187. mindspore/nn/layer/basic.py +109 -1
  188. mindspore/nn/layer/container.py +2 -2
  189. mindspore/nn/layer/conv.py +6 -6
  190. mindspore/nn/layer/embedding.py +1 -1
  191. mindspore/nn/layer/normalization.py +21 -43
  192. mindspore/nn/layer/padding.py +4 -0
  193. mindspore/nn/optim/ada_grad.py +2 -2
  194. mindspore/nn/optim/adadelta.py +1 -1
  195. mindspore/nn/optim/adafactor.py +1 -1
  196. mindspore/nn/optim/adam.py +7 -7
  197. mindspore/nn/optim/adamax.py +2 -2
  198. mindspore/nn/optim/adasum.py +2 -2
  199. mindspore/nn/optim/asgd.py +2 -2
  200. mindspore/nn/optim/ftrl.py +1 -1
  201. mindspore/nn/optim/lamb.py +3 -3
  202. mindspore/nn/optim/lars.py +1 -1
  203. mindspore/nn/optim/lazyadam.py +2 -2
  204. mindspore/nn/optim/momentum.py +2 -2
  205. mindspore/nn/optim/optimizer.py +2 -2
  206. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  207. mindspore/nn/optim/rmsprop.py +2 -2
  208. mindspore/nn/optim/rprop.py +2 -2
  209. mindspore/nn/optim/sgd.py +2 -2
  210. mindspore/nn/optim/thor.py +2 -2
  211. mindspore/nn/wrap/cell_wrapper.py +9 -9
  212. mindspore/nn/wrap/grad_reducer.py +5 -5
  213. mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
  214. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
  215. mindspore/ops/_vmap/vmap_math_ops.py +27 -8
  216. mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
  217. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
  218. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
  219. mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
  220. mindspore/ops/auto_generate/gen_extend_func.py +274 -0
  221. mindspore/ops/auto_generate/gen_ops_def.py +889 -22
  222. mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
  223. mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
  224. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  225. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
  226. mindspore/ops/extend/__init__.py +9 -1
  227. mindspore/ops/extend/array_func.py +134 -27
  228. mindspore/ops/extend/math_func.py +3 -3
  229. mindspore/ops/extend/nn_func.py +363 -2
  230. mindspore/ops/function/__init__.py +19 -2
  231. mindspore/ops/function/array_func.py +463 -439
  232. mindspore/ops/function/clip_func.py +7 -18
  233. mindspore/ops/function/grad/grad_func.py +5 -5
  234. mindspore/ops/function/linalg_func.py +4 -4
  235. mindspore/ops/function/math_func.py +260 -243
  236. mindspore/ops/function/nn_func.py +825 -62
  237. mindspore/ops/function/random_func.py +73 -4
  238. mindspore/ops/function/sparse_unary_func.py +1 -1
  239. mindspore/ops/function/vmap_func.py +1 -1
  240. mindspore/ops/functional.py +2 -2
  241. mindspore/ops/op_info_register.py +1 -31
  242. mindspore/ops/operations/__init__.py +2 -3
  243. mindspore/ops/operations/_grad_ops.py +2 -107
  244. mindspore/ops/operations/_inner_ops.py +5 -5
  245. mindspore/ops/operations/_sequence_ops.py +2 -2
  246. mindspore/ops/operations/array_ops.py +11 -233
  247. mindspore/ops/operations/comm_ops.py +32 -32
  248. mindspore/ops/operations/custom_ops.py +7 -89
  249. mindspore/ops/operations/manually_defined/ops_def.py +329 -4
  250. mindspore/ops/operations/math_ops.py +13 -163
  251. mindspore/ops/operations/nn_ops.py +9 -316
  252. mindspore/ops/operations/random_ops.py +1 -1
  253. mindspore/ops/operations/sparse_ops.py +3 -3
  254. mindspore/ops/primitive.py +2 -2
  255. mindspore/ops_generate/arg_dtype_cast.py +12 -3
  256. mindspore/ops_generate/arg_handler.py +24 -0
  257. mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
  258. mindspore/ops_generate/gen_pyboost_func.py +13 -6
  259. mindspore/ops_generate/pyboost_utils.py +2 -17
  260. mindspore/parallel/__init__.py +3 -2
  261. mindspore/parallel/_auto_parallel_context.py +106 -1
  262. mindspore/parallel/_parallel_serialization.py +34 -2
  263. mindspore/parallel/_utils.py +16 -0
  264. mindspore/parallel/algo_parameter_config.py +4 -4
  265. mindspore/parallel/checkpoint_transform.py +249 -77
  266. mindspore/parallel/cluster/process_entity/_api.py +1 -1
  267. mindspore/parallel/parameter_broadcast.py +1 -1
  268. mindspore/parallel/shard.py +1 -1
  269. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
  270. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
  271. mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
  272. mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
  273. mindspore/profiler/parser/ascend_op_generator.py +26 -9
  274. mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
  275. mindspore/profiler/parser/profiler_info.py +11 -1
  276. mindspore/profiler/profiling.py +13 -5
  277. mindspore/rewrite/api/node.py +12 -12
  278. mindspore/rewrite/api/symbol_tree.py +11 -11
  279. mindspore/run_check/_check_version.py +1 -1
  280. mindspore/safeguard/rewrite_obfuscation.py +2 -2
  281. mindspore/train/amp.py +4 -4
  282. mindspore/train/anf_ir_pb2.py +8 -2
  283. mindspore/train/callback/_backup_and_restore.py +2 -2
  284. mindspore/train/callback/_callback.py +4 -4
  285. mindspore/train/callback/_checkpoint.py +2 -2
  286. mindspore/train/callback/_early_stop.py +2 -2
  287. mindspore/train/callback/_landscape.py +4 -4
  288. mindspore/train/callback/_loss_monitor.py +2 -2
  289. mindspore/train/callback/_on_request_exit.py +2 -2
  290. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  291. mindspore/train/callback/_summary_collector.py +2 -2
  292. mindspore/train/callback/_time_monitor.py +2 -2
  293. mindspore/train/dataset_helper.py +8 -3
  294. mindspore/train/loss_scale_manager.py +2 -2
  295. mindspore/train/metrics/metric.py +3 -3
  296. mindspore/train/mind_ir_pb2.py +22 -17
  297. mindspore/train/model.py +15 -15
  298. mindspore/train/serialization.py +18 -18
  299. mindspore/train/summary/summary_record.py +7 -7
  300. mindspore/train/train_thor/convert_utils.py +3 -3
  301. mindspore/version.py +1 -1
  302. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
  303. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +307 -260
  304. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
  305. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
  306. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
  307. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
  308. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
  309. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
  310. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
  311. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
  312. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
  313. /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
  314. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  315. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  316. {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,8 @@ from mindspore.ops.auto_generate import gen_arg_handler as handler
26
26
  from mindspore.common import Tensor, CSRTensor, COOTensor
27
27
  from mindspore.common._stub_tensor import _convert_stub
28
28
  from mindspore._c_expression import typing
29
- from mindspore._c_expression import pyboost_cast
30
29
  from mindspore._c_expression import Tensor as Tensor_
31
- from mindspore._c_expression import pyboost_tile
30
+ from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
32
31
  from mindspore.common import dtype as mstype
33
32
  from mindspore.common._utils import is_shape_unknown
34
33
  from mindspore import _checkparam as validator
@@ -36,6 +35,10 @@ from mindspore.ops.operations.manually_defined._inner import ScalarCast
36
35
  from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
37
36
  from mindspore.common.initializer import Zero
38
37
  from mindspore.common.parameter import Parameter
38
+ from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
39
+
40
+
41
+ dtype_to_type_id = DtypeToEnum()
39
42
 
40
43
 
41
44
  dtype_to_type_id = DtypeToEnum()
@@ -1249,9 +1252,57 @@ def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
1249
1252
  return value
1250
1253
 
1251
1254
 
1252
- def infer_value_for_Cast(x, dst_type_enum):
1255
+ def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
1256
+ """Infer value for Common ReduceExtand op."""
1257
+ value = None
1258
+ if input_x is not None:
1259
+ prim_map = {
1260
+ 'MeanExt': np.mean,
1261
+ 'SumExt': np.sum,
1262
+ 'ProdExt': np.prod,
1263
+ }
1264
+ np_reduce_extand_func = prim_map.get(prim_name, None)
1265
+
1266
+ if np_reduce_extand_func is not None:
1267
+ value = input_x.asnumpy()
1268
+ if isinstance(axis, int):
1269
+ pass
1270
+ elif axis:
1271
+ axis = tuple(set(axis))
1272
+ else:
1273
+ axis = tuple(range(len(value.shape)))
1274
+ if dtype is not None:
1275
+ np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1276
+ value = np_reduce_extand_func(value, axis, dtype=np_dtype, keepdims=keep_dims)
1277
+ else:
1278
+ value = np_reduce_extand_func(value, axis, keepdims=keep_dims)
1279
+
1280
+ value = np.array(value)
1281
+ value = Tensor(value)
1282
+ return value
1283
+
1284
+
1285
+ def _infer_value_for_max_min(input_x, prim_name):
1286
+ """Infer value for Max/Min op."""
1287
+ value = None
1288
+ if input_x is not None:
1289
+ prim_map = {
1290
+ 'Max': np.max,
1291
+ 'Min': np.min,
1292
+ }
1293
+ np_reduce_func = prim_map.get(prim_name, None)
1294
+
1295
+ if np_reduce_func is not None:
1296
+ value = input_x.asnumpy()
1297
+ value = np_reduce_func(value, None, keepdims=False)
1298
+ value = np.array(value)
1299
+ value = Tensor(value)
1300
+ return value
1301
+
1302
+
1303
+ def infer_value_for_Cast(x, dst_type_enum=None):
1253
1304
  """Infer value for Cast op."""
1254
- if x is None:
1305
+ if x is None or dst_type_enum is None:
1255
1306
  return None
1256
1307
  dst_type = typing.type_id_to_type(dst_type_enum)
1257
1308
  src_type = mstype.get_py_obj_dtype(x)
@@ -1280,11 +1331,21 @@ def infer_value_for_ReduceMax(input_x, axis, keep_dims):
1280
1331
  return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
1281
1332
 
1282
1333
 
1334
+ def infer_value_for_Max(input_x):
1335
+ """Infer value for Max op."""
1336
+ return _infer_value_for_max_min(input_x, 'Max')
1337
+
1338
+
1283
1339
  def infer_value_for_ReduceMin(input_x, axis, keep_dims):
1284
1340
  """Infer value for ReduceMin op."""
1285
1341
  return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMin')
1286
1342
 
1287
1343
 
1344
+ def infer_value_for_Min(input_x):
1345
+ """Infer value for Max op."""
1346
+ return _infer_value_for_max_min(input_x, 'Min')
1347
+
1348
+
1288
1349
  def infer_value_for_ReduceProd(input_x, axis, keep_dims):
1289
1350
  """Infer value for ReduceProd op."""
1290
1351
  return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceProd')
@@ -1305,6 +1366,21 @@ def infer_value_for_ReduceAny(input_x, axis, keep_dims):
1305
1366
  return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAny')
1306
1367
 
1307
1368
 
1369
+ def infer_value_for_MeanExt(input_x, axis, keep_dims, dtype):
1370
+ """Infer value for MeanExt op."""
1371
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'MeanExt')
1372
+
1373
+
1374
+ def infer_value_for_SumExt(input_x, axis, keep_dims, dtype):
1375
+ """Infer value for SumExt op."""
1376
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'SumExt')
1377
+
1378
+
1379
+ def infer_value_for_ProdExt(input_x, axis, keep_dims, dtype):
1380
+ """Infer value for ProdExt op."""
1381
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'ProdExt')
1382
+
1383
+
1308
1384
  def infer_value_for_Diag(input_x):
1309
1385
  """Infer value for Diag op."""
1310
1386
  if input_x is None:
@@ -1389,3 +1465,252 @@ def infer_value_for_Reshape(x, shape):
1389
1465
  else:
1390
1466
  out = Tensor(x.asnumpy().reshape(shape))
1391
1467
  return out
1468
+
1469
+
1470
+ class Ones(Primitive):
1471
+ r"""
1472
+ Creates a tensor filled with value ones.
1473
+
1474
+ Refer to :func:`mindspore.ops.ones` for more details.
1475
+
1476
+ .. warning::
1477
+ For argument `size`, Tensor type input will be deprecated in the future version.
1478
+
1479
+ Inputs:
1480
+ - **shape** (Union[tuple[int], List[int], int, Tensor]) - The specified shape of output tensor.
1481
+ - **type** (:class:`mindspore.dtype`) - The specified type of output tensor.
1482
+
1483
+ Outputs:
1484
+ Tensor, whose dtype and size are defined by input.
1485
+
1486
+ Raises:
1487
+ TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1488
+
1489
+ Supported Platforms:
1490
+ ``Ascend`` ``GPU`` ``CPU``
1491
+
1492
+ Examples:
1493
+ >>> import mindspore
1494
+ >>> from mindspore import ops
1495
+ >>> ones = ops.Ones()
1496
+ >>> output = ones((2, 2), mindspore.float32)
1497
+ >>> print(output)
1498
+ [[1. 1.]
1499
+ [1. 1.]]
1500
+ >>> output = ones((3, 3), mindspore.float32)
1501
+ >>> print(output)
1502
+ [[1. 1. 1.]
1503
+ [1. 1. 1.]
1504
+ [1. 1. 1.]]
1505
+ """
1506
+
1507
+ __mindspore_signature__ = (
1508
+ sig.make_sig('size'),
1509
+ sig.make_sig('type', default=None),
1510
+ )
1511
+
1512
+ @prim_arg_register
1513
+ def __init__(self):
1514
+ pass
1515
+
1516
+ def __call__(self, size, type=None):
1517
+ return _convert_stub(pyboost_ones(self, [size, type if type is None \
1518
+ else handler.dtype_to_type_id('Ones', 'type', type)]))
1519
+
1520
+
1521
+ class Zeros(Primitive):
1522
+ r"""
1523
+ Zeros will be deprecated in the future. Please use class `mindspore.ops.zeros` instead.
1524
+
1525
+ Creates a tensor filled with value zeros.
1526
+
1527
+ Creates a tensor with shape described by the first argument and
1528
+ fills it with value zeros in type of the second argument.
1529
+
1530
+ .. warning::
1531
+ For argument `size`, Tensor type input will be deprecated in the future version.
1532
+
1533
+ Inputs:
1534
+ - **shape** (tuple[int], List[int], int, Tensor) - The specified shape of output tensor.
1535
+ - **type** (mindspore.dtype) - The specified type of output tensor.
1536
+
1537
+ Outputs:
1538
+ Tensor, whose dtype and size are defined by input.
1539
+
1540
+ Raises:
1541
+ TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1542
+
1543
+ Supported Platforms:
1544
+ ``Ascend`` ``GPU`` ``CPU``
1545
+
1546
+ Examples:
1547
+ >>> import mindspore
1548
+ >>> from mindspore import ops
1549
+ >>> zeros = ops.Zeros()
1550
+ >>> output = zeros((2, 2), mindspore.float32)
1551
+ >>> print(output)
1552
+ [[0. 0.]
1553
+ [0. 0.]]
1554
+
1555
+ """
1556
+
1557
+ __mindspore_signature__ = (
1558
+ sig.make_sig('size'),
1559
+ sig.make_sig('type', default=None),
1560
+ )
1561
+
1562
+ @prim_arg_register
1563
+ def __init__(self):
1564
+ pass
1565
+
1566
+ def __call__(self, size, type=None):
1567
+ return _convert_stub(pyboost_zeros(self, [size, type if type is None else \
1568
+ handler.dtype_to_type_id('Zeros', 'type', type)]))
1569
+
1570
+
1571
+ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
1572
+ attn_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, keep_prob=1.0,
1573
+ scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
1574
+ input_layout='BSH', sparse_mode=0):
1575
+ r"""
1576
+ The interface is not open to the public, just for internal use,
1577
+
1578
+ .. math::
1579
+ \begin{array}{ll} \\
1580
+ y = Dropout(Softmax(Mask(scale_value \mul (real_shift + query * key), attn_mask), -1), keep_prob) \\
1581
+ \mul value \\
1582
+ \end{array}
1583
+
1584
+ B -- Batch size. Value range 1 to 2k.
1585
+ S1 -- Sequence length of query. Value range 1 to 512k.
1586
+ S2 -- Sequence length of key and value. Value range 1 to 512k.
1587
+ N1 -- Num heads of query. Value range 1 to 256.
1588
+ N2 -- Num heads of key and value, and N2 must be a factor of N1.
1589
+ D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
1590
+ H1 -- Hidden size of query, which equals to N1 * D.
1591
+ H2 -- Hidden size of key and value, which equals to N2 * D.
1592
+
1593
+ .. warning::
1594
+ This is an experimental API that is subject to change or deletion. Only support on Atlas training series.
1595
+
1596
+ Args:
1597
+ query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
1598
+ `(B, N1, S1, D)`, `(S1, B, H1)`, `(B, S1, N1, D)` or `(T1, N1, D)`.
1599
+ key (Tensor[float16, bfloat16]): The key tensor. Input tensor of shape :math:`(B, S2, H2)`,
1600
+ `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`.
1601
+ value (Tensor[float16, bfloat16]): The value tensor. Input tensor of shape :math:`(B, S2, H2)`,
1602
+ `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. The key and value have the same shape.
1603
+ head_num (int): The head num of query, equal to N1.
1604
+ real_shift (Union[Tensor[float16, bfloat16], None]): Also known as pse. The position embedding code. If S
1605
+ is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of
1606
+ the lower triangle for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
1607
+ `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1608
+
1609
+ - ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
1610
+ In this scenario, real_shift is `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1611
+ - Non-ALiBi scenario: real_shift is `(B, N1, S1, S2)`, `(1, N1, S1, S2)`.
1612
+
1613
+ The shape of `real_shift` should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)` when input_layout is
1614
+ `TND`.
1615
+ drop_mask (Union[Tensor[uint8], None]): The dropout mask tensor. Input tensor of shape :math:
1616
+ `(B, N1, S1, S2 // 8) or None`. S2 is a multiple of 8 when not None.
1617
+ padding_mask (None): Reserved parameter. Not implemented yet.
1618
+ attn_mask (Union[Tensor[uint8], Tensor[bool], None]): The attention mask tensor. For each element, 0
1619
+ indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
1620
+ `(B, 1, S1, S2)`, `(S1, S2)` or `(2048, 2048)`. In compression scenario, sparse_mode is 2, 3, or 4,
1621
+ attn_mask must be `(2048, 2048)`. When sparse_mode is 5, attn_mask must be `(B, N1, S1, S2)`,
1622
+ `(B, 1, S1, S2)`. When sparse_mode is 0 and 1, attn_mask should be `(B, N1, S1, S2)`, `(B, 1, S1, S2)`,
1623
+ `(S1, S2)`.
1624
+ prefix (Union[List[int64], Tuple[int64] None]): N value of each Batch in the prefix sparse calculation
1625
+ scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when sparse_mode is 5.
1626
+ If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
1627
+ actual_seq_qlen (Union[List[int64], Tuple[int64], None]): Size of query corresponding to each batch, array
1628
+ with increasing values and the last value equal to T1.
1629
+ actual_seq_kvlen (Union[List[int64], Tuple[int64], None]): Size of key and value corresponding to each batch,
1630
+ array with increasing values and the last value equal to T2.
1631
+ keep_prob (float): The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0. when keep_prob
1632
+ is 1.0, drop_mask should be none.
1633
+ scale_value (float): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0.
1634
+ pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
1635
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1636
+ next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
1637
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1638
+ The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the
1639
+ valid area on the attn_mask matrix. It must ensure that the band is not empty.
1640
+ The following values are not allowed:
1641
+
1642
+ - pre_tokens < 0 and next_tokens < 0.
1643
+ - (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
1644
+ - (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
1645
+
1646
+ inner_precise (int): The parameter is reserved and not implemented yet. Default: 0.
1647
+ input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH", "BNSD",
1648
+ "SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH".
1649
+ When input_layout is "TND", the following restrictions must be met.
1650
+ There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
1651
+ value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
1652
+ list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
1653
+ sum(list_seq_k) = 22.
1654
+ max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
1655
+ qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
1656
+
1657
+ - The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024.
1658
+ - When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none.
1659
+ Otherwise, they are none.
1660
+ - The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
1661
+ be non-decreasing.
1662
+ - If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
1663
+ list_seq_k is greater than 1024. Real_shift should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)`, and
1664
+ S2 is equal to max_seqlen_k.
1665
+ - Attn mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask
1666
+ should be `(2048, 2048)`.
1667
+ - The shape of drop_mask is (qk_pointer * N1 // 8,).
1668
+ - Prefix is none.
1669
+ - Next_tokens is 0, and pre_tokens is not less than max_seqlen_q.
1670
+ - When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
1671
+ - 0 should not exist in list_seq_k.
1672
+
1673
+ sparse_mode (int): Indicates sparse mode. Default 0.
1674
+
1675
+ - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
1676
+ and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full
1677
+ attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and
1678
+ nextTokens needs to be calculated.
1679
+ - 1: Represents allMask, that is, passing in the complete attn_mask matrix.
1680
+ - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
1681
+ vertex, and the optimized attn_mask matrix (2048*2048) is required.
1682
+ - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
1683
+ right vertex, and the optimized attn_mask matrix (2048*2048) is required.
1684
+ - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
1685
+ optimized attn_mask matrix (2048*2048) is required.
1686
+ - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
1687
+ width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
1688
+ of each Batch axis is different, not implemented yet.
1689
+ - 6: Represents the global scenario, not implemented yet.
1690
+ - 7: Represents the dilated scenario, not implemented yet.
1691
+ - 8: Represents the block_local scenario, not implemented yet.
1692
+
1693
+ Returns:
1694
+ attention_out (Tensor[float16, bfloat16]), The output of attention, its shape, and data type are the same
1695
+ as the query.
1696
+
1697
+ Supported Platforms:
1698
+ ``Ascend``
1699
+
1700
+ Examples:
1701
+ >>> import mindspore
1702
+ >>> import mindspore.common.dtype as mstype
1703
+ >>> import numpy as np
1704
+ >>> from mindspore import ops, Tensor
1705
+ >>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
1706
+ >>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
1707
+ >>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
1708
+ >>> head_num = 4
1709
+ >>> output = ops.flash_attention_score(query, key, value, head_num)
1710
+ >>> print(output.shape)
1711
+ (2, 4, 64)
1712
+ """
1713
+ rank_op = _get_cache_prim(FlashAttentionScore)(head_num, keep_prob, scalar_value, pre_tokens, next_tokens,
1714
+ inner_precise, input_layout, sparse_mode)
1715
+ return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen,
1716
+ actual_seq_kvlen)[3]
@@ -30,7 +30,7 @@ from mindspore.ops._utils import get_broadcast_shape
30
30
  from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
31
31
  from mindspore._c_expression import Tensor as Tensor_
32
32
  from ..auto_generate import (Add, Addcdiv, Addcmul, ReduceMean, ReduceSum, ReduceAll, ReduceAny,
33
- ReduceMax, ReduceMin, ReduceProd, Betainc, Neg,
33
+ ReduceMax, ReduceMin, ReduceProd, Betainc, Neg, MatMul, BatchMatMul,
34
34
  Mul, Square, Rsqrt, Sqrt, Reciprocal, Pow, Exp,
35
35
  Logit, ReduceStd, Expm1, Log, Log1p, Erf, Erfc,
36
36
  Minimum, RealDiv, FloorDiv, Floor, FloorMod, Ceil,
@@ -39,8 +39,8 @@ from ..auto_generate import (Add, Addcdiv, Addcmul, ReduceMean, ReduceSum, Reduc
39
39
  LogicalXor, Cos, ACos, Sin, Asin, Abs, Round, Atan, Atanh, Atan2,
40
40
  LinSpace, MatrixDeterminant, LogMatrixDeterminant, Erfinv, Conj,
41
41
  Real, Complex, Angle, MatrixExp, CholeskyInverse, Trace, Cholesky,
42
- FFTWithSize, NextAfter, NanToNum, Eig, Qr, Roll, Maximum, Div, CumProd,
43
- CumSum, Less, LessEqual, AssignAdd)
42
+ FFTWithSize, NextAfter, NanToNum, Eig, Qr, Roll, Maximum, Div, DivMod, CumProd,
43
+ CumSum, Less, LessEqual, AssignAdd, IsFinite, TanhGrad)
44
44
 
45
45
 
46
46
  def _infer_shape_reduce(x, axis, keep_dims, prim_name):
@@ -716,130 +716,6 @@ class LpNorm(Primitive):
716
716
  self.init_prim_io_names(inputs=['input'], outputs=['output'])
717
717
 
718
718
 
719
- class MatMul(Primitive):
720
- r"""
721
- Multiplies matrix `a` and matrix `b`.
722
-
723
- .. math::
724
-
725
- (Output)_{i j}=\sum_{k=1}^{p} a_{i k} b_{k j}=a_{i 1} b_{1 j}+a_{i 2} b_{2 j}+\cdots+a_{i p} b_{p j}, p\in N
726
-
727
- where the :math:`i,j` indicates the output of the i-th row and j-th column element.
728
-
729
- Note:
730
- - If :math:`N * M` cannot be divided by 16, the performance will be poor in ascend environment.
731
- - The dtype of inputs must be same.
732
- - On Ascend, float64 doesn't be supported.
733
-
734
- Args:
735
- transpose_a (bool): If ``True`` , `a` is transposed before multiplication. Default: ``False`` .
736
- transpose_b (bool): If ``True`` , `b` is transposed before multiplication. Default: ``False`` .
737
-
738
- Inputs:
739
- - **a** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If
740
- `transpose_a` is ``True`` , its shape must be :math:`(C, N)` after transpose.
741
- - **b** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If
742
- `transpose_b` is ``True`` , its shape must be :math:`(M, C)` after transpose.
743
-
744
- Outputs:
745
- Tensor, the shape of the output tensor is :math:`(N, M)`.
746
-
747
- Raises:
748
- TypeError: If `transpose_a` or `transpose_b` is not a bool.
749
- TypeError: If the dtype of `a` and the dtype of `b` are not the same.
750
- ValueError: If the column of matrix dimensions of `a` is not equal to
751
- the row of matrix dimensions of `b`.
752
- ValueError: If length of shape of `a` or `b` is not equal to 2.
753
-
754
- Supported Platforms:
755
- ``Ascend`` ``GPU`` ``CPU``
756
-
757
- Examples:
758
- >>> import mindspore
759
- >>> import numpy as np
760
- >>> from mindspore import Tensor, ops
761
- >>> a = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
762
- >>> b = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
763
- >>> matmul = ops.MatMul()
764
- >>> output = matmul(a, b)
765
- >>> print(output)
766
- [[3. 3. 3. 3.]]
767
- """
768
-
769
- @prim_attr_register
770
- def __init__(self, transpose_a=False, transpose_b=False):
771
- """Initialize MatMul."""
772
- self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
773
- cls_name = self.name
774
- validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
775
- validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
776
- self.add_prim_attr('transpose_x1', self.transpose_a)
777
- self.add_prim_attr('transpose_x2', self.transpose_b)
778
-
779
-
780
- class BatchMatMul(Primitive):
781
- r"""
782
- Computes matrix multiplication between two tensors by batch.
783
-
784
- .. math::
785
-
786
- \text{output}[..., :, :] = \text{matrix}(x[..., :, :]) * \text{matrix}(y[..., :, :])
787
-
788
- The rank of both two input tensors must be same and not less than `2`.
789
-
790
- Args:
791
- transpose_a (bool): If ``True`` , the last two dimensions of `x` is transposed before multiplication.
792
- Default: ``False`` .
793
- transpose_b (bool): If ``True`` , the last two dimensions of `y` is transposed before multiplication.
794
- Default: ``False`` .
795
-
796
- Inputs:
797
- - **x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
798
- where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
799
- size of the last two dimensions. If `transpose_a` is ``True`` , its shape must be :math:`(*B, C, N)`.
800
- - **y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`. If
801
- `transpose_b` is ``True`` , its shape must be :math:`(*B, M, C)`.
802
-
803
- Outputs:
804
- Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
805
-
806
- Raises:
807
- TypeError: If `transpose_a` or `transpose_b` is not a bool.
808
- ValueError: If length of shape of `x` is not equal to length of shape of `y` or
809
- length of shape of inputs is less than 2.
810
-
811
- Supported Platforms:
812
- ``Ascend`` ``GPU`` ``CPU``
813
-
814
- Examples:
815
- >>> import mindspore
816
- >>> import numpy as np
817
- >>> from mindspore import Tensor, ops
818
- >>> x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
819
- >>> y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
820
- >>> batmatmul = ops.BatchMatMul()
821
- >>> output = batmatmul(x, y)
822
- >>> print(output.shape)
823
- (2, 4, 1, 4)
824
- >>> x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
825
- >>> y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
826
- >>> batmatmul = ops.BatchMatMul(transpose_a=True)
827
- >>> output = batmatmul(x, y)
828
- >>> print(output.shape)
829
- (2, 4, 1, 4)
830
- """
831
-
832
- @prim_attr_register
833
- def __init__(self, transpose_a=False, transpose_b=False):
834
- """Initialize BatchMatMul."""
835
- self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
836
- cls_name = self.name
837
- validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
838
- validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
839
- self.add_prim_attr('adj_x1', self.transpose_a)
840
- self.add_prim_attr('adj_x2', self.transpose_b)
841
-
842
-
843
719
  class AddN(Primitive):
844
720
  """
845
721
  Computes addition of all input tensors element-wise.
@@ -1165,8 +1041,8 @@ class Sub(_MathBinaryOp):
1165
1041
  Inputs:
1166
1042
  - **x** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
1167
1043
  a bool or a tensor whose data type is
1168
- `number <https://www.mindspore.cn/docs/en/r2.3.q1/api_python/mindspore.html#mindspore.dtype>`_ or
1169
- `bool_ <https://www.mindspore.cn/docs/en/r2.3.q1/api_python/mindspore.html#mindspore.dtype>`_.
1044
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
1045
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
1170
1046
  - **y** (Union[Tensor, number.Number, bool]) - The second input, when the first input is a Tensor,
1171
1047
  the second input should be a number.Number or bool value, or a Tensor whose data type is number or bool.
1172
1048
 
@@ -1507,7 +1383,7 @@ class Heaviside(Primitive):
1507
1383
  0, & \text { if x }<0 \\
1508
1384
  \text { values, } & \text { if x }==0 \\
1509
1385
  1, & \text { if x }>0
1510
- \end{array}\right.
1386
+ \end{array}\right
1511
1387
 
1512
1388
  .. warning::
1513
1389
  This is an experimental API that is subject to change or deletion.
@@ -1564,8 +1440,8 @@ class DivNoNan(Primitive):
1564
1440
  Inputs:
1565
1441
  - **x1** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
1566
1442
  a bool or a tensor whose data type is
1567
- `number <https://www.mindspore.cn/docs/en/r2.3.q1/api_python/mindspore.html#mindspore.dtype>`_ or
1568
- `bool_ <https://www.mindspore.cn/docs/en/r2.3.q1/api_python/mindspore.html#mindspore.dtype>`_.
1443
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
1444
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
1569
1445
  - **x2** (Union[Tensor, number.Number, bool]) - The second input is a number.Number or
1570
1446
  a bool when the first input is a bool or a tensor whose data type is number or bool\_.
1571
1447
  When the first input is Scalar, the second input must be a Tensor whose data type is number or bool\_.
@@ -1937,8 +1813,8 @@ class Xlogy(Primitive):
1937
1813
  Inputs:
1938
1814
  - **x** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
1939
1815
  a bool or a tensor whose data type is
1940
- `number <https://www.mindspore.cn/docs/en/r2.3.q1/api_python/mindspore.html#mindspore.dtype>`_ or
1941
- `bool_ <https://www.mindspore.cn/docs/en/r2.3.q1/api_python/mindspore.html#mindspore.dtype>`_.
1816
+ `number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
1817
+ `bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
1942
1818
  - **y** (Union[Tensor, number.Number, bool]) - The second input is a number.Number or
1943
1819
  a bool when the first input is a tensor or a tensor whose data type is number or bool\_.
1944
1820
  When the first input is Scalar, the second input must be a Tensor whose data type is number or bool\_.
@@ -2238,32 +2114,6 @@ class IsInf(Primitive):
2238
2114
  self.init_prim_io_names(inputs=['x'], outputs=['output'])
2239
2115
 
2240
2116
 
2241
- class IsFinite(Primitive):
2242
- r"""
2243
- Determines which elements are finite for each position.
2244
-
2245
- Refer to :func:`mindspore.ops.isfinite` for more details.
2246
-
2247
- Supported Platforms:
2248
- ``Ascend`` ``GPU`` ``CPU``
2249
-
2250
- Examples:
2251
- >>> import mindspore
2252
- >>> import numpy as np
2253
- >>> from mindspore import Tensor, ops
2254
- >>> is_finite = ops.IsFinite()
2255
- >>> x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32)
2256
- >>> output = is_finite(x)
2257
- >>> print(output)
2258
- [False True False]
2259
- """
2260
-
2261
- @prim_attr_register
2262
- def __init__(self):
2263
- """Initialize IsFinite"""
2264
- self.init_prim_io_names(inputs=['x'], outputs=['output'])
2265
-
2266
-
2267
2117
  class FloatStatus(Primitive):
2268
2118
  """
2269
2119
  Determines if the elements contain Not a Number(NaN), infinite or negative infinite. 0 for normal, 1 for overflow.
@@ -2362,7 +2212,7 @@ class NPUGetFloatStatus(Primitive):
2362
2212
  >>> import mindspore.nn as nn
2363
2213
  >>> from mindspore import ops
2364
2214
  >>> from mindspore import dtype as mstype
2365
- >>> from mindspore.common.tensor import Tensor
2215
+ >>> from mindspore import Tensor
2366
2216
  >>> class Net(nn.Cell):
2367
2217
  ... def __init__(self):
2368
2218
  ... super().__init__()
@@ -2427,7 +2277,7 @@ class NPUClearFloatStatus(Primitive):
2427
2277
  >>> import mindspore.nn as nn
2428
2278
  >>> from mindspore import ops
2429
2279
  >>> from mindspore import dtype as mstype
2430
- >>> from mindspore.common.tensor import Tensor
2280
+ >>> from mindspore import Tensor
2431
2281
  >>> class Net(nn.Cell):
2432
2282
  ... def __init__(self):
2433
2283
  ... super().__init__()
@@ -2790,7 +2640,7 @@ class SquareSumAll(Primitive):
2790
2640
  .. math::
2791
2641
  \left\{\begin{matrix}out_{x} = {\textstyle \sum_{0}^{N}} (x_{i})^2
2792
2642
  \\out_{y} = {\textstyle \sum_{0}^{N}} (y_{i})^2
2793
- \end{matrix}\right.
2643
+ \end{matrix}\right
2794
2644
 
2795
2645
  Note:
2796
2646
  SquareSumAll only supports float16 and float32 data type.