mindspore 2.2.14__cp38-cp38-manylinux1_x86_64.whl → 2.3.0rc2__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1172) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -4
  3. mindspore/_akg/akg/composite/build_module.py +155 -11
  4. mindspore/_akg/akg/config/repository.json +38 -0
  5. mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
  7. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
  8. mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
  9. mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
  10. mindspore/_akg/akg/utils/gen_random.py +14 -8
  11. mindspore/_akg/akg/utils/op_dsl.py +11 -0
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +18 -8
  13. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  14. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +78 -0
  17. mindspore/_extends/builtin_operations.py +2 -1
  18. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  19. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  20. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  24. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  25. mindspore/_extends/parse/__init__.py +18 -14
  26. mindspore/_extends/parse/compile_config.py +229 -0
  27. mindspore/_extends/parse/parser.py +155 -59
  28. mindspore/_extends/parse/resources.py +40 -7
  29. mindspore/_extends/parse/standard_method.py +127 -206
  30. mindspore/_extends/remote/kernel_build_server.py +2 -0
  31. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _profiler.py} +13 -16
  33. mindspore/amp.py +24 -18
  34. mindspore/bin/cache_admin +0 -0
  35. mindspore/bin/cache_server +0 -0
  36. mindspore/boost/boost_cell_wrapper.py +1 -1
  37. mindspore/boost/group_loss_scale_manager.py +1 -1
  38. mindspore/common/__init__.py +7 -3
  39. mindspore/common/_jit_fallback_utils.py +2 -3
  40. mindspore/common/_register_for_adapter.py +7 -0
  41. mindspore/common/_register_for_recompute.py +48 -0
  42. mindspore/common/_stub_tensor.py +7 -1
  43. mindspore/common/_utils.py +5 -17
  44. mindspore/common/api.py +145 -50
  45. mindspore/common/auto_dynamic_shape.py +27 -14
  46. mindspore/common/dtype.py +9 -6
  47. mindspore/common/dump.py +5 -4
  48. mindspore/common/hook_handle.py +51 -4
  49. mindspore/common/initializer.py +1 -1
  50. mindspore/common/jit_config.py +33 -13
  51. mindspore/common/lazy_inline.py +58 -17
  52. mindspore/common/mindir_util.py +12 -2
  53. mindspore/common/mutable.py +79 -14
  54. mindspore/common/parameter.py +24 -4
  55. mindspore/common/recompute.py +247 -0
  56. mindspore/common/seed.py +9 -9
  57. mindspore/common/sparse_tensor.py +251 -18
  58. mindspore/common/symbol.py +122 -0
  59. mindspore/common/tensor.py +391 -465
  60. mindspore/communication/__init__.py +3 -3
  61. mindspore/communication/_comm_helper.py +5 -0
  62. mindspore/communication/management.py +53 -38
  63. mindspore/config/op_info.config +22 -54
  64. mindspore/context.py +176 -55
  65. mindspore/dataset/__init__.py +5 -5
  66. mindspore/dataset/audio/__init__.py +6 -6
  67. mindspore/dataset/audio/transforms.py +711 -158
  68. mindspore/dataset/callback/ds_callback.py +2 -2
  69. mindspore/dataset/engine/cache_client.py +2 -2
  70. mindspore/dataset/engine/datasets.py +72 -38
  71. mindspore/dataset/engine/datasets_audio.py +14 -14
  72. mindspore/dataset/engine/datasets_standard_format.py +33 -3
  73. mindspore/dataset/engine/datasets_text.py +38 -38
  74. mindspore/dataset/engine/datasets_user_defined.py +7 -7
  75. mindspore/dataset/engine/datasets_vision.py +75 -71
  76. mindspore/dataset/engine/offload.py +5 -7
  77. mindspore/dataset/text/__init__.py +3 -3
  78. mindspore/dataset/text/transforms.py +408 -121
  79. mindspore/dataset/text/utils.py +9 -9
  80. mindspore/dataset/transforms/__init__.py +1 -1
  81. mindspore/dataset/transforms/transforms.py +261 -76
  82. mindspore/dataset/utils/browse_dataset.py +9 -9
  83. mindspore/dataset/vision/__init__.py +3 -3
  84. mindspore/dataset/vision/c_transforms.py +5 -5
  85. mindspore/dataset/vision/transforms.py +2264 -514
  86. mindspore/dataset/vision/utils.py +40 -9
  87. mindspore/dataset/vision/validators.py +7 -1
  88. mindspore/experimental/optim/__init__.py +12 -2
  89. mindspore/experimental/optim/adadelta.py +161 -0
  90. mindspore/experimental/optim/adagrad.py +168 -0
  91. mindspore/experimental/optim/adam.py +35 -34
  92. mindspore/experimental/optim/adamax.py +170 -0
  93. mindspore/experimental/optim/adamw.py +40 -16
  94. mindspore/experimental/optim/asgd.py +153 -0
  95. mindspore/experimental/optim/lr_scheduler.py +66 -121
  96. mindspore/experimental/optim/nadam.py +157 -0
  97. mindspore/experimental/optim/optimizer.py +15 -8
  98. mindspore/experimental/optim/radam.py +194 -0
  99. mindspore/experimental/optim/rmsprop.py +154 -0
  100. mindspore/experimental/optim/rprop.py +164 -0
  101. mindspore/experimental/optim/sgd.py +28 -19
  102. mindspore/hal/__init__.py +34 -0
  103. mindspore/hal/_ascend.py +57 -0
  104. mindspore/hal/_base.py +57 -0
  105. mindspore/hal/_cpu.py +56 -0
  106. mindspore/hal/_gpu.py +57 -0
  107. mindspore/hal/device.py +356 -0
  108. mindspore/hal/event.py +179 -0
  109. mindspore/hal/stream.py +339 -0
  110. mindspore/include/api/data_type.h +2 -2
  111. mindspore/include/api/dual_abi_helper.h +16 -3
  112. mindspore/include/api/model.h +1 -3
  113. mindspore/include/api/status.h +14 -0
  114. mindspore/include/c_api/model_c.h +173 -0
  115. mindspore/include/c_api/ms/base/types.h +1 -0
  116. mindspore/include/c_api/types_c.h +19 -0
  117. mindspore/include/dataset/execute.h +1 -3
  118. mindspore/include/mindapi/base/format.h +125 -23
  119. mindspore/include/mindapi/base/types.h +12 -0
  120. mindspore/lib/libdnnl.so.2 +0 -0
  121. mindspore/lib/libmindspore.so +0 -0
  122. mindspore/lib/libmindspore_backend.so +0 -0
  123. mindspore/lib/libmindspore_common.so +0 -0
  124. mindspore/lib/libmindspore_core.so +0 -0
  125. mindspore/lib/libmindspore_glog.so.0 +0 -0
  126. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  127. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  128. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  129. mindspore/lib/libmindspore_shared_lib.so +0 -0
  130. mindspore/lib/libmpi_adapter.so +0 -0
  131. mindspore/lib/libmpi_collective.so +0 -0
  132. mindspore/lib/libnnacl.so +0 -0
  133. mindspore/lib/libopencv_core.so.4.5 +0 -0
  134. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  135. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  136. mindspore/lib/libps_cache.so +0 -0
  137. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
  138. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
  139. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
  140. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
  141. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
  142. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
  143. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  144. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
  145. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  146. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  147. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6318 -1760
  148. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  149. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
  150. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
  151. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
  152. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
  153. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
  154. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
  155. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
  156. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
  157. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
  158. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
  159. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
  160. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
  161. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
  162. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
  163. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
  164. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
  165. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
  166. mindspore/lib/plugin/ascend/libakg.so +0 -0
  167. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  168. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  169. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  170. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  171. mindspore/lib/plugin/cpu/libakg.so +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  173. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  174. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  175. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  176. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  177. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  180. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  181. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  182. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  183. mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
  184. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  185. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  186. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  187. mindspore/log.py +2 -2
  188. mindspore/mindrecord/__init__.py +5 -1
  189. mindspore/mindrecord/config.py +809 -0
  190. mindspore/mindrecord/filereader.py +25 -0
  191. mindspore/mindrecord/filewriter.py +74 -56
  192. mindspore/mindrecord/mindpage.py +40 -6
  193. mindspore/mindrecord/shardutils.py +3 -2
  194. mindspore/mindrecord/shardwriter.py +7 -0
  195. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  196. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  197. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  198. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  199. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  200. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  201. mindspore/mint/__init__.py +457 -0
  202. mindspore/mint/nn/__init__.py +430 -0
  203. mindspore/mint/nn/functional.py +424 -0
  204. mindspore/mint/optim/__init__.py +24 -0
  205. mindspore/mint/optim/adamw.py +186 -0
  206. mindspore/multiprocessing/__init__.py +72 -0
  207. mindspore/nn/__init__.py +3 -0
  208. mindspore/nn/cell.py +131 -174
  209. mindspore/nn/dynamic_lr.py +2 -2
  210. mindspore/nn/extend/__init__.py +29 -0
  211. mindspore/nn/extend/basic.py +140 -0
  212. mindspore/nn/extend/embedding.py +143 -0
  213. mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
  214. mindspore/nn/extend/layer/normalization.py +107 -0
  215. mindspore/nn/extend/pooling.py +117 -0
  216. mindspore/nn/generator.py +297 -0
  217. mindspore/nn/layer/activation.py +79 -90
  218. mindspore/nn/layer/basic.py +113 -81
  219. mindspore/nn/layer/channel_shuffle.py +3 -16
  220. mindspore/nn/layer/container.py +3 -3
  221. mindspore/nn/layer/conv.py +71 -71
  222. mindspore/nn/layer/embedding.py +105 -44
  223. mindspore/nn/layer/image.py +4 -7
  224. mindspore/nn/layer/normalization.py +52 -66
  225. mindspore/nn/layer/padding.py +30 -39
  226. mindspore/nn/layer/pooling.py +13 -9
  227. mindspore/nn/layer/rnn_cells.py +5 -15
  228. mindspore/nn/layer/rnns.py +6 -5
  229. mindspore/nn/layer/thor_layer.py +1 -2
  230. mindspore/nn/layer/timedistributed.py +1 -1
  231. mindspore/nn/layer/transformer.py +52 -50
  232. mindspore/nn/learning_rate_schedule.py +6 -5
  233. mindspore/nn/loss/loss.py +43 -64
  234. mindspore/nn/optim/ada_grad.py +4 -2
  235. mindspore/nn/optim/adadelta.py +3 -1
  236. mindspore/nn/optim/adafactor.py +1 -1
  237. mindspore/nn/optim/adam.py +102 -181
  238. mindspore/nn/optim/adamax.py +4 -2
  239. mindspore/nn/optim/adasum.py +2 -2
  240. mindspore/nn/optim/asgd.py +4 -2
  241. mindspore/nn/optim/ftrl.py +31 -61
  242. mindspore/nn/optim/lamb.py +5 -3
  243. mindspore/nn/optim/lars.py +2 -2
  244. mindspore/nn/optim/lazyadam.py +6 -4
  245. mindspore/nn/optim/momentum.py +13 -25
  246. mindspore/nn/optim/optimizer.py +6 -3
  247. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  248. mindspore/nn/optim/rmsprop.py +9 -3
  249. mindspore/nn/optim/rprop.py +4 -2
  250. mindspore/nn/optim/sgd.py +6 -5
  251. mindspore/nn/optim/thor.py +2 -2
  252. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  253. mindspore/nn/probability/distribution/beta.py +2 -2
  254. mindspore/nn/probability/distribution/categorical.py +4 -6
  255. mindspore/nn/probability/distribution/cauchy.py +2 -2
  256. mindspore/nn/probability/distribution/exponential.py +1 -1
  257. mindspore/nn/probability/distribution/gumbel.py +2 -2
  258. mindspore/nn/probability/distribution/poisson.py +2 -2
  259. mindspore/nn/probability/distribution/uniform.py +2 -2
  260. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  261. mindspore/nn/wrap/__init__.py +2 -1
  262. mindspore/nn/wrap/cell_wrapper.py +33 -12
  263. mindspore/nn/wrap/grad_reducer.py +148 -8
  264. mindspore/nn/wrap/loss_scale.py +7 -7
  265. mindspore/numpy/__init__.py +2 -0
  266. mindspore/numpy/array_creations.py +2 -0
  267. mindspore/numpy/array_ops.py +1 -5
  268. mindspore/numpy/fft.py +431 -0
  269. mindspore/numpy/math_ops.py +54 -60
  270. mindspore/numpy/utils.py +3 -0
  271. mindspore/ops/__init__.py +5 -4
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +14 -18
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +137 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +101 -57
  295. mindspore/ops/_vmap/vmap_nn_ops.py +230 -97
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +205 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +257 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +171 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +404 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +5653 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +11623 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +359 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +118 -17
  307. mindspore/ops/composite/math_ops.py +9 -48
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +168 -602
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +24 -133
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/extend/__init__.py +54 -0
  340. mindspore/ops/extend/array_func.py +259 -0
  341. mindspore/ops/extend/math_func.py +76 -0
  342. mindspore/ops/extend/nn_func.py +384 -0
  343. mindspore/ops/function/__init__.py +37 -12
  344. mindspore/ops/function/array_func.py +702 -1867
  345. mindspore/ops/function/clip_func.py +19 -31
  346. mindspore/ops/function/debug_func.py +1 -4
  347. mindspore/ops/function/fft_func.py +31 -0
  348. mindspore/ops/function/grad/grad_func.py +24 -17
  349. mindspore/ops/function/image_func.py +27 -21
  350. mindspore/ops/function/linalg_func.py +35 -68
  351. mindspore/ops/function/math_func.py +639 -2531
  352. mindspore/ops/function/nn_func.py +1274 -832
  353. mindspore/ops/function/other_func.py +4 -5
  354. mindspore/ops/function/parameter_func.py +5 -93
  355. mindspore/ops/function/random_func.py +84 -71
  356. mindspore/ops/function/sparse_unary_func.py +9 -16
  357. mindspore/ops/function/spectral_func.py +1 -1
  358. mindspore/ops/function/vmap_func.py +14 -14
  359. mindspore/ops/functional.py +57 -63
  360. mindspore/ops/op_info_register.py +16 -43
  361. mindspore/ops/operations/__init__.py +19 -20
  362. mindspore/ops/operations/_grad_ops.py +20 -828
  363. mindspore/ops/operations/_inner_ops.py +180 -288
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/array_ops.py +83 -2697
  367. mindspore/ops/operations/comm_ops.py +38 -46
  368. mindspore/ops/operations/custom_ops.py +14 -96
  369. mindspore/ops/operations/debug_ops.py +100 -31
  370. mindspore/ops/operations/image_ops.py +1 -217
  371. mindspore/ops/operations/inner_ops.py +3 -38
  372. mindspore/ops/operations/linalg_ops.py +1 -49
  373. mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
  374. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  375. mindspore/ops/operations/manually_defined/ops_def.py +1716 -0
  376. mindspore/ops/operations/math_ops.py +581 -4629
  377. mindspore/ops/operations/nn_ops.py +260 -1941
  378. mindspore/ops/operations/other_ops.py +50 -42
  379. mindspore/ops/operations/random_ops.py +3 -52
  380. mindspore/ops/operations/sparse_ops.py +3 -3
  381. mindspore/ops/primitive.py +196 -96
  382. mindspore/ops_generate/__init__.py +27 -0
  383. mindspore/ops_generate/arg_dtype_cast.py +257 -0
  384. mindspore/ops_generate/arg_handler.py +171 -0
  385. mindspore/ops_generate/gen_aclnn_implement.py +266 -0
  386. mindspore/ops_generate/gen_ops.py +1062 -0
  387. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  388. mindspore/ops_generate/gen_pyboost_func.py +939 -0
  389. mindspore/ops_generate/gen_utils.py +188 -0
  390. mindspore/ops_generate/op_proto.py +138 -0
  391. mindspore/ops_generate/pyboost_utils.py +349 -0
  392. mindspore/ops_generate/template.py +238 -0
  393. mindspore/parallel/__init__.py +6 -4
  394. mindspore/parallel/_auto_parallel_context.py +52 -2
  395. mindspore/parallel/_cell_wrapper.py +16 -9
  396. mindspore/parallel/_cost_model_context.py +1 -1
  397. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  398. mindspore/parallel/_parallel_serialization.py +29 -13
  399. mindspore/parallel/_ps_context.py +1 -1
  400. mindspore/parallel/_recovery_context.py +1 -1
  401. mindspore/parallel/_tensor.py +19 -7
  402. mindspore/parallel/_transformer/__init__.py +1 -1
  403. mindspore/parallel/_transformer/layers.py +1 -1
  404. mindspore/parallel/_transformer/loss.py +1 -1
  405. mindspore/parallel/_transformer/moe.py +1 -1
  406. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  407. mindspore/parallel/_transformer/transformer.py +1 -1
  408. mindspore/parallel/_utils.py +147 -6
  409. mindspore/parallel/algo_parameter_config.py +6 -6
  410. mindspore/parallel/checkpoint_transform.py +180 -24
  411. mindspore/parallel/cluster/__init__.py +15 -0
  412. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  413. mindspore/parallel/cluster/process_entity/_api.py +345 -0
  414. mindspore/parallel/cluster/process_entity/_utils.py +116 -0
  415. mindspore/parallel/cluster/run.py +139 -0
  416. mindspore/parallel/mpi/__init__.py +1 -1
  417. mindspore/parallel/mpi/_mpi_config.py +1 -1
  418. mindspore/parallel/parameter_broadcast.py +152 -0
  419. mindspore/parallel/shard.py +99 -2
  420. mindspore/profiler/common/util.py +20 -0
  421. mindspore/profiler/envprofiling.py +1 -1
  422. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  423. mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
  424. mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
  425. mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
  426. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +109 -0
  427. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
  428. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
  429. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  430. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  431. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
  432. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  433. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  434. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  435. mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
  436. mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
  437. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  438. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  439. mindspore/profiler/parser/ascend_timeline_generator.py +294 -133
  440. mindspore/profiler/parser/base_timeline_generator.py +6 -0
  441. mindspore/profiler/parser/framework_parser.py +3 -2
  442. mindspore/profiler/parser/integrator.py +3 -1
  443. mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
  444. mindspore/profiler/parser/msadvisor_parser.py +1 -1
  445. mindspore/profiler/parser/profiler_info.py +16 -1
  446. mindspore/profiler/profiling.py +305 -167
  447. mindspore/rewrite/__init__.py +2 -13
  448. mindspore/rewrite/api/node.py +121 -35
  449. mindspore/rewrite/api/pattern_engine.py +2 -3
  450. mindspore/rewrite/api/scoped_value.py +16 -15
  451. mindspore/rewrite/api/symbol_tree.py +45 -29
  452. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  453. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  454. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  455. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  456. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  457. mindspore/rewrite/common/__init__.py +1 -2
  458. mindspore/rewrite/common/config.py +24 -0
  459. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  460. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  461. mindspore/rewrite/common/namespace.py +118 -0
  462. mindspore/rewrite/node/__init__.py +5 -5
  463. mindspore/rewrite/node/call_function.py +23 -7
  464. mindspore/rewrite/node/cell_container.py +7 -3
  465. mindspore/rewrite/node/control_flow.py +53 -28
  466. mindspore/rewrite/node/node.py +212 -196
  467. mindspore/rewrite/node/node_manager.py +51 -22
  468. mindspore/rewrite/node/node_topological_manager.py +3 -23
  469. mindspore/rewrite/parsers/__init__.py +12 -0
  470. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  471. mindspore/rewrite/parsers/assign_parser.py +635 -413
  472. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  473. mindspore/rewrite/parsers/class_def_parser.py +107 -144
  474. mindspore/rewrite/parsers/constant_parser.py +5 -5
  475. mindspore/rewrite/parsers/container_parser.py +4 -6
  476. mindspore/rewrite/parsers/expr_parser.py +55 -0
  477. mindspore/rewrite/parsers/for_parser.py +31 -98
  478. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  479. mindspore/rewrite/parsers/if_parser.py +28 -10
  480. mindspore/rewrite/parsers/module_parser.py +8 -182
  481. mindspore/rewrite/parsers/parser.py +1 -5
  482. mindspore/rewrite/parsers/parser_register.py +1 -1
  483. mindspore/rewrite/parsers/return_parser.py +5 -10
  484. mindspore/rewrite/parsers/while_parser.py +59 -0
  485. mindspore/rewrite/sparsify/utils.py +1 -1
  486. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  487. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
  488. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  489. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  490. mindspore/run_check/_check_version.py +6 -14
  491. mindspore/run_check/run_check.py +1 -1
  492. mindspore/safeguard/rewrite_obfuscation.py +9 -19
  493. mindspore/scipy/__init__.py +2 -1
  494. mindspore/scipy/fft.py +133 -0
  495. mindspore/scipy/linalg.py +140 -55
  496. mindspore/scipy/ops.py +15 -71
  497. mindspore/scipy/ops_grad.py +5 -34
  498. mindspore/scipy/optimize/line_search.py +2 -2
  499. mindspore/scipy/optimize/minimize.py +1 -1
  500. mindspore/train/__init__.py +3 -2
  501. mindspore/train/_utils.py +178 -4
  502. mindspore/train/amp.py +167 -245
  503. mindspore/train/anf_ir_pb2.py +8 -2
  504. mindspore/train/callback/_backup_and_restore.py +4 -4
  505. mindspore/train/callback/_callback.py +4 -4
  506. mindspore/train/callback/_checkpoint.py +39 -13
  507. mindspore/train/callback/_early_stop.py +2 -2
  508. mindspore/train/callback/_landscape.py +14 -8
  509. mindspore/train/callback/_loss_monitor.py +2 -2
  510. mindspore/train/callback/_on_request_exit.py +2 -2
  511. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  512. mindspore/train/callback/_summary_collector.py +7 -7
  513. mindspore/train/callback/_time_monitor.py +2 -2
  514. mindspore/train/data_sink.py +1 -1
  515. mindspore/train/dataset_helper.py +18 -4
  516. mindspore/train/loss_scale_manager.py +2 -2
  517. mindspore/train/metrics/accuracy.py +7 -7
  518. mindspore/train/metrics/confusion_matrix.py +8 -6
  519. mindspore/train/metrics/cosine_similarity.py +6 -4
  520. mindspore/train/metrics/error.py +2 -2
  521. mindspore/train/metrics/metric.py +3 -3
  522. mindspore/train/metrics/perplexity.py +2 -1
  523. mindspore/train/metrics/topk.py +2 -2
  524. mindspore/train/mind_ir_pb2.py +89 -15
  525. mindspore/train/model.py +24 -22
  526. mindspore/train/serialization.py +257 -133
  527. mindspore/train/summary/summary_record.py +51 -28
  528. mindspore/train/train_thor/convert_utils.py +3 -3
  529. mindspore/version.py +1 -1
  530. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
  531. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +534 -1066
  532. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +1 -0
  533. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  534. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  535. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  536. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  537. mindspore/config/super_bar_config.json +0 -544
  538. mindspore/gen_ops.py +0 -273
  539. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  540. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  541. mindspore/nn/layer/flash_attention.py +0 -189
  542. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  543. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  544. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  545. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  546. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  547. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  548. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  549. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  550. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  551. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  552. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  553. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  554. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  555. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  556. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  557. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  558. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  559. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  560. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  561. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  562. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  563. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  564. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  565. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  566. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  567. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  568. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  569. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  570. mindspore/ops/_op_impl/tbe/add.py +0 -42
  571. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  572. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  573. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  574. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  575. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  576. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  577. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  578. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  579. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  580. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  581. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  582. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  583. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  584. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  585. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  586. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  587. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  588. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  589. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  590. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  591. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  592. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  593. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  594. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  595. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  596. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  597. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  598. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  599. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  600. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  601. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  602. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  603. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  604. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  605. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  606. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  607. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  608. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  609. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  610. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  611. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  612. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  613. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  614. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  615. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  616. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  617. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  618. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  619. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  620. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  621. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  622. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  623. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  624. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  625. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  626. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  627. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  628. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  629. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  631. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  632. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  633. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  634. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  635. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  636. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  637. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  638. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  639. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  640. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  641. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  642. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  643. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  644. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  645. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  646. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  647. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  648. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  649. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  650. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  651. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  652. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  653. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  654. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  655. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  656. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  657. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  658. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  659. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  660. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  661. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  662. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  663. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  664. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  666. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  667. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  668. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  669. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  670. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  671. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  672. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  673. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  674. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  675. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  676. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  677. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  678. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  679. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  680. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  681. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  682. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  683. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  684. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  685. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  686. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  687. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  688. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  689. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  690. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  691. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  692. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  693. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  694. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  695. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  696. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  697. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  698. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  699. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  700. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  701. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  702. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  703. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  705. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  706. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  707. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  709. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  710. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  711. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  712. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  713. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  714. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  715. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  716. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  717. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  718. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  719. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  720. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  721. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  722. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  723. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  724. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  725. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  726. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  727. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  728. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  729. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  730. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  731. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  732. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  733. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  734. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  735. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  736. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  737. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  738. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  739. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  740. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  741. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  742. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  743. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  744. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  745. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  746. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  747. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  748. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  749. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  750. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  751. mindspore/ops/_op_impl/tbe/div.py +0 -41
  752. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  754. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  755. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  756. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  757. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  758. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  759. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  760. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  761. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  762. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  763. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  764. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  765. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  766. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  767. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  768. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  769. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  770. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  771. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  772. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  773. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  774. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  775. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  776. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  777. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  778. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  779. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  780. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  781. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  782. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  783. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  784. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  785. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  786. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  787. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  788. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  789. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  790. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  791. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  792. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  793. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  794. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  795. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  796. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  797. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  798. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  799. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  800. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  801. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  802. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  803. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  804. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  805. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  806. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  807. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  808. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  809. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  810. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  811. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  812. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  813. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  814. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  815. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  816. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  817. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  818. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  819. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  820. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  821. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  822. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  823. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  824. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  825. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  826. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  827. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  828. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  829. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  830. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  831. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  832. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  833. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  834. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  835. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  836. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  837. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  838. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  839. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  840. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  841. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  842. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  843. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  844. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  845. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  846. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  847. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  848. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  849. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  850. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  851. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  852. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  853. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  854. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  855. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  856. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  857. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  858. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  859. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  860. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  861. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  862. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  863. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  864. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  865. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  866. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  867. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  868. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  869. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  870. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  871. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  872. mindspore/ops/_op_impl/tbe/less.py +0 -41
  873. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  874. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  875. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  876. mindspore/ops/_op_impl/tbe/log.py +0 -40
  877. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  878. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  879. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  881. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  883. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  884. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  885. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  886. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  887. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  888. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  889. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  890. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  891. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  892. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  893. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  895. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  896. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  897. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  898. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  899. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  900. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  901. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  902. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  903. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  904. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  905. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  906. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  907. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  908. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  909. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  910. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  911. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  912. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  913. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  914. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  915. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  916. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  917. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  918. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  919. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  920. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  921. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  922. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  923. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  924. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  925. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  926. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  928. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  929. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  930. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  931. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  932. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  933. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  934. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  935. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  936. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  937. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  938. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  939. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  940. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  941. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  942. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  943. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  944. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  945. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  946. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  947. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  948. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  949. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  950. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  951. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  952. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  953. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  954. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  955. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  956. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  957. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  958. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  959. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  960. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  961. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  962. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  963. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  964. mindspore/ops/_op_impl/tbe/range.py +0 -39
  965. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  966. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  967. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  968. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  969. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  970. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  971. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  972. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  973. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  974. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  975. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  976. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  977. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  978. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  980. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  981. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  982. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  983. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  984. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  985. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  986. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  987. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  988. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  989. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  990. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  991. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  992. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  993. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  994. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  995. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  996. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  997. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  998. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  999. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1000. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1002. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1003. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1004. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1005. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1006. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1007. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1008. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1009. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1010. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1011. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1012. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1013. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1014. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1017. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1018. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1020. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1021. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1022. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1023. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1024. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1026. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1027. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1028. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1029. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1030. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1031. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1032. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1033. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1034. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1035. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1036. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1037. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1038. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1039. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1041. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1042. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1043. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1044. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1045. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1046. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1048. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1049. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1050. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1051. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1052. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1053. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1054. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1055. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1056. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1057. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1058. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1059. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1060. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1061. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1062. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1063. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1064. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1065. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1066. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1067. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1068. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1069. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1070. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1071. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1072. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1073. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1074. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1075. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1076. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1077. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1078. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1079. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1080. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1081. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1082. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1083. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1084. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1085. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1086. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1087. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1088. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1089. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1090. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1091. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1092. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1093. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1094. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1095. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1096. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1097. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1099. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1100. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1101. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1102. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1103. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1104. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1105. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1106. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1107. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1108. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1109. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1110. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1111. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1112. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1114. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1115. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1117. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1118. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1119. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1120. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1121. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1122. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1123. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1124. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1125. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1126. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1127. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1128. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1129. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1130. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1132. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1133. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1134. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1135. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1136. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1137. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1138. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1139. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1140. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1141. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1142. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1143. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1144. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1145. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1146. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1147. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1148. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1149. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1150. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1151. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1152. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1153. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1154. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1155. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1156. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1157. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1158. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1159. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1160. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1161. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1162. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1163. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1164. mindspore/ops/_tracefunc.py +0 -241
  1165. mindspore/ops/arg_dtype_cast.py +0 -54
  1166. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1167. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1168. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1169. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1170. mindspore/rewrite/namespace.py +0 -53
  1171. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  1172. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1716 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Operators for nn."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+
19
+ import numbers
20
+ import math
21
+ import numpy as np
22
+ from mindspore.ops import signature as sig
23
+ from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
24
+ from mindspore.ops._primitive_cache import _get_cache_prim
25
+ from mindspore.ops.auto_generate import gen_arg_handler as handler
26
+ from mindspore.common import Tensor, CSRTensor, COOTensor
27
+ from mindspore.common._stub_tensor import _convert_stub
28
+ from mindspore._c_expression import typing
29
+ from mindspore._c_expression import Tensor as Tensor_
30
+ from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
31
+ from mindspore.common import dtype as mstype
32
+ from mindspore.common._utils import is_shape_unknown
33
+ from mindspore import _checkparam as validator
34
+ from mindspore.ops.operations.manually_defined._inner import ScalarCast
35
+ from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
36
+ from mindspore.common.initializer import Zero
37
+ from mindspore.common.parameter import Parameter
38
+ from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
39
+
40
+
41
+ dtype_to_type_id = DtypeToEnum()
42
+
43
+
44
+ dtype_to_type_id = DtypeToEnum()
45
+
46
+
47
+ class ScalarDiv(Primitive):
48
+ r"""
49
+ Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
50
+
51
+ .. math::
52
+
53
+ out_{i} = \frac{x_i}{y_i}
54
+
55
+ .. note::
56
+ The inputs can be constant/variable value. Usage is the same as '/' in Python.
57
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
58
+
59
+ Inputs:
60
+ - **x** (Scalar) - A constant or variable scalar.
61
+ - **y** (Scalar) - A constant or variable scalar.
62
+
63
+ Outputs:
64
+ Scalar, the type of scalar is float.
65
+
66
+ Raises:
67
+ TypeError: If `x` and `y` are not scalar.
68
+ ValueError: If `y` is 0.
69
+
70
+ Supported Platforms:
71
+ ``Ascend`` ``GPU`` ``CPU``
72
+ """
73
+ @prim_attr_register
74
+ def __init__(self):
75
+ """Initialize ScalarDiv"""
76
+
77
+ def __call__(self, x, y):
78
+ if y == 0:
79
+ raise ValueError('The divisor could not be zero. But the divisor is zero now.')
80
+ return x / y
81
+
82
+
83
+ class ScalarFloorDiv(Primitive):
84
+ r"""
85
+ Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
86
+
87
+ .. math::
88
+
89
+ out_{i} = \frac{x_i}{y_i}
90
+
91
+ .. note::
92
+ The inputs can be constant/variable value. Usage is the same as '//' in Python.
93
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
94
+
95
+ Inputs:
96
+ - **x** (Scalar) - A constant or variable scalar.
97
+ - **y** (Scalar) - A constant or variable scalar.
98
+
99
+ Outputs:
100
+ Scalar, the type of scalar is float.
101
+
102
+ Raises:
103
+ TypeError: If `x` and `y` are not scalar.
104
+ ValueError: If `y` is 0.
105
+
106
+ Supported Platforms:
107
+ ``Ascend`` ``GPU`` ``CPU``
108
+ """
109
+ @prim_attr_register
110
+ def __init__(self):
111
+ """Initialize ScalarFloorDiv"""
112
+ self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
113
+
114
+ def __call__(self, x, y):
115
+ if y == 0:
116
+ raise ValueError('The divisor could not be zero. But the divisor is zero now.')
117
+ return x // y
118
+
119
+
120
+ class ScalarAdd(Primitive):
121
+ r"""
122
+ Adds two input scalar.
123
+
124
+ .. note::
125
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
126
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
127
+
128
+ Inputs:
129
+ - **x** (Scalar) - A constant or variable scalar.
130
+ - **y** (Scalar) - A constant or variable scalar.
131
+
132
+ Outputs:
133
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
134
+
135
+ Raises:
136
+ TypeError: If `x` and `y` are not scalar.
137
+
138
+ Supported Platforms:
139
+ ``Ascend`` ``GPU`` ``CPU``
140
+ """
141
+ @prim_attr_register
142
+ def __init__(self):
143
+ """Initialize ScalarAdd"""
144
+
145
+ def __call__(self, x, y):
146
+ return x + y
147
+
148
+
149
+ class ScalarPow(Primitive):
150
+ r"""
151
+ Pow two input scalar.
152
+
153
+ .. note::
154
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
155
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
156
+
157
+ Inputs:
158
+ - **x** (Scalar) - A constant or variable scalar.
159
+ - **y** (Scalar) - A constant or variable scalar.
160
+
161
+ Outputs:
162
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
163
+
164
+ Raises:
165
+ TypeError: If `x` and `y` are not scalar.
166
+
167
+ Supported Platforms:
168
+ ``Ascend`` ``GPU`` ``CPU``
169
+ """
170
+ @prim_attr_register
171
+ def __init__(self):
172
+ """Initialize ScalarPow"""
173
+
174
+ def __call__(self, x, y):
175
+ return pow(x, y)
176
+
177
+
178
+ class ScalarLog(Primitive):
179
+ r"""
180
+ Log input scalar.
181
+
182
+ .. note::
183
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
184
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
185
+
186
+ Inputs:
187
+ - **x** (Scalar) - A constant or variable scalar.
188
+
189
+ Outputs:
190
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
191
+
192
+ Raises:
193
+ TypeError: If `x` and `y` are not scalar.
194
+
195
+ Supported Platforms:
196
+ ``Ascend`` ``GPU`` ``CPU``
197
+ """
198
+ @prim_attr_register
199
+ def __init__(self):
200
+ """Initialize ScalarAdd"""
201
+
202
+ def __call__(self, x):
203
+ return math.log(x)
204
+
205
+
206
+ class ScalarUadd(Primitive):
207
+ r"""
208
+ UAdds input scalar.
209
+
210
+ .. note::
211
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
212
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
213
+
214
+ Inputs:
215
+ - **x** (Scalar) - A constant or variable scalar.
216
+
217
+ Outputs:
218
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
219
+
220
+ Raises:
221
+ TypeError: If `x` and `y` are not scalar.
222
+
223
+ Supported Platforms:
224
+ ``Ascend`` ``GPU`` ``CPU``
225
+ """
226
+ @prim_attr_register
227
+ def __init__(self):
228
+ """Initialize ScalarAdd"""
229
+
230
+ def __call__(self, x):
231
+ return x
232
+
233
+
234
+ class ScalarUsub(Primitive):
235
+ r"""
236
+ usub input scalar.
237
+
238
+ .. note::
239
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
240
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
241
+
242
+ Inputs:
243
+ - **x** (Scalar) - A constant or variable scalar.
244
+ - **y** (Scalar) - A constant or variable scalar.
245
+
246
+ Outputs:
247
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
248
+
249
+ Raises:
250
+ TypeError: If `x` and `y` are not scalar.
251
+
252
+ Supported Platforms:
253
+ ``Ascend`` ``GPU`` ``CPU``
254
+ """
255
+ @prim_attr_register
256
+ def __init__(self):
257
+ """Initialize ScalarUsub"""
258
+
259
+ def __call__(self, x):
260
+ return -x
261
+
262
+
263
+ class ScalarSub(Primitive):
264
+ r"""
265
+ Subtracts the second input Scalar from the first input Scalar.
266
+
267
+ .. note::
268
+ The inputs can be constant/variable value. Usage is the same as '-' in Python.
269
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
270
+
271
+ Inputs:
272
+ - **x** (Scalar) - A constant or variable scalar.
273
+ - **y** (Scalar) - A constant or variable scalar.
274
+
275
+ Outputs:
276
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
277
+
278
+ Raises:
279
+ TypeError: If `x` and `y` are not scalar.
280
+
281
+ Supported Platforms:
282
+ ``Ascend`` ``GPU`` ``CPU``
283
+ """
284
+ @prim_attr_register
285
+ def __init__(self):
286
+ """Initialize ScalarSub"""
287
+
288
+ def __call__(self, x, y):
289
+ return x - y
290
+
291
+
292
+ class ScalarMul(Primitive):
293
+ r"""
294
+ Muls two input scalar.
295
+
296
+ .. note::
297
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
298
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
299
+
300
+ Inputs:
301
+ - **x** (Scalar) - A constant or variable scalar.
302
+ - **y** (Scalar) - A constant or variable scalar.
303
+
304
+ Outputs:
305
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
306
+
307
+ Raises:
308
+ TypeError: If `x` and `y` are not scalar.
309
+
310
+ Supported Platforms:
311
+ ``Ascend`` ``GPU`` ``CPU``
312
+ """
313
+ @prim_attr_register
314
+ def __init__(self):
315
+ """Initialize ScalarMul"""
316
+
317
+ def __call__(self, x, y):
318
+ return x * y
319
+
320
+
321
+ class ScalarEq(Primitive):
322
+ r"""
323
+ Computes the equivalence between two Scalars.
324
+
325
+ .. note::
326
+ The inputs can be constant/variable value. Usage is the same as '==' in Python.
327
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
328
+
329
+ Inputs:
330
+ - **x** (Scalar) - A constant or variable scalar.
331
+ - **y** (Scalar) - A constant or variable scalar.
332
+
333
+ Outputs:
334
+ Scalar, the type of scalar is bool.
335
+
336
+ Raises:
337
+ TypeError: If `x` and `y` are not scalar.
338
+
339
+ Supported Platforms:
340
+ ``Ascend`` ``GPU`` ``CPU``
341
+ """
342
+ @prim_attr_register
343
+ def __init__(self):
344
+ """Initialize ScalarEq"""
345
+
346
+ def __call__(self, x, y):
347
+ return x == y
348
+
349
+
350
+ class ScalarGt(Primitive):
351
+ r"""
352
+ Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
353
+
354
+ .. note::
355
+ The inputs can be constant/variable value. Usage is the same as '>' in Python.
356
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
357
+
358
+ Inputs:
359
+ - **x** (Scalar) - A constant or variable scalar.
360
+ - **y** (Scalar) - A constant or variable scalar.
361
+
362
+ Outputs:
363
+ Scalar, the type of scalar is bool.
364
+
365
+ Raises:
366
+ TypeError: If `x` and `y` are not scalar.
367
+
368
+ Supported Platforms:
369
+ ``Ascend`` ``GPU`` ``CPU``
370
+ """
371
+ @prim_attr_register
372
+ def __init__(self):
373
+ """Initialize scalar_gt"""
374
+
375
+ def __call__(self, x, y):
376
+ return x > y
377
+
378
+
379
+ class ScalarLt(Primitive):
380
+ r"""
381
+ Computes the boolean value of :math:`x < y`.
382
+
383
+ .. note::
384
+ The inputs can be constant/variable value. Usage is the same as '<' in Python.
385
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
386
+
387
+ Inputs:
388
+ - **x** (Scalar) - A constant or variable scalar.
389
+ - **y** (Scalar) - A constant or variable scalar.
390
+
391
+ Outputs:
392
+ Scalar, the type of scalar is bool.
393
+
394
+ Raises:
395
+ TypeError: If `x` and `y` are not scalar.
396
+
397
+ Supported Platforms:
398
+ ``Ascend`` ``GPU`` ``CPU``
399
+ """
400
+ @prim_attr_register
401
+ def __init__(self):
402
+ """Initialize scalar_lt"""
403
+
404
+ def __call__(self, x, y):
405
+ return x < y
406
+
407
+
408
+ class ScalarGe(Primitive):
409
+ r"""
410
+ Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
411
+
412
+ .. note::
413
+ The inputs can be constant/variable value. Usage is the same as '>=' in Python.
414
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
415
+
416
+ Inputs:
417
+ - **x** (Scalar) - A constant or variable scalar.
418
+ - **y** (Scalar) - A constant or variable scalar.
419
+
420
+ Outputs:
421
+ Scalar, the type of scalar is bool.
422
+
423
+ Raises:
424
+ TypeError: If `x` and `y` are not scalar.
425
+
426
+ Supported Platforms:
427
+ ``Ascend`` ``GPU`` ``CPU``
428
+ """
429
+ @prim_attr_register
430
+ def __init__(self):
431
+ """Initialize scalar_ge"""
432
+
433
+ def __call__(self, x, y):
434
+ return x >= y
435
+
436
+
437
+ class ScalarLe(Primitive):
438
+ r"""
439
+ Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
440
+
441
+ .. note::
442
+ The inputs can be constant/variable value. Usage is the same as '<=' in Python.
443
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
444
+
445
+ Inputs:
446
+ - **x** (Scalar) - A constant or variable scalar.
447
+ - **y** (Scalar) - A constant or variable scalar.
448
+
449
+ Outputs:
450
+ Scalar, the type of scalar is bool.
451
+
452
+ Raises:
453
+ TypeError: If `x` and `y` are not scalar.
454
+
455
+ Supported Platforms:
456
+ ``Ascend`` ``GPU`` ``CPU``
457
+ """
458
+ @prim_attr_register
459
+ def __init__(self):
460
+ """Initialize scalar_le"""
461
+
462
+ def __call__(self, x, y):
463
+ return x <= y
464
+
465
+
466
+ class ScalarMod(Primitive):
467
+ r"""
468
+ Computes the remainder of dividing the first input scalar by the second input scalar element-wise.
469
+
470
+ .. math::
471
+
472
+ out_{i} = x_{i} \text{ % } y_{i}
473
+
474
+ .. note::
475
+ The inputs can be constant/variable value. Usage is the same as '%' in Python.
476
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
477
+
478
+ Inputs:
479
+ - **x** (Scalar) - A constant or variable scalar.
480
+ - **y** (Scalar) - A constant or variable scalar.
481
+
482
+ Outputs:
483
+ Scalar, the type is the one with higher precision or higher digits among the two inputs.
484
+
485
+ Raises:
486
+ TypeError: If `x` and `y` are not scalar.
487
+
488
+ Supported Platforms:
489
+ ``Ascend`` ``GPU`` ``CPU``
490
+ """
491
+ @prim_attr_register
492
+ def __init__(self):
493
+ """Initialize ScalarMod"""
494
+
495
+ def __call__(self, x, y):
496
+ if y == 0:
497
+ raise ValueError('Cannot perform modulo operation on zero.')
498
+ return x % y
499
+
500
+
501
+ class ScalarBool(Primitive):
502
+ r"""
503
+ Computes the input scalar true or false.
504
+
505
+ .. note::
506
+ The inputs can be constant/variable value.
507
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
508
+
509
+ Inputs:
510
+ - **x** (Scalar) - A constant or variable scalar.
511
+
512
+ Outputs:
513
+ Scalar, the type is bool.
514
+
515
+ Raises:
516
+ TypeError: If `x` are not scalar.
517
+
518
+ Supported Platforms:
519
+ ``Ascend`` ``GPU`` ``CPU``
520
+ """
521
+ @prim_attr_register
522
+ def __init__(self):
523
+ """Initialize ScalarBool"""
524
+
525
+ def __call__(self, x):
526
+ return bool(x)
527
+
528
+
529
+ scalar_div = ScalarDiv()
530
+ scalar_mod = ScalarMod()
531
+ scalar_add = ScalarAdd()
532
+ scalar_mul = ScalarMul()
533
+ scalar_sub = ScalarSub()
534
+ scalar_gt = ScalarGt()
535
+ scalar_ge = ScalarGe()
536
+ scalar_le = ScalarLe()
537
+ scalar_lt = ScalarLt()
538
+ scalar_eq = ScalarEq()
539
+ scalar_bool = ScalarBool()
540
+ scalar_floordiv = ScalarFloorDiv()
541
+ scalar_log = ScalarLog()
542
+ scalar_pow = ScalarPow()
543
+ scalar_uadd = ScalarUadd()
544
+ scalar_usub = ScalarUsub()
545
+
546
+
547
+ class BatchNorm(Primitive):
548
+ r"""
549
+ Batch Normalization for input data and updated parameters.
550
+
551
+ Batch Normalization is widely used in convolutional neural networks. This operation
552
+ applies Batch Normalization over inputs to avoid internal covariate shift as described
553
+ in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
554
+ Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
555
+ features using a mini-batch of data and the learned parameters can be described
556
+ in the following formula,
557
+
558
+ .. math::
559
+
560
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
561
+
562
+ where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon,
563
+ :math:`mean` is the mean of :math:`x`,
564
+ :math:`variance` is the variance of :math:`x`.
565
+
566
+ .. warning::
567
+ - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
568
+ then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
569
+ - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
570
+
571
+ Args:
572
+ is_training (bool): If `is_training` is ``True`` , `mean` and `variance` are computed during training.
573
+ If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` .
574
+ epsilon (float): A small value added for numerical stability. Default: ``1e-5``, value must be (0, 1] .
575
+ momentum (float): The hyper parameter to compute moving average for running_mean and running_var
576
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
577
+ Momentum value must be [0, 1]. Default: ``0.1`` .
578
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``, and the ``'NHWC'`` format
579
+ is only supported in GPU target. Default: ``"NCHW"`` .
580
+
581
+ Inputs:
582
+ If `is_training` is ``False`` , inputs are Tensors.
583
+
584
+ - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
585
+ - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
586
+ - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
587
+ - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
588
+ - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
589
+
590
+ If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters.
591
+
592
+ - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
593
+ - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type.
594
+ - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
595
+ - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
596
+ - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
597
+
598
+ Outputs:
599
+ Tuple of 5 Tensors, the normalized inputs and the updated parameters.
600
+
601
+ - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
602
+ - **batch_mean** (Tensor) - The mean calculated per-dimension over the mini-batches,
603
+ shape is :math:`(C,)`.
604
+ - **batch_variance** (Tensor) - The variance calculated per-dimension over the mini-batches,
605
+ shape is :math:`(C,)`.
606
+ - **reserve_space_1** (Tensor) - The mean that needs to be reused when calculating gradients,
607
+ one-dimensional Tensor. The shape is :math:`(C,)`.
608
+ - **reserve_space_2** (Tensor) - The variance that needs to be reused when calculating gradients,
609
+ one-dimensional Tensor. The shape is :math:`(C,)`.
610
+
611
+ Raises:
612
+ TypeError: If `is_training` is not a bool.
613
+ TypeError: If dtype of `epsilon` or `momentum` is not float.
614
+ TypeError: If `data_format` is not a str.
615
+ TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
616
+ TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
617
+
618
+ Supported Platforms:
619
+ ``Ascend`` ``GPU`` ``CPU``
620
+
621
+ Examples:
622
+ >>> import mindspore
623
+ >>> import numpy as np
624
+ >>> from mindspore import Tensor, ops
625
+ >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
626
+ >>> scale = Tensor(np.ones([2]), mindspore.float32)
627
+ >>> bias = Tensor(np.ones([2]), mindspore.float32)
628
+ >>> mean = Tensor(np.ones([2]), mindspore.float32)
629
+ >>> variance = Tensor(np.ones([2]), mindspore.float32)
630
+ >>> batch_norm = ops.BatchNorm()
631
+ >>> output = batch_norm(input_x, scale, bias, mean, variance)
632
+ >>> print(output[0])
633
+ [[1. 1.]
634
+ [1. 1.]]
635
+ """
636
+ __mindspore_signature__ = (sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
637
+ sig.make_sig('scale',
638
+ sig.sig_rw.RW_WRITE,
639
+ dtype=sig.sig_dtype.T2),
640
+ sig.make_sig('bias',
641
+ sig.sig_rw.RW_WRITE,
642
+ dtype=sig.sig_dtype.T2),
643
+ sig.make_sig('mean',
644
+ sig.sig_rw.RW_WRITE,
645
+ dtype=sig.sig_dtype.T3),
646
+ sig.make_sig('variance',
647
+ sig.sig_rw.RW_WRITE,
648
+ dtype=sig.sig_dtype.T3))
649
+
650
+ @prim_arg_register
651
+ def __init__(self,
652
+ is_training=False,
653
+ epsilon=1e-5,
654
+ momentum=0.1,
655
+ data_format="NCHW"):
656
+ """Initialize BatchNorm."""
657
+ if is_training is False:
658
+ self.set_signatures(tuple())
659
+ else:
660
+ self.add_prim_attr('side_effect_mem', True)
661
+ self.is_training = is_training
662
+ self.epsilon = epsilon
663
+ self.momentum = momentum
664
+ self.data_format = handler.str_to_enum("BatchNorm", "data_format", data_format)
665
+
666
+ def __call__(self, *args):
667
+ return super().__call__(*args, self.is_training, self.epsilon,
668
+ self.momentum, self.data_format)
669
+
670
+
671
+ def batch_norm_(input_x,
672
+ scale,
673
+ bias,
674
+ mean,
675
+ variance,
676
+ is_training=False,
677
+ epsilon=1e-5,
678
+ momentum=0.1,
679
+ data_format="NCHW"):
680
+ r"""
681
+ Batch Normalization for input data and updated parameters.
682
+
683
+ Batch Normalization is widely used in convolutional neural networks. This operation
684
+ applies Batch Normalization over inputs to avoid internal covariate shift as described
685
+ in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
686
+ Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
687
+ features using a mini-batch of data and the learned parameters can be described
688
+ in the following formula,
689
+
690
+ .. math::
691
+
692
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
693
+
694
+ where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon,
695
+ :math:`mean` is the mean of :math:`x`,
696
+ :math:`variance` is the variance of :math:`x`.
697
+
698
+ .. warning::
699
+ - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
700
+ then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
701
+ - For Atlas 200/300/500 inference product,
702
+ the result accuracy fails to reach 1‰ due to the square root instruction.
703
+
704
+ Note:
705
+ - If `training` is `False`, `weight`, `bias`, `running_mean` and `running_var` are tensors.
706
+ - If `training` is `True`, `weight`, `bias`, `running_mean` and `running_var` are Parameters.
707
+
708
+ Args:
709
+ input_x (tensor): tensor of shape :math:`(N, C)`, with float16 or float32 data type.
710
+ scale (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
711
+ bias (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
712
+ mean (Union[tensor, Parameter]): The shape :math:`(C,)`, with float16 or float32 data type.
713
+ variance (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
714
+ is_training (bool, optional): If `training` is `True`, `mean` and `variance` are computed during training.
715
+ If `training` is `False`, they're loaded from checkpoint during inference. Default: False.
716
+ epsilon (float): A small value added for numerical stability.
717
+ Default: ``1e-5``, value must be (0, 1] .
718
+ momentum (float): The hyper parameter to compute moving average for running_mean and running_var
719
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
720
+ Momentum value must be [0, 1].
721
+ Default: ``0.1`` .
722
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
723
+ and the ``'NHWC'`` format is only supported in GPU target.
724
+ Default: ``"NCHW"`` .
725
+
726
+ Returns:
727
+ output_x (Tensor): The same type and shape as the input_x. The shape is :math:`(N, C)`.
728
+ batch_mean (Tensor): Tensor of shape :math:`(C,)`.
729
+ batch_variance (Tensor): Tensor of shape :math:`(C,)`.
730
+ reserve_space_1 (Tensor): Tensor of shape :math:`(C,)`.
731
+ reserve_space_2 (Tensor): Tensor of shape :math:`(C,)`.
732
+
733
+ Raises:
734
+ TypeError: If `is_training` is not a bool.
735
+ TypeError: If dtype of `epsilon` or `momentum` is not float.
736
+ TypeError: If `data_format` is not a str.
737
+ TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
738
+ TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
739
+
740
+ Supported Platforms:
741
+ ``Ascend`` ``GPU`` ``CPU``
742
+
743
+ Examples:
744
+ >>> import mindspore
745
+ >>> import numpy as np
746
+ >>> from mindspore import Tensor, ops
747
+ >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
748
+ >>> scale = Tensor(np.ones([2]), mindspore.float32)
749
+ >>> bias = Tensor(np.ones([2]), mindspore.float32)
750
+ >>> mean = Tensor(np.ones([2]), mindspore.float32)
751
+ >>> variance = Tensor(np.ones([2]), mindspore.float32)
752
+ >>> output = ops.batch_norm_(input_x, scale, bias, mean, variance, is_training, epsilon, momentum, data_format)
753
+ >>> print(output[0])
754
+ [[1. 1.]
755
+ [1. 1.]]
756
+ """
757
+ batch_norm_op = _get_cache_prim(BatchNorm)(is_training, epsilon, momentum,
758
+ data_format)
759
+ return batch_norm_op(input_x, scale, bias, mean, variance)
760
+
761
+
762
+ class Rank(Primitive):
763
+ """
764
+ Returns the rank of a tensor.
765
+
766
+ Refer to :func:`mindspore.ops.rank` for more details.
767
+
768
+ Supported Platforms:
769
+ ``Ascend`` ``GPU`` ``CPU``
770
+
771
+ Examples:
772
+ >>> import mindspore
773
+ >>> import numpy as np
774
+ >>> from mindspore import Tensor, ops
775
+ >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
776
+ >>> rank = ops.Rank()
777
+ >>> output = rank(input_tensor)
778
+ >>> print(output)
779
+ 2
780
+ >>> print(type(output))
781
+ <class 'int'>
782
+ """
783
+
784
+ @prim_attr_register
785
+ def __init__(self):
786
+ """Initialize Rank"""
787
+
788
+ def __call__(self, x):
789
+ if not isinstance(x, (Tensor, Tensor_)):
790
+ raise TypeError("the input x must be Tensor!")
791
+ return len(x.shape)
792
+
793
+
794
+ def rank(input_x):
795
+ """
796
+ Returns the rank of a tensor.
797
+
798
+ Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
799
+ is the number of indices required to uniquely select each element of the tensor.
800
+
801
+ Args:
802
+ input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
803
+
804
+ Returns:
805
+ Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int.
806
+
807
+ Raises:
808
+ TypeError: If `input_x` is not a Tensor.
809
+
810
+ Supported Platforms:
811
+ ``Ascend`` ``GPU`` ``CPU``
812
+
813
+ Examples:
814
+ >>> import mindspore
815
+ >>> import numpy as np
816
+ >>> from mindspore import Tensor, ops
817
+ >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
818
+ >>> output = ops.rank(input_tensor)
819
+ >>> print(output)
820
+ 2
821
+ >>> print(type(output))
822
+ <class 'int'>
823
+
824
+ """
825
+ rank_op = _get_cache_prim(Rank)()
826
+ return rank_op(input_x)
827
+
828
+
829
+ class Shape(Primitive):
830
+ """
831
+ Returns the shape of the input tensor.
832
+
833
+ Refer to :func:`mindspore.ops.shape` for more details.
834
+
835
+ Inputs:
836
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
837
+
838
+ Outputs:
839
+ tuple[int], the output tuple is constructed by multiple integers,
840
+ :math:`(x_1, x_2, ..., x_R)`.
841
+
842
+ Supported Platforms:
843
+ ``Ascend`` ``GPU`` ``CPU``
844
+
845
+ Examples:
846
+ >>> import mindspore
847
+ >>> import numpy as np
848
+ >>> from mindspore import Tensor, ops
849
+ >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
850
+ >>> shape = ops.Shape()
851
+ >>> output = shape(input_x)
852
+ >>> print(output)
853
+ (3, 2, 1)
854
+ """
855
+
856
+ @prim_attr_register
857
+ def __init__(self):
858
+ """Initialize Shape"""
859
+
860
+ def __call__(self, x):
861
+ if isinstance(x, (Tensor, COOTensor, CSRTensor, Tensor_)):
862
+ return x.shape
863
+ raise TypeError(f"For primitive[{self.name}], the input argument must be Tensor, but got {type(x)}.")
864
+
865
+
866
+ def shape_(input_x):
867
+ """
868
+ Returns the shape of the input tensor.
869
+
870
+ Args:
871
+ input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
872
+
873
+ Returns:
874
+ tuple[int], the output tuple is constructed by multiple integers,
875
+ :math:`(x_1, x_2, ..., x_R)`.
876
+
877
+ Raises:
878
+ TypeError: If `input_x` is not a Tensor.
879
+
880
+ Supported Platforms:
881
+ ``Ascend`` ``GPU`` ``CPU``
882
+
883
+ Examples:
884
+ >>> import mindspore
885
+ >>> import numpy as np
886
+ >>> from mindspore import Tensor, ops
887
+ >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
888
+ >>> output = ops.shape(input_x)
889
+ >>> print(output)
890
+ (3, 2, 1)
891
+ """
892
+ shape_op = _get_cache_prim(Shape)()
893
+ return shape_op(input_x)
894
+
895
+
896
+ class ScalarToTensor(PrimitiveWithInfer):
897
+ """
898
+ Converts a scalar to a `Tensor`, and converts the data type to the specified type.
899
+
900
+ Refer to :func:`mindspore.ops.scalar_to_tensor` for more details.
901
+
902
+ Inputs:
903
+ - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
904
+ - **dtype** (mindspore.dtype) - The target data type. Default: ``mindspore.float32`` . Only
905
+ constant value is allowed.
906
+
907
+ Outputs:
908
+ Tensor. 0-D Tensor and the content is the input.
909
+
910
+ Supported Platforms:
911
+ ``Ascend`` ``GPU`` ``CPU``
912
+
913
+ Examples:
914
+ >>> import mindspore
915
+ >>> from mindspore import ops
916
+ >>> op = ops.ScalarToTensor()
917
+ >>> data = 1
918
+ >>> output = op(data, mindspore.float32)
919
+ >>> print(output)
920
+ 1.0
921
+ """
922
+
923
+ @prim_attr_register
924
+ def __init__(self):
925
+ self.init_prim_io_names(inputs=['input_scalar', 'dtype'], outputs=['output_data'])
926
+
927
+ def __call__(self, x, dtype=mstype.float32):
928
+ validator.check_value_type("x", x, [bool, int, float], self.name)
929
+ validator.check_subclass("dtype", dtype, mstype.number, self.name)
930
+ data_type = mstype.dtype_to_nptype(dtype)
931
+ return Tensor(np.array(x, data_type), dtype=dtype)
932
+
933
+
934
+ class Tile(Primitive):
935
+ r"""
936
+ Replicates an input tensor with given multiple times.
937
+
938
+ Refer to :func:`mindspore.ops.tile` for more details.
939
+
940
+ Inputs:
941
+ - **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
942
+ :math:`(x_1, x_2, ..., x_S)` .
943
+ - **dims** (tuple[int]) - The parameter that specifies the number of replications,
944
+ the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
945
+ Only constant value is allowed.
946
+
947
+ Outputs:
948
+ Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
949
+ the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
950
+
951
+ - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
952
+ the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
953
+ - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
954
+ Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
955
+ then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
956
+ :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
957
+ - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the
958
+ `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
959
+ can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
960
+
961
+ Raises:
962
+ TypeError: If `dims` is not a tuple or its elements are not all int.
963
+ ValueError: If the elements of `dims` are not all greater than or equal to 0.
964
+
965
+ Supported Platforms:
966
+ ``Ascend`` ``GPU`` ``CPU``
967
+
968
+ Examples:
969
+ >>> import mindspore
970
+ >>> import numpy as np
971
+ >>> from mindspore import Tensor, ops
972
+ >>> tile = ops.Tile()
973
+ >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
974
+ >>> dims = (2, 3)
975
+ >>> output = tile(input, dims)
976
+ >>> print(output)
977
+ [[1. 2. 1. 2. 1. 2.]
978
+ [3. 4. 3. 4. 3. 4.]
979
+ [1. 2. 1. 2. 1. 2.]
980
+ [3. 4. 3. 4. 3. 4.]]
981
+ >>> dims = (2, 3, 2)
982
+ >>> output = tile(input, dims)
983
+ >>> print(output)
984
+ [[[1. 2. 1. 2.]
985
+ [3. 4. 3. 4.]
986
+ [1. 2. 1. 2.]
987
+ [3. 4. 3. 4.]
988
+ [1. 2. 1. 2.]
989
+ [3. 4. 3. 4.]]
990
+ [[1. 2. 1. 2.]
991
+ [3. 4. 3. 4.]
992
+ [1. 2. 1. 2.]
993
+ [3. 4. 3. 4.]
994
+ [1. 2. 1. 2.]
995
+ [3. 4. 3. 4.]]]
996
+ """
997
+
998
+ @prim_attr_register
999
+ def __init__(self):
1000
+ """Initialize."""
1001
+
1002
+ def __call__(self, input, dims):
1003
+ return _convert_stub(pyboost_tile(self, [input, dims]))
1004
+
1005
+ # pylint: disable=missing-docstring
1006
+ def check_elim(self, *args):
1007
+ base_tensor, dims = args
1008
+ if not isinstance(base_tensor, Tensor):
1009
+ raise TypeError(f"For '{self.name}', the type of 'input' must be Tensor, "
1010
+ f"but got {type(base_tensor).__name__}.")
1011
+ if not isinstance(dims, tuple):
1012
+ raise TypeError(f"For '{self.name}', the type of 'dims' must be tuple, "
1013
+ f"but got {type(dims).__name__}.")
1014
+
1015
+ if all(v == 1 for v in dims) and len(base_tensor.shape) >= len(dims):
1016
+ from mindspore.ops.auto_generate.gen_ops_def import Identity
1017
+ ret = Identity()(base_tensor)
1018
+ return (True, ret)
1019
+ return (False, None)
1020
+
1021
+
1022
+ def tile(input, dims):
1023
+ r"""
1024
+ Creates a new tensor by replicating `input` `dims` times. The i'th dimension of
1025
+ output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
1026
+ are replicated `dims[i]` times along the i'th dimension.
1027
+
1028
+ Args:
1029
+ input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
1030
+ :math:`(x_1, x_2, ..., x_S)` .
1031
+
1032
+ dims (tuple[int]): The parameter that specifies the number of replications,
1033
+ the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
1034
+ Only constant value is allowed.
1035
+
1036
+ Returns:
1037
+ Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
1038
+ the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
1039
+
1040
+ - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
1041
+ the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
1042
+ - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
1043
+ Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
1044
+ then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
1045
+ :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
1046
+ - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the
1047
+ `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
1048
+ can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
1049
+
1050
+ Raises:
1051
+ TypeError: If `dims` is not a tuple or its elements are not all int.
1052
+ ValueError: If the elements of `dims` are not all greater than or equal to 0.
1053
+
1054
+ Supported Platforms:
1055
+ ``Ascend`` ``GPU`` ``CPU``
1056
+
1057
+ Examples:
1058
+ >>> import mindspore
1059
+ >>> import numpy as np
1060
+ >>> from mindspore import Tensor, ops
1061
+ >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
1062
+ >>> dims = (2, 3)
1063
+ >>> output = ops.tile(input, dims)
1064
+ >>> print(output)
1065
+ [[1. 2. 1. 2. 1. 2.]
1066
+ [3. 4. 3. 4. 3. 4.]
1067
+ [1. 2. 1. 2. 1. 2.]
1068
+ [3. 4. 3. 4. 3. 4.]]
1069
+ >>> dims = (2, 3, 2)
1070
+ >>> output = ops.tile(input, dims)
1071
+ >>> print(output)
1072
+ [[[1. 2. 1. 2.]
1073
+ [3. 4. 3. 4.]
1074
+ [1. 2. 1. 2.]
1075
+ [3. 4. 3. 4.]
1076
+ [1. 2. 1. 2.]
1077
+ [3. 4. 3. 4.]]
1078
+ [[1. 2. 1. 2.]
1079
+ [3. 4. 3. 4.]
1080
+ [1. 2. 1. 2.]
1081
+ [3. 4. 3. 4.]
1082
+ [1. 2. 1. 2.]
1083
+ [3. 4. 3. 4.]]]
1084
+ """
1085
+ tile_op = _get_cache_prim(Tile)()
1086
+ return tile_op(input, dims)
1087
+
1088
+
1089
+ def scalar_cast(input_x, input_y):
1090
+ r"""
1091
+ The interface is deprecated from version 2.3 and will be removed in a future version,
1092
+ please use `int(x)` or `float(x)` instead.
1093
+
1094
+ Casts the input scalar to another type.
1095
+
1096
+ Args:
1097
+ input_x (scalar): The input scalar.
1098
+ input_y (mindspore.dtype): The type to be cast. Only constant value is allowed.
1099
+ The value should only be mindspore.int64, mindspore.float64, or mindspore.bool\_.
1100
+
1101
+ Returns:
1102
+ Scalar, the type is the same as the python type corresponding to `input_y`.
1103
+
1104
+ Raises:
1105
+ ValueError: if input_y's value is invalid.
1106
+
1107
+ Supported Platforms:
1108
+ Deprecated
1109
+
1110
+ Examples:
1111
+ >>> import mindspore
1112
+ >>> from mindspore import ops
1113
+ >>> output = ops.scalar_cast(255.0, mindspore.int64)
1114
+ >>> print(output)
1115
+ 255
1116
+ """
1117
+ scalar_cast_op = _get_cache_prim(ScalarCast)()
1118
+ return scalar_cast_op(input_x, input_y)
1119
+
1120
+
1121
+ class Cast(Primitive):
1122
+ """
1123
+ Returns a tensor with the new specified data type.
1124
+
1125
+ Note:
1126
+ When converting complex numbers to boolean type, the imaginary part of the complex number is not
1127
+ taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
1128
+
1129
+ Inputs:
1130
+ - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1131
+ The tensor to be cast.
1132
+ - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
1133
+
1134
+ Outputs:
1135
+ Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
1136
+
1137
+ Raises:
1138
+ TypeError: If `input_x` is neither Tensor nor Number.
1139
+ TypeError: If `type` is not a Number.
1140
+
1141
+ Supported Platforms:
1142
+ ``Ascend`` ``GPU`` ``CPU``
1143
+
1144
+ Examples:
1145
+ >>> import mindspore
1146
+ >>> import numpy as np
1147
+ >>> from mindspore import Tensor, ops
1148
+ >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
1149
+ >>> input_x = Tensor(input_np)
1150
+ >>> type_dst = mindspore.int32
1151
+ >>> cast = ops.Cast()
1152
+ >>> output = cast(input_x, type_dst)
1153
+ >>> print(output.dtype)
1154
+ Int32
1155
+ >>> print(output.shape)
1156
+ (2, 3, 4, 5)
1157
+ """
1158
+
1159
+ @prim_attr_register
1160
+ def __init__(self):
1161
+ """Initialize Cast"""
1162
+ self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
1163
+
1164
+ def check_elim(self, x, dtype):
1165
+ if isinstance(x, (Tensor, numbers.Number, Parameter)):
1166
+ if isinstance(x, Parameter):
1167
+ data = x.data
1168
+ if data.dtype == dtype:
1169
+ return (True, x)
1170
+ if isinstance(x, Tensor) and x.dtype == dtype:
1171
+ x = Tensor(x)
1172
+ x.set_cast_dtype()
1173
+ return (True, x)
1174
+ if isinstance(x, numbers.Number):
1175
+ return (True, Tensor(x, dtype=dtype))
1176
+ return (False, None)
1177
+
1178
+ def __call__(self, input_x, dtype):
1179
+ should_elim, output = self.check_elim(input_x, dtype)
1180
+ if should_elim:
1181
+ return output
1182
+ return _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
1183
+
1184
+ # Following is Python Infer Value.
1185
+ # A valid infer value function should be:
1186
+ #
1187
+ # 1. named as infer_value_for_OpName
1188
+ # 2. All inputs should pass without default value.
1189
+ # 3. If not const input is given, return None. (for now)
1190
+
1191
+
1192
+ def infer_value_for_Tile(input, dims):
1193
+ """Infer value for Tile op."""
1194
+ if input is None or dims is None or None in dims:
1195
+ return None
1196
+ return Tensor(np.tile(input.asnumpy(), dims))
1197
+
1198
+
1199
+ def infer_value_for_Concat(tensors, axis):
1200
+ """Infer value for Concat op."""
1201
+ if not tensors or None in tensors or axis is None:
1202
+ return None
1203
+
1204
+ tensor_to_concat = [x.asnumpy() if x.dtype != mstype.bfloat16 else x.float().asnumpy() for x in tensors]
1205
+ return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
1206
+
1207
+
1208
+ def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode):
1209
+ """Infer value for ReduceSum op."""
1210
+ value = None
1211
+ if input_x is not None and axis is not None:
1212
+ value = input_x.asnumpy()
1213
+ if isinstance(axis, int):
1214
+ pass
1215
+ elif axis:
1216
+ axis = tuple(set(axis))
1217
+ elif axis in ((), []) and skip_mode:
1218
+ return input_x
1219
+ else:
1220
+ axis = tuple(range(len(value.shape)))
1221
+ value = np.sum(value, axis, keepdims=keep_dims)
1222
+ value = np.array(value)
1223
+ value = Tensor(value)
1224
+ return value
1225
+
1226
+
1227
+ def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
1228
+ """Infer value for Common Reduce op."""
1229
+ value = None
1230
+ if input_x is not None and axis is not None:
1231
+ prim_map = {
1232
+ 'ReduceMax': np.max,
1233
+ 'ReduceMin': np.min,
1234
+ 'ReduceProd': np.prod,
1235
+ 'ReduceMean': np.mean,
1236
+ 'ReduceAll': np.all,
1237
+ 'ReduceAny': np.any,
1238
+ }
1239
+ np_reduce_func = prim_map.get(prim_name, None)
1240
+
1241
+ if np_reduce_func is not None:
1242
+ value = input_x.asnumpy()
1243
+ if isinstance(axis, int):
1244
+ pass
1245
+ elif axis:
1246
+ axis = tuple(set(axis))
1247
+ else:
1248
+ axis = tuple(range(len(value.shape)))
1249
+ value = np_reduce_func(value, axis, keepdims=keep_dims)
1250
+ value = np.array(value)
1251
+ value = Tensor(value)
1252
+ return value
1253
+
1254
+
1255
+ def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
1256
+ """Infer value for Common ReduceExtand op."""
1257
+ value = None
1258
+ if input_x is not None:
1259
+ prim_map = {
1260
+ 'MeanExt': np.mean,
1261
+ 'SumExt': np.sum,
1262
+ 'ProdExt': np.prod,
1263
+ }
1264
+ np_reduce_extand_func = prim_map.get(prim_name, None)
1265
+
1266
+ if np_reduce_extand_func is not None:
1267
+ value = input_x.asnumpy()
1268
+ if isinstance(axis, int):
1269
+ pass
1270
+ elif axis:
1271
+ axis = tuple(set(axis))
1272
+ else:
1273
+ axis = tuple(range(len(value.shape)))
1274
+ if dtype is not None:
1275
+ np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1276
+ value = np_reduce_extand_func(value, axis, dtype=np_dtype, keepdims=keep_dims)
1277
+ else:
1278
+ value = np_reduce_extand_func(value, axis, keepdims=keep_dims)
1279
+
1280
+ value = np.array(value)
1281
+ value = Tensor(value)
1282
+ return value
1283
+
1284
+
1285
+ def _infer_value_for_max_min(input_x, prim_name):
1286
+ """Infer value for Max/Min op."""
1287
+ value = None
1288
+ if input_x is not None:
1289
+ prim_map = {
1290
+ 'Max': np.max,
1291
+ 'Min': np.min,
1292
+ }
1293
+ np_reduce_func = prim_map.get(prim_name, None)
1294
+
1295
+ if np_reduce_func is not None:
1296
+ value = input_x.asnumpy()
1297
+ value = np_reduce_func(value, None, keepdims=False)
1298
+ value = np.array(value)
1299
+ value = Tensor(value)
1300
+ return value
1301
+
1302
+
1303
+ def infer_value_for_Cast(x, dst_type_enum=None):
1304
+ """Infer value for Cast op."""
1305
+ if x is None or dst_type_enum is None:
1306
+ return None
1307
+ dst_type = typing.type_id_to_type(dst_type_enum)
1308
+ src_type = mstype.get_py_obj_dtype(x)
1309
+ validator.check_subclass("input_x", src_type, [mstype.tensor_type, mstype.number], "Cast")
1310
+ validator.check_subclass("type", dst_type, mstype.number, "Cast")
1311
+
1312
+ if isinstance(src_type, type(mstype.tensor_type)):
1313
+ src_type = src_type.element_type()
1314
+ if isinstance(dst_type, type(mstype.tensor_type)):
1315
+ dst_type = dst_type.element_type()
1316
+
1317
+ value = None
1318
+ np_dst_type = mstype.dtype_to_nptype(dst_type)
1319
+ if isinstance(x, (int, float)):
1320
+ value = Tensor(np.array(x).astype(np_dst_type), dtype=dst_type)
1321
+ else:
1322
+ if x.dtype == mstype.bfloat16:
1323
+ cpu_cast = Cast().set_device("CPU")
1324
+ x = cpu_cast(x, mstype.float32)
1325
+ value = Tensor_(x.asnumpy().astype(np_dst_type), dtype=dst_type)
1326
+ return value
1327
+
1328
+
1329
+ def infer_value_for_ReduceMax(input_x, axis, keep_dims):
1330
+ """Infer value for ReduceMax op."""
1331
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
1332
+
1333
+
1334
+ def infer_value_for_Max(input_x):
1335
+ """Infer value for Max op."""
1336
+ return _infer_value_for_max_min(input_x, 'Max')
1337
+
1338
+
1339
+ def infer_value_for_ReduceMin(input_x, axis, keep_dims):
1340
+ """Infer value for ReduceMin op."""
1341
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMin')
1342
+
1343
+
1344
+ def infer_value_for_Min(input_x):
1345
+ """Infer value for Max op."""
1346
+ return _infer_value_for_max_min(input_x, 'Min')
1347
+
1348
+
1349
+ def infer_value_for_ReduceProd(input_x, axis, keep_dims):
1350
+ """Infer value for ReduceProd op."""
1351
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceProd')
1352
+
1353
+
1354
+ def infer_value_for_ReduceMean(input_x, axis, keep_dims):
1355
+ """Infer value for ReduceMean op."""
1356
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMean')
1357
+
1358
+
1359
+ def infer_value_for_ReduceAll(input_x, axis, keep_dims):
1360
+ """Infer value for ReduceAll op."""
1361
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAll')
1362
+
1363
+
1364
+ def infer_value_for_ReduceAny(input_x, axis, keep_dims):
1365
+ """Infer value for ReduceAny op."""
1366
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAny')
1367
+
1368
+
1369
+ def infer_value_for_MeanExt(input_x, axis, keep_dims, dtype):
1370
+ """Infer value for MeanExt op."""
1371
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'MeanExt')
1372
+
1373
+
1374
+ def infer_value_for_SumExt(input_x, axis, keep_dims, dtype):
1375
+ """Infer value for SumExt op."""
1376
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'SumExt')
1377
+
1378
+
1379
+ def infer_value_for_ProdExt(input_x, axis, keep_dims, dtype):
1380
+ """Infer value for ProdExt op."""
1381
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'ProdExt')
1382
+
1383
+
1384
+ def infer_value_for_Diag(input_x):
1385
+ """Infer value for Diag op."""
1386
+ if input_x is None:
1387
+ return None
1388
+ # do constant-folding only when x rank is 1
1389
+ if len(input_x.shape) != 1:
1390
+ return None
1391
+ ret = np.diag(input_x.asnumpy())
1392
+ return Tensor(ret)
1393
+
1394
+
1395
+ def infer_value_for_BroadcastTo(x, shape):
1396
+ """Infer value for BroadcastTo op."""
1397
+ def none_in_tuple_or_list(x):
1398
+ return isinstance(x, (tuple, list)) and None in x
1399
+ if shape is None or none_in_tuple_or_list(shape) or x is None:
1400
+ return None
1401
+
1402
+ if isinstance(shape, (Tensor, Tensor_)):
1403
+ validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
1404
+ [mstype.int32, mstype.int64], "BroadcastTo")
1405
+ shape = shape.asnumpy().tolist()
1406
+ else:
1407
+ validator.check_value_type("shape", shape, [tuple], "BroadcastTo")
1408
+ shape = list(shape)
1409
+
1410
+ np_data = np.broadcast_to(x.asnumpy(), shape)
1411
+ if 0 in shape:
1412
+ init_func = Zero()
1413
+ init_func.__enable_zero_dim__ = True
1414
+ out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1415
+ return out
1416
+ return Tensor(np_data)
1417
+
1418
+
1419
+ def infer_value_for_Reshape(x, shape):
1420
+ """Infer value for Reshape op."""
1421
+ def none_in_tuple_or_list(x):
1422
+ return isinstance(x, (tuple, list)) and None in x
1423
+ # for shape is not constant
1424
+ if shape is None or none_in_tuple_or_list(shape) or x is None:
1425
+ return None
1426
+
1427
+ if isinstance(shape, (Tensor, Tensor_)):
1428
+ validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
1429
+ [mstype.int32, mstype.int64], "Reshape")
1430
+ shape = shape.asnumpy().tolist()
1431
+ else:
1432
+ validator.check_value_type("shape", shape, [tuple], "Reshape")
1433
+ shape = list(shape)
1434
+
1435
+ neg_index = -1
1436
+ dim_prod = 1
1437
+ for i, shp_i in enumerate(shape):
1438
+ validator.check_value_type("shape[%d]" % i, shp_i, [int], "Reshape")
1439
+ if shp_i == -1:
1440
+ if neg_index != -1:
1441
+ raise ValueError(f"For 'Reshape', there can be at most one '-1' in 'input_shape', "
1442
+ f"but got {shape}.")
1443
+ neg_index = i
1444
+ else:
1445
+ dim_prod *= shp_i
1446
+ out = None
1447
+ if not is_shape_unknown(x.shape):
1448
+ x_shp = x.shape
1449
+ if dim_prod < 0:
1450
+ raise ValueError(f"For 'Reshape', the shape of 'input_x' is {x_shp}, "
1451
+ f"the value of 'input_shape' is {shape}. "
1452
+ f"The product of 'input_shape' should > 0, but got {dim_prod}.")
1453
+ arr_prod = np.prod(x_shp)
1454
+ if neg_index != -1:
1455
+ shape[neg_index] = int(arr_prod // dim_prod)
1456
+ dim_prod *= shape[neg_index]
1457
+ if dim_prod != arr_prod:
1458
+ raise ValueError(f"For 'Reshape', the product of the 'input_x' shape "
1459
+ f"should be equal to product of 'input_shape', but got product of the"
1460
+ f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
1461
+ if 0 in shape:
1462
+ init_func = Zero()
1463
+ init_func.__enable_zero_dim__ = True
1464
+ out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1465
+ else:
1466
+ out = Tensor(x.asnumpy().reshape(shape))
1467
+ return out
1468
+
1469
+
1470
+ class Ones(Primitive):
1471
+ r"""
1472
+ Creates a tensor filled with value ones.
1473
+
1474
+ Refer to :func:`mindspore.ops.ones` for more details.
1475
+
1476
+ .. warning::
1477
+ For argument `size`, Tensor type input will be deprecated in the future version.
1478
+
1479
+ Inputs:
1480
+ - **shape** (Union[tuple[int], List[int], int, Tensor]) - The specified shape of output tensor.
1481
+ - **type** (:class:`mindspore.dtype`) - The specified type of output tensor.
1482
+
1483
+ Outputs:
1484
+ Tensor, whose dtype and size are defined by input.
1485
+
1486
+ Raises:
1487
+ TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1488
+
1489
+ Supported Platforms:
1490
+ ``Ascend`` ``GPU`` ``CPU``
1491
+
1492
+ Examples:
1493
+ >>> import mindspore
1494
+ >>> from mindspore import ops
1495
+ >>> ones = ops.Ones()
1496
+ >>> output = ones((2, 2), mindspore.float32)
1497
+ >>> print(output)
1498
+ [[1. 1.]
1499
+ [1. 1.]]
1500
+ >>> output = ones((3, 3), mindspore.float32)
1501
+ >>> print(output)
1502
+ [[1. 1. 1.]
1503
+ [1. 1. 1.]
1504
+ [1. 1. 1.]]
1505
+ """
1506
+
1507
+ __mindspore_signature__ = (
1508
+ sig.make_sig('size'),
1509
+ sig.make_sig('type', default=None),
1510
+ )
1511
+
1512
+ @prim_arg_register
1513
+ def __init__(self):
1514
+ pass
1515
+
1516
+ def __call__(self, size, type=None):
1517
+ return _convert_stub(pyboost_ones(self, [size, type if type is None \
1518
+ else handler.dtype_to_type_id('Ones', 'type', type)]))
1519
+
1520
+
1521
+ class Zeros(Primitive):
1522
+ r"""
1523
+ Zeros will be deprecated in the future. Please use class `mindspore.ops.zeros` instead.
1524
+
1525
+ Creates a tensor filled with value zeros.
1526
+
1527
+ Creates a tensor with shape described by the first argument and
1528
+ fills it with value zeros in type of the second argument.
1529
+
1530
+ .. warning::
1531
+ For argument `size`, Tensor type input will be deprecated in the future version.
1532
+
1533
+ Inputs:
1534
+ - **shape** (tuple[int], List[int], int, Tensor) - The specified shape of output tensor.
1535
+ - **type** (mindspore.dtype) - The specified type of output tensor.
1536
+
1537
+ Outputs:
1538
+ Tensor, whose dtype and size are defined by input.
1539
+
1540
+ Raises:
1541
+ TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1542
+
1543
+ Supported Platforms:
1544
+ ``Ascend`` ``GPU`` ``CPU``
1545
+
1546
+ Examples:
1547
+ >>> import mindspore
1548
+ >>> from mindspore import ops
1549
+ >>> zeros = ops.Zeros()
1550
+ >>> output = zeros((2, 2), mindspore.float32)
1551
+ >>> print(output)
1552
+ [[0. 0.]
1553
+ [0. 0.]]
1554
+
1555
+ """
1556
+
1557
+ __mindspore_signature__ = (
1558
+ sig.make_sig('size'),
1559
+ sig.make_sig('type', default=None),
1560
+ )
1561
+
1562
+ @prim_arg_register
1563
+ def __init__(self):
1564
+ pass
1565
+
1566
+ def __call__(self, size, type=None):
1567
+ return _convert_stub(pyboost_zeros(self, [size, type if type is None else \
1568
+ handler.dtype_to_type_id('Zeros', 'type', type)]))
1569
+
1570
+
1571
+ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
1572
+ attn_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, keep_prob=1.0,
1573
+ scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
1574
+ input_layout='BSH', sparse_mode=0):
1575
+ r"""
1576
+ The interface is not open to the public, just for internal use,
1577
+
1578
+ .. math::
1579
+ \begin{array}{ll} \\
1580
+ y = Dropout(Softmax(Mask(scale_value \mul (real_shift + query * key), attn_mask), -1), keep_prob) \\
1581
+ \mul value \\
1582
+ \end{array}
1583
+
1584
+ B -- Batch size. Value range 1 to 2k.
1585
+ S1 -- Sequence length of query. Value range 1 to 512k.
1586
+ S2 -- Sequence length of key and value. Value range 1 to 512k.
1587
+ N1 -- Num heads of query. Value range 1 to 256.
1588
+ N2 -- Num heads of key and value, and N2 must be a factor of N1.
1589
+ D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
1590
+ H1 -- Hidden size of query, which equals to N1 * D.
1591
+ H2 -- Hidden size of key and value, which equals to N2 * D.
1592
+
1593
+ .. warning::
1594
+ This is an experimental API that is subject to change or deletion. Only support on Atlas training series.
1595
+
1596
+ Args:
1597
+ query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
1598
+ `(B, N1, S1, D)`, `(S1, B, H1)`, `(B, S1, N1, D)` or `(T1, N1, D)`.
1599
+ key (Tensor[float16, bfloat16]): The key tensor. Input tensor of shape :math:`(B, S2, H2)`,
1600
+ `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`.
1601
+ value (Tensor[float16, bfloat16]): The value tensor. Input tensor of shape :math:`(B, S2, H2)`,
1602
+ `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. The key and value have the same shape.
1603
+ head_num (int): The head num of query, equal to N1.
1604
+ real_shift (Union[Tensor[float16, bfloat16], None]): Also known as pse. The position embedding code. If S
1605
+ is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of
1606
+ the lower triangle for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
1607
+ `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1608
+
1609
+ - ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
1610
+ In this scenario, real_shift is `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1611
+ - Non-ALiBi scenario: real_shift is `(B, N1, S1, S2)`, `(1, N1, S1, S2)`.
1612
+
1613
+ The shape of `real_shift` should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)` when input_layout is
1614
+ `TND`.
1615
+ drop_mask (Union[Tensor[uint8], None]): The dropout mask tensor. Input tensor of shape :math:
1616
+ `(B, N1, S1, S2 // 8) or None`. S2 is a multiple of 8 when not None.
1617
+ padding_mask (None): Reserved parameter. Not implemented yet.
1618
+ attn_mask (Union[Tensor[uint8], Tensor[bool], None]): The attention mask tensor. For each element, 0
1619
+ indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
1620
+ `(B, 1, S1, S2)`, `(S1, S2)` or `(2048, 2048)`. In compression scenario, sparse_mode is 2, 3, or 4,
1621
+ attn_mask must be `(2048, 2048)`. When sparse_mode is 5, attn_mask must be `(B, N1, S1, S2)`,
1622
+ `(B, 1, S1, S2)`. When sparse_mode is 0 and 1, attn_mask should be `(B, N1, S1, S2)`, `(B, 1, S1, S2)`,
1623
+ `(S1, S2)`.
1624
+ prefix (Union[List[int64], Tuple[int64] None]): N value of each Batch in the prefix sparse calculation
1625
+ scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when sparse_mode is 5.
1626
+ If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
1627
+ actual_seq_qlen (Union[List[int64], Tuple[int64], None]): Size of query corresponding to each batch, array
1628
+ with increasing values and the last value equal to T1.
1629
+ actual_seq_kvlen (Union[List[int64], Tuple[int64], None]): Size of key and value corresponding to each batch,
1630
+ array with increasing values and the last value equal to T2.
1631
+ keep_prob (float): The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0. when keep_prob
1632
+ is 1.0, drop_mask should be none.
1633
+ scale_value (float): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0.
1634
+ pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
1635
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1636
+ next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
1637
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1638
+ The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the
1639
+ valid area on the attn_mask matrix. It must ensure that the band is not empty.
1640
+ The following values are not allowed:
1641
+
1642
+ - pre_tokens < 0 and next_tokens < 0.
1643
+ - (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
1644
+ - (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
1645
+
1646
+ inner_precise (int): The parameter is reserved and not implemented yet. Default: 0.
1647
+ input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH", "BNSD",
1648
+ "SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH".
1649
+ When input_layout is "TND", the following restrictions must be met.
1650
+ There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
1651
+ value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
1652
+ list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
1653
+ sum(list_seq_k) = 22.
1654
+ max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
1655
+ qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
1656
+
1657
+ - The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024.
1658
+ - When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none.
1659
+ Otherwise, they are none.
1660
+ - The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
1661
+ be non-decreasing.
1662
+ - If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
1663
+ list_seq_k is greater than 1024. Real_shift should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)`, and
1664
+ S2 is equal to max_seqlen_k.
1665
+ - Attn mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask
1666
+ should be `(2048, 2048)`.
1667
+ - The shape of drop_mask is (qk_pointer * N1 // 8,).
1668
+ - Prefix is none.
1669
+ - Next_tokens is 0, and pre_tokens is not less than max_seqlen_q.
1670
+ - When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
1671
+ - 0 should not exist in list_seq_k.
1672
+
1673
+ sparse_mode (int): Indicates sparse mode. Default 0.
1674
+
1675
+ - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
1676
+ and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full
1677
+ attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and
1678
+ nextTokens needs to be calculated.
1679
+ - 1: Represents allMask, that is, passing in the complete attn_mask matrix.
1680
+ - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
1681
+ vertex, and the optimized attn_mask matrix (2048*2048) is required.
1682
+ - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
1683
+ right vertex, and the optimized attn_mask matrix (2048*2048) is required.
1684
+ - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
1685
+ optimized attn_mask matrix (2048*2048) is required.
1686
+ - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
1687
+ width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
1688
+ of each Batch axis is different, not implemented yet.
1689
+ - 6: Represents the global scenario, not implemented yet.
1690
+ - 7: Represents the dilated scenario, not implemented yet.
1691
+ - 8: Represents the block_local scenario, not implemented yet.
1692
+
1693
+ Returns:
1694
+ attention_out (Tensor[float16, bfloat16]), The output of attention, its shape, and data type are the same
1695
+ as the query.
1696
+
1697
+ Supported Platforms:
1698
+ ``Ascend``
1699
+
1700
+ Examples:
1701
+ >>> import mindspore
1702
+ >>> import mindspore.common.dtype as mstype
1703
+ >>> import numpy as np
1704
+ >>> from mindspore import ops, Tensor
1705
+ >>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
1706
+ >>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
1707
+ >>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
1708
+ >>> head_num = 4
1709
+ >>> output = ops.flash_attention_score(query, key, value, head_num)
1710
+ >>> print(output.shape)
1711
+ (2, 4, 64)
1712
+ """
1713
+ rank_op = _get_cache_prim(FlashAttentionScore)(head_num, keep_prob, scalar_value, pre_tokens, next_tokens,
1714
+ inner_precise, input_layout, sparse_mode)
1715
+ return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen,
1716
+ actual_seq_kvlen)[3]