mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1217) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +124 -25
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +299 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +182 -68
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +192 -252
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +67 -26
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +20 -7
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_pijit_context.py +190 -0
  46. mindspore/common/_register_for_adapter.py +7 -0
  47. mindspore/common/_register_for_recompute.py +48 -0
  48. mindspore/common/_register_for_tensor.py +10 -10
  49. mindspore/common/_stub_tensor.py +7 -1
  50. mindspore/common/_tensor_overload.py +139 -0
  51. mindspore/common/_utils.py +5 -17
  52. mindspore/common/api.py +449 -129
  53. mindspore/common/auto_dynamic_shape.py +27 -14
  54. mindspore/common/dtype.py +17 -10
  55. mindspore/common/dump.py +8 -11
  56. mindspore/common/file_system.py +48 -0
  57. mindspore/common/generator.py +254 -0
  58. mindspore/common/hook_handle.py +65 -30
  59. mindspore/common/initializer.py +1 -1
  60. mindspore/common/jit_config.py +34 -14
  61. mindspore/common/lazy_inline.py +72 -19
  62. mindspore/common/mindir_util.py +12 -2
  63. mindspore/common/mutable.py +79 -14
  64. mindspore/common/no_inline.py +54 -0
  65. mindspore/common/np_dtype.py +25 -0
  66. mindspore/common/parameter.py +73 -21
  67. mindspore/common/recompute.py +292 -0
  68. mindspore/common/seed.py +9 -9
  69. mindspore/common/sparse_tensor.py +276 -24
  70. mindspore/common/symbol.py +122 -0
  71. mindspore/common/tensor.py +668 -514
  72. mindspore/communication/__init__.py +6 -11
  73. mindspore/communication/_comm_helper.py +43 -3
  74. mindspore/communication/comm_func.py +1395 -0
  75. mindspore/communication/management.py +117 -104
  76. mindspore/config/op_info.config +22 -54
  77. mindspore/context.py +455 -71
  78. mindspore/dataset/__init__.py +5 -5
  79. mindspore/dataset/audio/__init__.py +6 -6
  80. mindspore/dataset/audio/transforms.py +711 -158
  81. mindspore/dataset/callback/ds_callback.py +2 -2
  82. mindspore/dataset/core/config.py +7 -0
  83. mindspore/dataset/core/validator_helpers.py +7 -0
  84. mindspore/dataset/engine/cache_client.py +2 -2
  85. mindspore/dataset/engine/datasets.py +201 -116
  86. mindspore/dataset/engine/datasets_audio.py +14 -14
  87. mindspore/dataset/engine/datasets_standard_format.py +83 -3
  88. mindspore/dataset/engine/datasets_text.py +39 -39
  89. mindspore/dataset/engine/datasets_user_defined.py +230 -141
  90. mindspore/dataset/engine/datasets_vision.py +78 -74
  91. mindspore/dataset/engine/iterators.py +29 -0
  92. mindspore/dataset/engine/obs/util.py +7 -0
  93. mindspore/dataset/engine/offload.py +5 -7
  94. mindspore/dataset/engine/queue.py +138 -66
  95. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  96. mindspore/dataset/engine/validators.py +41 -15
  97. mindspore/dataset/text/__init__.py +2 -5
  98. mindspore/dataset/text/transforms.py +408 -121
  99. mindspore/dataset/text/utils.py +9 -9
  100. mindspore/dataset/transforms/__init__.py +0 -3
  101. mindspore/dataset/transforms/transforms.py +261 -76
  102. mindspore/dataset/utils/browse_dataset.py +9 -9
  103. mindspore/dataset/utils/line_reader.py +2 -0
  104. mindspore/dataset/vision/__init__.py +7 -10
  105. mindspore/dataset/vision/c_transforms.py +10 -10
  106. mindspore/dataset/vision/py_transforms_util.py +1 -1
  107. mindspore/dataset/vision/transforms.py +2844 -549
  108. mindspore/dataset/vision/utils.py +161 -10
  109. mindspore/dataset/vision/validators.py +16 -3
  110. mindspore/dnnl.dll +0 -0
  111. mindspore/dpcmi.dll +0 -0
  112. mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
  113. mindspore/experimental/es/embedding_service.py +883 -0
  114. mindspore/experimental/es/embedding_service_layer.py +581 -0
  115. mindspore/experimental/llm_boost/__init__.py +21 -0
  116. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  119. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  120. mindspore/experimental/llm_boost/register.py +129 -0
  121. mindspore/experimental/llm_boost/utils.py +31 -0
  122. mindspore/experimental/optim/__init__.py +12 -2
  123. mindspore/experimental/optim/adadelta.py +161 -0
  124. mindspore/experimental/optim/adagrad.py +168 -0
  125. mindspore/experimental/optim/adam.py +35 -34
  126. mindspore/experimental/optim/adamax.py +170 -0
  127. mindspore/experimental/optim/adamw.py +124 -15
  128. mindspore/experimental/optim/asgd.py +153 -0
  129. mindspore/experimental/optim/lr_scheduler.py +66 -121
  130. mindspore/experimental/optim/nadam.py +157 -0
  131. mindspore/experimental/optim/optimizer.py +18 -8
  132. mindspore/experimental/optim/radam.py +194 -0
  133. mindspore/experimental/optim/rmsprop.py +154 -0
  134. mindspore/experimental/optim/rprop.py +164 -0
  135. mindspore/experimental/optim/sgd.py +28 -19
  136. mindspore/hal/__init__.py +40 -0
  137. mindspore/hal/_ascend.py +57 -0
  138. mindspore/hal/_base.py +57 -0
  139. mindspore/hal/_cpu.py +56 -0
  140. mindspore/hal/_gpu.py +57 -0
  141. mindspore/hal/contiguous_tensors_handle.py +175 -0
  142. mindspore/hal/device.py +356 -0
  143. mindspore/hal/event.py +179 -0
  144. mindspore/hal/memory.py +326 -0
  145. mindspore/hal/stream.py +357 -0
  146. mindspore/include/api/data_type.h +2 -2
  147. mindspore/include/api/dual_abi_helper.h +16 -3
  148. mindspore/include/api/model.h +4 -3
  149. mindspore/include/api/model_group.h +13 -1
  150. mindspore/include/api/status.h +14 -0
  151. mindspore/include/api/types.h +10 -10
  152. mindspore/include/c_api/model_c.h +173 -0
  153. mindspore/include/c_api/types_c.h +19 -0
  154. mindspore/include/dataset/config.h +2 -2
  155. mindspore/include/dataset/constants.h +2 -2
  156. mindspore/include/dataset/execute.h +3 -5
  157. mindspore/include/dataset/vision.h +58 -2
  158. mindspore/jpeg62.dll +0 -0
  159. mindspore/log.py +3 -3
  160. mindspore/mindrecord/__init__.py +5 -1
  161. mindspore/mindrecord/config.py +809 -0
  162. mindspore/mindrecord/filereader.py +25 -0
  163. mindspore/mindrecord/filewriter.py +138 -103
  164. mindspore/mindrecord/mindpage.py +40 -6
  165. mindspore/mindrecord/shardutils.py +3 -2
  166. mindspore/mindrecord/shardwriter.py +7 -0
  167. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  168. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  169. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  170. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  171. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  172. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  173. mindspore/mindspore_backend.dll +0 -0
  174. mindspore/mindspore_common.dll +0 -0
  175. mindspore/mindspore_core.dll +0 -0
  176. mindspore/mindspore_glog.dll +0 -0
  177. mindspore/mindspore_np_dtype.dll +0 -0
  178. mindspore/mindspore_ops.dll +0 -0
  179. mindspore/mint/__init__.py +1586 -0
  180. mindspore/mint/distributed/__init__.py +31 -0
  181. mindspore/mint/distributed/distributed.py +254 -0
  182. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  183. mindspore/mint/nn/__init__.py +757 -0
  184. mindspore/mint/nn/functional.py +679 -0
  185. mindspore/mint/nn/layer/__init__.py +39 -0
  186. mindspore/mint/nn/layer/activation.py +133 -0
  187. mindspore/mint/nn/layer/normalization.py +477 -0
  188. mindspore/mint/nn/layer/pooling.py +110 -0
  189. mindspore/mint/optim/__init__.py +24 -0
  190. mindspore/mint/optim/adamw.py +206 -0
  191. mindspore/mint/special/__init__.py +63 -0
  192. mindspore/msobj140.dll +0 -0
  193. mindspore/mspdb140.dll +0 -0
  194. mindspore/mspdbcore.dll +0 -0
  195. mindspore/mspdbst.dll +0 -0
  196. mindspore/mspft140.dll +0 -0
  197. mindspore/msvcdis140.dll +0 -0
  198. mindspore/msvcp140_1.dll +0 -0
  199. mindspore/msvcp140_2.dll +0 -0
  200. mindspore/msvcp140_atomic_wait.dll +0 -0
  201. mindspore/msvcp140_codecvt_ids.dll +0 -0
  202. mindspore/multiprocessing/__init__.py +73 -0
  203. mindspore/nn/cell.py +461 -323
  204. mindspore/nn/dynamic_lr.py +2 -2
  205. mindspore/nn/layer/activation.py +292 -135
  206. mindspore/nn/layer/basic.py +288 -83
  207. mindspore/nn/layer/channel_shuffle.py +3 -16
  208. mindspore/nn/layer/container.py +3 -3
  209. mindspore/nn/layer/conv.py +75 -66
  210. mindspore/nn/layer/embedding.py +221 -45
  211. mindspore/nn/layer/image.py +4 -7
  212. mindspore/nn/layer/math.py +1 -1
  213. mindspore/nn/layer/normalization.py +150 -68
  214. mindspore/nn/layer/padding.py +64 -87
  215. mindspore/nn/layer/pooling.py +175 -12
  216. mindspore/nn/layer/rnn_cells.py +6 -16
  217. mindspore/nn/layer/rnns.py +6 -5
  218. mindspore/nn/layer/thor_layer.py +1 -2
  219. mindspore/nn/layer/timedistributed.py +1 -1
  220. mindspore/nn/layer/transformer.py +55 -53
  221. mindspore/nn/learning_rate_schedule.py +6 -5
  222. mindspore/nn/loss/__init__.py +2 -2
  223. mindspore/nn/loss/loss.py +145 -88
  224. mindspore/nn/optim/__init__.py +2 -1
  225. mindspore/nn/optim/ada_grad.py +4 -2
  226. mindspore/nn/optim/adadelta.py +4 -2
  227. mindspore/nn/optim/adafactor.py +1 -1
  228. mindspore/nn/optim/adam.py +102 -181
  229. mindspore/nn/optim/adamax.py +4 -2
  230. mindspore/nn/optim/adasum.py +3 -3
  231. mindspore/nn/optim/asgd.py +4 -2
  232. mindspore/nn/optim/ftrl.py +31 -61
  233. mindspore/nn/optim/lamb.py +5 -3
  234. mindspore/nn/optim/lars.py +2 -2
  235. mindspore/nn/optim/lazyadam.py +6 -4
  236. mindspore/nn/optim/momentum.py +13 -25
  237. mindspore/nn/optim/optimizer.py +6 -3
  238. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  239. mindspore/nn/optim/rmsprop.py +9 -3
  240. mindspore/nn/optim/rprop.py +4 -2
  241. mindspore/nn/optim/sgd.py +5 -3
  242. mindspore/nn/optim/tft_wrapper.py +127 -0
  243. mindspore/nn/optim/thor.py +2 -2
  244. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  245. mindspore/nn/probability/distribution/beta.py +2 -2
  246. mindspore/nn/probability/distribution/categorical.py +4 -6
  247. mindspore/nn/probability/distribution/cauchy.py +2 -2
  248. mindspore/nn/probability/distribution/exponential.py +2 -2
  249. mindspore/nn/probability/distribution/geometric.py +1 -1
  250. mindspore/nn/probability/distribution/gumbel.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +1 -1
  252. mindspore/nn/probability/distribution/poisson.py +2 -2
  253. mindspore/nn/probability/distribution/uniform.py +2 -2
  254. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  255. mindspore/nn/wrap/__init__.py +2 -1
  256. mindspore/nn/wrap/cell_wrapper.py +46 -12
  257. mindspore/nn/wrap/grad_reducer.py +148 -8
  258. mindspore/nn/wrap/loss_scale.py +44 -7
  259. mindspore/numpy/__init__.py +2 -0
  260. mindspore/numpy/array_creations.py +67 -68
  261. mindspore/numpy/array_ops.py +70 -66
  262. mindspore/numpy/dtypes.py +3 -3
  263. mindspore/numpy/fft.py +966 -0
  264. mindspore/numpy/logic_ops.py +11 -10
  265. mindspore/numpy/math_ops.py +147 -152
  266. mindspore/numpy/utils.py +3 -0
  267. mindspore/numpy/utils_const.py +4 -4
  268. mindspore/opencv_core452.dll +0 -0
  269. mindspore/opencv_imgcodecs452.dll +0 -0
  270. mindspore/opencv_imgproc452.dll +0 -0
  271. mindspore/ops/__init__.py +9 -6
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +162 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +147 -59
  295. mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +201 -66
  307. mindspore/ops/composite/math_ops.py +10 -49
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/function/__init__.py +53 -11
  340. mindspore/ops/function/array_func.py +1269 -1821
  341. mindspore/ops/function/clip_func.py +19 -31
  342. mindspore/ops/function/debug_func.py +114 -5
  343. mindspore/ops/function/fft_func.py +44 -0
  344. mindspore/ops/function/grad/grad_func.py +30 -22
  345. mindspore/ops/function/image_func.py +27 -21
  346. mindspore/ops/function/linalg_func.py +35 -68
  347. mindspore/ops/function/math_func.py +1170 -2697
  348. mindspore/ops/function/nn_func.py +2116 -1128
  349. mindspore/ops/function/other_func.py +8 -8
  350. mindspore/ops/function/parameter_func.py +5 -93
  351. mindspore/ops/function/random_func.py +435 -113
  352. mindspore/ops/function/reshard_func.py +104 -0
  353. mindspore/ops/function/sparse_func.py +4 -4
  354. mindspore/ops/function/sparse_unary_func.py +9 -16
  355. mindspore/ops/function/spectral_func.py +1 -1
  356. mindspore/ops/function/vmap_func.py +16 -15
  357. mindspore/ops/functional.py +355 -346
  358. mindspore/ops/op_info_register.py +18 -45
  359. mindspore/ops/operations/__init__.py +38 -24
  360. mindspore/ops/operations/_grad_ops.py +21 -927
  361. mindspore/ops/operations/_infer_ops.py +19 -0
  362. mindspore/ops/operations/_inner_ops.py +173 -607
  363. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/_tensor_array.py +8 -8
  367. mindspore/ops/operations/array_ops.py +106 -2837
  368. mindspore/ops/operations/comm_ops.py +799 -127
  369. mindspore/ops/operations/custom_ops.py +124 -119
  370. mindspore/ops/operations/debug_ops.py +142 -41
  371. mindspore/ops/operations/image_ops.py +1 -217
  372. mindspore/ops/operations/inner_ops.py +5 -40
  373. mindspore/ops/operations/linalg_ops.py +1 -49
  374. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  375. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  376. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  377. mindspore/ops/operations/math_ops.py +666 -4972
  378. mindspore/ops/operations/nn_ops.py +205 -2213
  379. mindspore/ops/operations/other_ops.py +60 -49
  380. mindspore/ops/operations/random_ops.py +50 -54
  381. mindspore/ops/operations/reshard_ops.py +53 -0
  382. mindspore/ops/operations/sparse_ops.py +4 -4
  383. mindspore/ops/primitive.py +216 -103
  384. mindspore/ops_generate/__init__.py +27 -0
  385. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  386. mindspore/ops_generate/arg_handler.py +197 -0
  387. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  388. mindspore/ops_generate/gen_constants.py +36 -0
  389. mindspore/ops_generate/gen_ops.py +1099 -0
  390. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  391. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  392. mindspore/ops_generate/gen_utils.py +209 -0
  393. mindspore/ops_generate/op_proto.py +145 -0
  394. mindspore/ops_generate/pyboost_utils.py +367 -0
  395. mindspore/ops_generate/template.py +261 -0
  396. mindspore/parallel/__init__.py +8 -4
  397. mindspore/parallel/_auto_parallel_context.py +100 -10
  398. mindspore/parallel/_cell_wrapper.py +99 -9
  399. mindspore/parallel/_cost_model_context.py +1 -1
  400. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  401. mindspore/parallel/_parallel_serialization.py +67 -23
  402. mindspore/parallel/_ps_context.py +1 -1
  403. mindspore/parallel/_recovery_context.py +1 -1
  404. mindspore/parallel/_tensor.py +99 -22
  405. mindspore/parallel/_transformer/__init__.py +1 -1
  406. mindspore/parallel/_transformer/layers.py +1 -1
  407. mindspore/parallel/_transformer/loss.py +1 -1
  408. mindspore/parallel/_transformer/moe.py +1 -1
  409. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  410. mindspore/parallel/_transformer/transformer.py +2 -2
  411. mindspore/parallel/_utils.py +173 -6
  412. mindspore/parallel/algo_parameter_config.py +8 -10
  413. mindspore/parallel/checkpoint_transform.py +204 -38
  414. mindspore/parallel/cluster/__init__.py +15 -0
  415. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  416. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  417. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  418. mindspore/parallel/cluster/run.py +136 -0
  419. mindspore/parallel/mpi/__init__.py +1 -1
  420. mindspore/parallel/mpi/_mpi_config.py +1 -1
  421. mindspore/parallel/parameter_broadcast.py +151 -0
  422. mindspore/parallel/shard.py +279 -37
  423. mindspore/parallel/transform_safetensors.py +993 -0
  424. mindspore/pgodb140.dll +0 -0
  425. mindspore/pgort140.dll +0 -0
  426. mindspore/profiler/__init__.py +4 -2
  427. mindspore/profiler/common/constant.py +29 -0
  428. mindspore/profiler/common/process_pool.py +41 -0
  429. mindspore/profiler/common/registry.py +47 -0
  430. mindspore/profiler/common/singleton.py +28 -0
  431. mindspore/profiler/common/util.py +153 -0
  432. mindspore/profiler/dynamic_profiler.py +694 -0
  433. mindspore/profiler/envprofiling.py +18 -20
  434. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  435. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  436. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  437. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  438. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  439. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  440. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  441. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  442. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  443. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  444. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  445. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  446. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  447. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  448. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  449. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  450. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  451. mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
  452. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  453. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  454. mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
  455. mindspore/profiler/parser/base_timeline_generator.py +25 -25
  456. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  457. mindspore/profiler/parser/framework_parser.py +4 -393
  458. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  459. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  460. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  461. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  462. mindspore/profiler/parser/integrator.py +3 -1
  463. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  464. mindspore/profiler/parser/minddata_parser.py +72 -3
  465. mindspore/profiler/parser/profiler_info.py +94 -7
  466. mindspore/profiler/profiler.py +153 -0
  467. mindspore/profiler/profiling.py +631 -508
  468. mindspore/rewrite/__init__.py +2 -14
  469. mindspore/rewrite/api/node.py +122 -36
  470. mindspore/rewrite/api/pattern_engine.py +2 -3
  471. mindspore/rewrite/api/scoped_value.py +16 -15
  472. mindspore/rewrite/api/symbol_tree.py +45 -29
  473. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  474. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  475. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  476. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  477. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  478. mindspore/rewrite/common/__init__.py +1 -2
  479. mindspore/rewrite/common/config.py +24 -0
  480. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  481. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  482. mindspore/rewrite/common/namespace.py +118 -0
  483. mindspore/rewrite/node/__init__.py +5 -5
  484. mindspore/rewrite/node/call_function.py +23 -7
  485. mindspore/rewrite/node/cell_container.py +7 -3
  486. mindspore/rewrite/node/control_flow.py +53 -28
  487. mindspore/rewrite/node/node.py +212 -196
  488. mindspore/rewrite/node/node_manager.py +51 -22
  489. mindspore/rewrite/node/node_topological_manager.py +3 -23
  490. mindspore/rewrite/parsers/__init__.py +12 -0
  491. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  492. mindspore/rewrite/parsers/assign_parser.py +637 -413
  493. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  494. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  495. mindspore/rewrite/parsers/constant_parser.py +5 -5
  496. mindspore/rewrite/parsers/container_parser.py +4 -6
  497. mindspore/rewrite/parsers/expr_parser.py +55 -0
  498. mindspore/rewrite/parsers/for_parser.py +31 -98
  499. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  500. mindspore/rewrite/parsers/if_parser.py +28 -10
  501. mindspore/rewrite/parsers/module_parser.py +8 -182
  502. mindspore/rewrite/parsers/parser.py +1 -5
  503. mindspore/rewrite/parsers/parser_register.py +1 -1
  504. mindspore/rewrite/parsers/return_parser.py +5 -10
  505. mindspore/rewrite/parsers/while_parser.py +59 -0
  506. mindspore/rewrite/sparsify/utils.py +1 -1
  507. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  508. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
  509. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  510. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  511. mindspore/run_check/_check_version.py +40 -115
  512. mindspore/run_check/run_check.py +1 -1
  513. mindspore/safeguard/rewrite_obfuscation.py +597 -263
  514. mindspore/swresample-4.dll +0 -0
  515. mindspore/swscale-6.dll +0 -0
  516. mindspore/tbbmalloc.dll +0 -0
  517. mindspore/tinyxml2.dll +0 -0
  518. mindspore/train/__init__.py +7 -5
  519. mindspore/train/_utils.py +204 -4
  520. mindspore/train/amp.py +335 -295
  521. mindspore/train/anf_ir_pb2.py +14 -2
  522. mindspore/train/callback/__init__.py +5 -2
  523. mindspore/train/callback/_backup_and_restore.py +5 -5
  524. mindspore/train/callback/_callback.py +4 -4
  525. mindspore/train/callback/_checkpoint.py +220 -43
  526. mindspore/train/callback/_cluster_monitor.py +201 -0
  527. mindspore/train/callback/_early_stop.py +2 -2
  528. mindspore/train/callback/_flops_collector.py +239 -0
  529. mindspore/train/callback/_landscape.py +15 -9
  530. mindspore/train/callback/_loss_monitor.py +5 -5
  531. mindspore/train/callback/_on_request_exit.py +136 -33
  532. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  533. mindspore/train/callback/_summary_collector.py +12 -12
  534. mindspore/train/callback/_tft_register.py +352 -0
  535. mindspore/train/callback/_time_monitor.py +3 -3
  536. mindspore/train/data_sink.py +6 -5
  537. mindspore/train/dataset_helper.py +66 -23
  538. mindspore/train/loss_scale_manager.py +2 -2
  539. mindspore/train/metrics/accuracy.py +7 -7
  540. mindspore/train/metrics/confusion_matrix.py +8 -6
  541. mindspore/train/metrics/cosine_similarity.py +6 -4
  542. mindspore/train/metrics/error.py +2 -2
  543. mindspore/train/metrics/metric.py +3 -3
  544. mindspore/train/metrics/perplexity.py +2 -1
  545. mindspore/train/metrics/roc.py +4 -4
  546. mindspore/train/metrics/topk.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +116 -37
  548. mindspore/train/model.py +382 -76
  549. mindspore/train/serialization.py +787 -288
  550. mindspore/train/summary/_summary_adapter.py +1 -1
  551. mindspore/train/summary/summary_record.py +51 -28
  552. mindspore/train/train_thor/convert_utils.py +3 -3
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +21 -0
  555. mindspore/utils/utils.py +60 -0
  556. mindspore/vcmeta.dll +0 -0
  557. mindspore/vcruntime140.dll +0 -0
  558. mindspore/vcruntime140_1.dll +0 -0
  559. mindspore/version.py +1 -1
  560. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
  561. mindspore-2.4.0.dist-info/RECORD +1406 -0
  562. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
  563. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  564. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  565. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  566. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  567. mindspore/gen_ops.py +0 -273
  568. mindspore/include/c_api/ms/abstract.h +0 -67
  569. mindspore/include/c_api/ms/attribute.h +0 -197
  570. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  571. mindspore/include/c_api/ms/base/macros.h +0 -32
  572. mindspore/include/c_api/ms/base/status.h +0 -33
  573. mindspore/include/c_api/ms/base/types.h +0 -282
  574. mindspore/include/c_api/ms/context.h +0 -102
  575. mindspore/include/c_api/ms/graph.h +0 -160
  576. mindspore/include/c_api/ms/node.h +0 -606
  577. mindspore/include/c_api/ms/tensor.h +0 -161
  578. mindspore/include/c_api/ms/value.h +0 -84
  579. mindspore/mindspore_shared_lib.dll +0 -0
  580. mindspore/nn/layer/flash_attention.py +0 -189
  581. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  582. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  583. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  584. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  585. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  586. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  587. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  588. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  589. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  590. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  591. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  592. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  593. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  594. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  595. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  596. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  597. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  601. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  602. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  603. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  604. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  605. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  606. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  607. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  608. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  609. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  610. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  611. mindspore/ops/_op_impl/tbe/add.py +0 -42
  612. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  613. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  614. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  615. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  616. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  617. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  618. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  620. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  621. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  622. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  623. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  624. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  625. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  626. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  627. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  628. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  629. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  631. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  632. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  633. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  634. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  635. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  636. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  637. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  638. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  639. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  640. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  641. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  642. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  643. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  644. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  645. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  646. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  647. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  648. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  649. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  650. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  651. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  653. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  654. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  655. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  656. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  658. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  659. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  660. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  661. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  662. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  663. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  664. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  666. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  667. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  668. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  669. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  670. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  671. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  672. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  673. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  674. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  675. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  676. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  677. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  678. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  680. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  681. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  682. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  683. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  684. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  685. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  686. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  687. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  688. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  689. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  690. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  691. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  692. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  693. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  694. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  695. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  696. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  697. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  698. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  699. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  700. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  701. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  702. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  703. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  705. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  706. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  707. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  709. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  710. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  711. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  712. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  713. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  714. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  715. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  716. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  717. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  718. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  719. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  720. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  721. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  722. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  724. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  725. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  726. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  727. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  728. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  729. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  730. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  731. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  732. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  733. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  734. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  735. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  737. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  738. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  739. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  740. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  741. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  742. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  743. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  744. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  745. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  746. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  747. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  748. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  749. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  750. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  751. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  752. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  754. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  756. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  757. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  758. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  759. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  760. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  761. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  762. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  763. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  764. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  765. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  766. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  767. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  768. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  769. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  770. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  771. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  772. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  773. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  774. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  776. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  777. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  778. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  779. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  780. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  781. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  782. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  783. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  784. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  785. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  786. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  787. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  788. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  789. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  790. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  791. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  792. mindspore/ops/_op_impl/tbe/div.py +0 -41
  793. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  794. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  795. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  797. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  798. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  799. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  800. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  801. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  802. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  803. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  804. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  805. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  806. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  807. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  809. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  810. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  811. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  813. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  814. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  815. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  816. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  817. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  818. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  819. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  820. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  821. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  822. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  823. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  824. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  825. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  826. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  827. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  828. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  829. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  830. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  831. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  832. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  833. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  834. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  835. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  836. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  837. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  838. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  839. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  840. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  841. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  842. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  843. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  844. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  845. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  846. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  847. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  848. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  849. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  850. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  851. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  852. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  853. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  854. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  855. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  856. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  857. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  858. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  859. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  861. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  862. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  863. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  864. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  865. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  866. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  867. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  868. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  869. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  870. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  871. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  872. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  873. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  874. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  875. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  876. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  877. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  878. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  879. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  881. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  883. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  884. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  885. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  886. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  887. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  888. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  889. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  890. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  891. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  892. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  893. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  895. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  896. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  897. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  898. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  899. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  900. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  901. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  902. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  903. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  904. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  905. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  906. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  907. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  908. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  909. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  910. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  911. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  912. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  913. mindspore/ops/_op_impl/tbe/less.py +0 -41
  914. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  915. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  916. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  917. mindspore/ops/_op_impl/tbe/log.py +0 -40
  918. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  919. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  921. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  922. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  923. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  924. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  925. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  926. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  928. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  929. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  930. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  931. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  932. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  933. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  934. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  935. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  936. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  937. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  938. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  939. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  940. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  941. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  942. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  943. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  944. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  945. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  946. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  947. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  948. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  949. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  950. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  951. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  952. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  953. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  954. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  955. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  956. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  957. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  958. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  959. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  960. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  962. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  963. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  964. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  965. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  966. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  967. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  968. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  969. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  971. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  972. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  974. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  975. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  976. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  977. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  978. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  980. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  981. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  982. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  983. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  984. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  985. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  986. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  987. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  988. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  989. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  990. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  991. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  992. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  993. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  994. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  995. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  996. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  997. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  998. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  999. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  1000. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  1002. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  1003. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  1004. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  1005. mindspore/ops/_op_impl/tbe/range.py +0 -39
  1006. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  1007. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  1008. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  1009. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  1013. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  1014. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  1017. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  1018. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  1020. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  1022. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  1023. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  1024. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  1026. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  1027. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  1029. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  1030. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  1032. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  1033. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  1034. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  1035. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  1036. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  1037. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  1038. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  1039. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1041. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1042. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1043. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1045. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1051. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1052. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1053. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1054. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1055. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1060. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1061. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1062. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1063. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1064. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1065. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1066. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1067. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1068. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1069. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1070. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1071. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1072. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1073. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1074. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1075. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1076. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1077. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1078. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1079. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1080. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1083. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1084. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1085. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1086. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1087. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1088. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1089. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1090. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1092. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1093. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1097. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1099. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1100. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1101. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1102. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1103. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1104. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1106. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1107. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1110. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1111. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1115. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1118. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1119. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1122. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1123. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1124. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1125. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1126. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1127. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1128. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1129. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1130. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1132. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1133. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1134. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1135. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1136. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1137. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1138. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1143. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1144. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1145. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1146. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1147. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1148. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1149. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1151. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1152. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1154. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1155. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1156. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1157. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1158. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1159. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1160. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1161. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1162. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1163. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1164. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1165. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1166. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1167. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1168. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1169. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1170. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1171. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1172. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1173. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1174. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1175. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1176. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1177. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1178. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1179. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1180. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1181. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1182. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1183. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1184. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1185. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1186. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1187. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1188. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1189. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1190. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1191. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1192. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1193. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1194. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1195. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1196. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1197. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1198. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1199. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1200. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1201. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1202. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1203. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1204. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1205. mindspore/ops/_tracefunc.py +0 -241
  1206. mindspore/ops/arg_dtype_cast.py +0 -54
  1207. mindspore/ops/silent_check.py +0 -162
  1208. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  1209. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  1210. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1211. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1212. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1213. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1214. mindspore/rewrite/namespace.py +0 -53
  1215. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1216. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  1217. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2271 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Operators for nn."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+
19
+ import numbers
20
+ import math
21
+ import types
22
+ import numpy as np
23
+ from mindspore.ops import signature as sig
24
+ from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
25
+ from mindspore.ops._primitive_cache import _get_cache_prim
26
+ from mindspore.ops.auto_generate import gen_arg_handler as handler
27
+ from mindspore.common import Tensor, CSRTensor, COOTensor
28
+ from mindspore.common._stub_tensor import _convert_stub
29
+ from mindspore._c_expression import typing
30
+ from mindspore._c_expression import Tensor as Tensor_
31
+ from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
32
+ from mindspore.common import dtype as mstype
33
+ from mindspore.common._utils import is_shape_unknown
34
+ from mindspore import _checkparam as validator
35
+ from mindspore.ops.operations.manually_defined._inner import ScalarCast
36
+ from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
37
+ from mindspore.common.initializer import Zero
38
+ from mindspore.common.parameter import Parameter
39
+ from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
40
+
41
+
42
+ dtype_to_type_id = DtypeToEnum()
43
+
44
+
45
+ dtype_to_type_id = DtypeToEnum()
46
+
47
+
48
+ class ScalarDiv(Primitive):
49
+ r"""
50
+ Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
51
+
52
+ .. math::
53
+
54
+ out_{i} = \frac{x_i}{y_i}
55
+
56
+ .. note::
57
+ The inputs can be constant/variable value. Usage is the same as '/' in Python.
58
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
59
+
60
+ Inputs:
61
+ - **x** (Scalar) - A constant or variable scalar.
62
+ - **y** (Scalar) - A constant or variable scalar.
63
+
64
+ Outputs:
65
+ Scalar, the type of scalar is float.
66
+
67
+ Raises:
68
+ TypeError: If `x` and `y` are not scalar.
69
+ ValueError: If `y` is 0.
70
+
71
+ Supported Platforms:
72
+ ``Ascend`` ``GPU`` ``CPU``
73
+ """
74
+ @prim_attr_register
75
+ def __init__(self):
76
+ """Initialize ScalarDiv"""
77
+
78
+ def __call__(self, x, y):
79
+ if y == 0:
80
+ raise ValueError('The divisor could not be zero. But the divisor is zero now.')
81
+ return x / y
82
+
83
+
84
+ class ScalarFloorDiv(Primitive):
85
+ r"""
86
+ Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
87
+
88
+ .. math::
89
+
90
+ out_{i} = \frac{x_i}{y_i}
91
+
92
+ .. note::
93
+ The inputs can be constant/variable value. Usage is the same as '//' in Python.
94
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
95
+
96
+ Inputs:
97
+ - **x** (Scalar) - A constant or variable scalar.
98
+ - **y** (Scalar) - A constant or variable scalar.
99
+
100
+ Outputs:
101
+ Scalar, the type of scalar is float.
102
+
103
+ Raises:
104
+ TypeError: If `x` and `y` are not scalar.
105
+ ValueError: If `y` is 0.
106
+
107
+ Supported Platforms:
108
+ ``Ascend`` ``GPU`` ``CPU``
109
+ """
110
+ @prim_attr_register
111
+ def __init__(self):
112
+ """Initialize ScalarFloorDiv"""
113
+ self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
114
+
115
+ def __call__(self, x, y):
116
+ if y == 0:
117
+ raise ValueError('The divisor could not be zero. But the divisor is zero now.')
118
+ return x // y
119
+
120
+
121
+ class ScalarAdd(Primitive):
122
+ r"""
123
+ Adds two input scalar.
124
+
125
+ .. note::
126
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
127
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
128
+
129
+ Inputs:
130
+ - **x** (Scalar) - A constant or variable scalar.
131
+ - **y** (Scalar) - A constant or variable scalar.
132
+
133
+ Outputs:
134
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
135
+
136
+ Raises:
137
+ TypeError: If `x` and `y` are not scalar.
138
+
139
+ Supported Platforms:
140
+ ``Ascend`` ``GPU`` ``CPU``
141
+ """
142
+ @prim_attr_register
143
+ def __init__(self):
144
+ """Initialize ScalarAdd"""
145
+
146
+ def __call__(self, x, y):
147
+ return x + y
148
+
149
+
150
+ class ScalarPow(Primitive):
151
+ r"""
152
+ Pow two input scalar.
153
+
154
+ .. note::
155
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
156
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
157
+
158
+ Inputs:
159
+ - **x** (Scalar) - A constant or variable scalar.
160
+ - **y** (Scalar) - A constant or variable scalar.
161
+
162
+ Outputs:
163
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
164
+
165
+ Raises:
166
+ TypeError: If `x` and `y` are not scalar.
167
+
168
+ Supported Platforms:
169
+ ``Ascend`` ``GPU`` ``CPU``
170
+ """
171
+ @prim_attr_register
172
+ def __init__(self):
173
+ """Initialize ScalarPow"""
174
+
175
+ def __call__(self, x, y):
176
+ return pow(x, y)
177
+
178
+
179
+ class ScalarLog(Primitive):
180
+ r"""
181
+ Log input scalar.
182
+
183
+ .. note::
184
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
185
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
186
+
187
+ Inputs:
188
+ - **x** (Scalar) - A constant or variable scalar.
189
+
190
+ Outputs:
191
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
192
+
193
+ Raises:
194
+ TypeError: If `x` and `y` are not scalar.
195
+
196
+ Supported Platforms:
197
+ ``Ascend`` ``GPU`` ``CPU``
198
+ """
199
+ @prim_attr_register
200
+ def __init__(self):
201
+ """Initialize ScalarAdd"""
202
+
203
+ def __call__(self, x):
204
+ return math.log(x)
205
+
206
+
207
+ class ScalarUadd(Primitive):
208
+ r"""
209
+ UAdds input scalar.
210
+
211
+ .. note::
212
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
213
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
214
+
215
+ Inputs:
216
+ - **x** (Scalar) - A constant or variable scalar.
217
+
218
+ Outputs:
219
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
220
+
221
+ Raises:
222
+ TypeError: If `x` and `y` are not scalar.
223
+
224
+ Supported Platforms:
225
+ ``Ascend`` ``GPU`` ``CPU``
226
+ """
227
+ @prim_attr_register
228
+ def __init__(self):
229
+ """Initialize ScalarAdd"""
230
+
231
+ def __call__(self, x):
232
+ return x
233
+
234
+
235
+ class ScalarUsub(Primitive):
236
+ r"""
237
+ usub input scalar.
238
+
239
+ .. note::
240
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
241
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
242
+
243
+ Inputs:
244
+ - **x** (Scalar) - A constant or variable scalar.
245
+ - **y** (Scalar) - A constant or variable scalar.
246
+
247
+ Outputs:
248
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
249
+
250
+ Raises:
251
+ TypeError: If `x` and `y` are not scalar.
252
+
253
+ Supported Platforms:
254
+ ``Ascend`` ``GPU`` ``CPU``
255
+ """
256
+ @prim_attr_register
257
+ def __init__(self):
258
+ """Initialize ScalarUsub"""
259
+
260
+ def __call__(self, x):
261
+ return -x
262
+
263
+
264
+ class ScalarSub(Primitive):
265
+ r"""
266
+ Subtracts the second input Scalar from the first input Scalar.
267
+
268
+ .. note::
269
+ The inputs can be constant/variable value. Usage is the same as '-' in Python.
270
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
271
+
272
+ Inputs:
273
+ - **x** (Scalar) - A constant or variable scalar.
274
+ - **y** (Scalar) - A constant or variable scalar.
275
+
276
+ Outputs:
277
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
278
+
279
+ Raises:
280
+ TypeError: If `x` and `y` are not scalar.
281
+
282
+ Supported Platforms:
283
+ ``Ascend`` ``GPU`` ``CPU``
284
+ """
285
+ @prim_attr_register
286
+ def __init__(self):
287
+ """Initialize ScalarSub"""
288
+
289
+ def __call__(self, x, y):
290
+ return x - y
291
+
292
+
293
+ class ScalarMul(Primitive):
294
+ r"""
295
+ Muls two input scalar.
296
+
297
+ .. note::
298
+ The inputs can be constant/variable value. Usage is the same as '+' in Python.
299
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
300
+
301
+ Inputs:
302
+ - **x** (Scalar) - A constant or variable scalar.
303
+ - **y** (Scalar) - A constant or variable scalar.
304
+
305
+ Outputs:
306
+ Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
307
+
308
+ Raises:
309
+ TypeError: If `x` and `y` are not scalar.
310
+
311
+ Supported Platforms:
312
+ ``Ascend`` ``GPU`` ``CPU``
313
+ """
314
+ @prim_attr_register
315
+ def __init__(self):
316
+ """Initialize ScalarMul"""
317
+
318
+ def __call__(self, x, y):
319
+ return x * y
320
+
321
+
322
+ class ScalarEq(Primitive):
323
+ r"""
324
+ Computes the equivalence between two Scalars.
325
+
326
+ .. note::
327
+ The inputs can be constant/variable value. Usage is the same as '==' in Python.
328
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
329
+
330
+ Inputs:
331
+ - **x** (Scalar) - A constant or variable scalar.
332
+ - **y** (Scalar) - A constant or variable scalar.
333
+
334
+ Outputs:
335
+ Scalar, the type of scalar is bool.
336
+
337
+ Raises:
338
+ TypeError: If `x` and `y` are not scalar.
339
+
340
+ Supported Platforms:
341
+ ``Ascend`` ``GPU`` ``CPU``
342
+ """
343
+ @prim_attr_register
344
+ def __init__(self):
345
+ """Initialize ScalarEq"""
346
+
347
+ def __call__(self, x, y):
348
+ return x == y
349
+
350
+
351
+ class ScalarGt(Primitive):
352
+ r"""
353
+ Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
354
+
355
+ .. note::
356
+ The inputs can be constant/variable value. Usage is the same as '>' in Python.
357
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
358
+
359
+ Inputs:
360
+ - **x** (Scalar) - A constant or variable scalar.
361
+ - **y** (Scalar) - A constant or variable scalar.
362
+
363
+ Outputs:
364
+ Scalar, the type of scalar is bool.
365
+
366
+ Raises:
367
+ TypeError: If `x` and `y` are not scalar.
368
+
369
+ Supported Platforms:
370
+ ``Ascend`` ``GPU`` ``CPU``
371
+ """
372
+ @prim_attr_register
373
+ def __init__(self):
374
+ """Initialize scalar_gt"""
375
+
376
+ def __call__(self, x, y):
377
+ return x > y
378
+
379
+
380
+ class ScalarLt(Primitive):
381
+ r"""
382
+ Computes the boolean value of :math:`x < y`.
383
+
384
+ .. note::
385
+ The inputs can be constant/variable value. Usage is the same as '<' in Python.
386
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
387
+
388
+ Inputs:
389
+ - **x** (Scalar) - A constant or variable scalar.
390
+ - **y** (Scalar) - A constant or variable scalar.
391
+
392
+ Outputs:
393
+ Scalar, the type of scalar is bool.
394
+
395
+ Raises:
396
+ TypeError: If `x` and `y` are not scalar.
397
+
398
+ Supported Platforms:
399
+ ``Ascend`` ``GPU`` ``CPU``
400
+ """
401
+ @prim_attr_register
402
+ def __init__(self):
403
+ """Initialize scalar_lt"""
404
+
405
+ def __call__(self, x, y):
406
+ return x < y
407
+
408
+
409
+ class ScalarGe(Primitive):
410
+ r"""
411
+ Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
412
+
413
+ .. note::
414
+ The inputs can be constant/variable value. Usage is the same as '>=' in Python.
415
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
416
+
417
+ Inputs:
418
+ - **x** (Scalar) - A constant or variable scalar.
419
+ - **y** (Scalar) - A constant or variable scalar.
420
+
421
+ Outputs:
422
+ Scalar, the type of scalar is bool.
423
+
424
+ Raises:
425
+ TypeError: If `x` and `y` are not scalar.
426
+
427
+ Supported Platforms:
428
+ ``Ascend`` ``GPU`` ``CPU``
429
+ """
430
+ @prim_attr_register
431
+ def __init__(self):
432
+ """Initialize scalar_ge"""
433
+
434
+ def __call__(self, x, y):
435
+ return x >= y
436
+
437
+
438
+ class ScalarLe(Primitive):
439
+ r"""
440
+ Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
441
+
442
+ .. note::
443
+ The inputs can be constant/variable value. Usage is the same as '<=' in Python.
444
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
445
+
446
+ Inputs:
447
+ - **x** (Scalar) - A constant or variable scalar.
448
+ - **y** (Scalar) - A constant or variable scalar.
449
+
450
+ Outputs:
451
+ Scalar, the type of scalar is bool.
452
+
453
+ Raises:
454
+ TypeError: If `x` and `y` are not scalar.
455
+
456
+ Supported Platforms:
457
+ ``Ascend`` ``GPU`` ``CPU``
458
+ """
459
+ @prim_attr_register
460
+ def __init__(self):
461
+ """Initialize scalar_le"""
462
+
463
+ def __call__(self, x, y):
464
+ return x <= y
465
+
466
+
467
+ class ScalarMod(Primitive):
468
+ r"""
469
+ Computes the remainder of dividing the first input scalar by the second input scalar element-wise.
470
+
471
+ .. math::
472
+
473
+ out_{i} = x_{i} \text{ % } y_{i}
474
+
475
+ .. note::
476
+ The inputs can be constant/variable value. Usage is the same as '%' in Python.
477
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
478
+
479
+ Inputs:
480
+ - **x** (Scalar) - A constant or variable scalar.
481
+ - **y** (Scalar) - A constant or variable scalar.
482
+
483
+ Outputs:
484
+ Scalar, the type is the one with higher precision or higher digits among the two inputs.
485
+
486
+ Raises:
487
+ TypeError: If `x` and `y` are not scalar.
488
+
489
+ Supported Platforms:
490
+ ``Ascend`` ``GPU`` ``CPU``
491
+ """
492
+ @prim_attr_register
493
+ def __init__(self):
494
+ """Initialize ScalarMod"""
495
+
496
+ def __call__(self, x, y):
497
+ if y == 0:
498
+ raise ValueError('Cannot perform modulo operation on zero.')
499
+ return x % y
500
+
501
+
502
+ class ScalarBool(Primitive):
503
+ r"""
504
+ Computes the input scalar true or false.
505
+
506
+ .. note::
507
+ The inputs can be constant/variable value.
508
+ This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
509
+
510
+ Inputs:
511
+ - **x** (Scalar) - A constant or variable scalar.
512
+
513
+ Outputs:
514
+ Scalar, the type is bool.
515
+
516
+ Raises:
517
+ TypeError: If `x` are not scalar.
518
+
519
+ Supported Platforms:
520
+ ``Ascend`` ``GPU`` ``CPU``
521
+ """
522
+ @prim_attr_register
523
+ def __init__(self):
524
+ """Initialize ScalarBool"""
525
+
526
+ def __call__(self, x):
527
+ return bool(x)
528
+
529
+
530
+ scalar_div = ScalarDiv()
531
+ scalar_mod = ScalarMod()
532
+ scalar_add = ScalarAdd()
533
+ scalar_mul = ScalarMul()
534
+ scalar_sub = ScalarSub()
535
+ scalar_gt = ScalarGt()
536
+ scalar_ge = ScalarGe()
537
+ scalar_le = ScalarLe()
538
+ scalar_lt = ScalarLt()
539
+ scalar_eq = ScalarEq()
540
+ scalar_bool = ScalarBool()
541
+ scalar_floordiv = ScalarFloorDiv()
542
+ scalar_log = ScalarLog()
543
+ scalar_pow = ScalarPow()
544
+ scalar_uadd = ScalarUadd()
545
+ scalar_usub = ScalarUsub()
546
+
547
+
548
+ class BatchNorm(Primitive):
549
+ r"""
550
+ Batch Normalization for input data and updated parameters.
551
+
552
+ Batch Normalization is widely used in convolutional neural networks. This operation
553
+ applies Batch Normalization over inputs to avoid internal covariate shift as described
554
+ in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
555
+ Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
556
+ features using a mini-batch of data and the learned parameters can be described
557
+ in the following formula,
558
+
559
+ .. math::
560
+
561
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
562
+
563
+ where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon,
564
+ :math:`mean` is the mean of :math:`x`,
565
+ :math:`variance` is the variance of :math:`x`.
566
+
567
+ .. warning::
568
+ - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
569
+ then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
570
+ - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
571
+
572
+ Args:
573
+ is_training (bool): If `is_training` is ``True`` , `mean` and `variance` are computed during training.
574
+ If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` .
575
+ epsilon (float): A small value added for numerical stability. Default: ``1e-5``, value must be (0, 1] .
576
+ momentum (float): The hyper parameter to compute moving average for running_mean and running_var
577
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
578
+ Momentum value must be [0, 1]. Default: ``0.1`` .
579
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``, and the ``'NHWC'`` format
580
+ is only supported in GPU target. Default: ``"NCHW"`` .
581
+
582
+ Inputs:
583
+ If `is_training` is ``False`` , inputs are Tensors.
584
+
585
+ - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
586
+ - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
587
+ - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
588
+ - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
589
+ - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
590
+
591
+ If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters.
592
+
593
+ - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
594
+ - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type.
595
+ - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
596
+ - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
597
+ - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
598
+
599
+ Outputs:
600
+ Tuple of 5 Tensors, the normalized inputs and the updated parameters.
601
+
602
+ - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
603
+ - **batch_mean** (Tensor) - The mean calculated per-dimension over the mini-batches,
604
+ shape is :math:`(C,)`.
605
+ - **batch_variance** (Tensor) - The variance calculated per-dimension over the mini-batches,
606
+ shape is :math:`(C,)`.
607
+ - **reserve_space_1** (Tensor) - The mean that needs to be reused when calculating gradients,
608
+ one-dimensional Tensor. The shape is :math:`(C,)`.
609
+ - **reserve_space_2** (Tensor) - The variance that needs to be reused when calculating gradients,
610
+ one-dimensional Tensor. The shape is :math:`(C,)`.
611
+
612
+ Raises:
613
+ TypeError: If `is_training` is not a bool.
614
+ TypeError: If dtype of `epsilon` or `momentum` is not float.
615
+ TypeError: If `data_format` is not a str.
616
+ TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
617
+ TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
618
+
619
+ Supported Platforms:
620
+ ``Ascend`` ``GPU`` ``CPU``
621
+
622
+ Examples:
623
+ >>> import mindspore
624
+ >>> import numpy as np
625
+ >>> from mindspore import Tensor, ops
626
+ >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
627
+ >>> scale = Tensor(np.ones([2]), mindspore.float32)
628
+ >>> bias = Tensor(np.ones([2]), mindspore.float32)
629
+ >>> mean = Tensor(np.ones([2]), mindspore.float32)
630
+ >>> variance = Tensor(np.ones([2]), mindspore.float32)
631
+ >>> batch_norm = ops.BatchNorm()
632
+ >>> output = batch_norm(input_x, scale, bias, mean, variance)
633
+ >>> print(output[0])
634
+ [[1. 1.]
635
+ [1. 1.]]
636
+ """
637
+ __mindspore_signature__ = (sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
638
+ sig.make_sig('scale',
639
+ sig.sig_rw.RW_WRITE,
640
+ dtype=sig.sig_dtype.T2),
641
+ sig.make_sig('bias',
642
+ sig.sig_rw.RW_WRITE,
643
+ dtype=sig.sig_dtype.T2),
644
+ sig.make_sig('mean',
645
+ sig.sig_rw.RW_WRITE,
646
+ dtype=sig.sig_dtype.T3),
647
+ sig.make_sig('variance',
648
+ sig.sig_rw.RW_WRITE,
649
+ dtype=sig.sig_dtype.T3))
650
+
651
+ @prim_arg_register
652
+ def __init__(self,
653
+ is_training=False,
654
+ epsilon=1e-5,
655
+ momentum=0.1,
656
+ data_format="NCHW"):
657
+ """Initialize BatchNorm."""
658
+ if is_training is False:
659
+ self.set_signatures(tuple())
660
+ else:
661
+ self.add_prim_attr('side_effect_mem', True)
662
+ self.is_training = is_training
663
+ self.epsilon = epsilon
664
+ self.momentum = momentum
665
+ self.data_format = handler.str_to_enum("BatchNorm", "data_format", data_format)
666
+
667
+ def __call__(self, *args):
668
+ return super().__call__(*args, self.is_training, self.epsilon,
669
+ self.momentum, self.data_format)
670
+
671
+
672
+ def batch_norm_(input_x,
673
+ scale,
674
+ bias,
675
+ mean,
676
+ variance,
677
+ is_training=False,
678
+ epsilon=1e-5,
679
+ momentum=0.1,
680
+ data_format="NCHW"):
681
+ r"""
682
+ Batch Normalization for input data and updated parameters.
683
+
684
+ Batch Normalization is widely used in convolutional neural networks. This operation
685
+ applies Batch Normalization over inputs to avoid internal covariate shift as described
686
+ in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
687
+ Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
688
+ features using a mini-batch of data and the learned parameters can be described
689
+ in the following formula,
690
+
691
+ .. math::
692
+
693
+ y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
694
+
695
+ where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon,
696
+ :math:`mean` is the mean of :math:`x`,
697
+ :math:`variance` is the variance of :math:`x`.
698
+
699
+ .. warning::
700
+ - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
701
+ then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
702
+ - For Atlas 200/300/500 inference product,
703
+ the result accuracy fails to reach 1‰ due to the square root instruction.
704
+
705
+ Note:
706
+ - If `training` is `False`, `weight`, `bias`, `running_mean` and `running_var` are tensors.
707
+ - If `training` is `True`, `weight`, `bias`, `running_mean` and `running_var` are Parameters.
708
+
709
+ Args:
710
+ input_x (tensor): tensor of shape :math:`(N, C)`, with float16 or float32 data type.
711
+ scale (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
712
+ bias (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
713
+ mean (Union[tensor, Parameter]): The shape :math:`(C,)`, with float16 or float32 data type.
714
+ variance (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
715
+ is_training (bool, optional): If `training` is `True`, `mean` and `variance` are computed during training.
716
+ If `training` is `False`, they're loaded from checkpoint during inference. Default: False.
717
+ epsilon (float): A small value added for numerical stability.
718
+ Default: ``1e-5``, value must be (0, 1] .
719
+ momentum (float): The hyper parameter to compute moving average for running_mean and running_var
720
+ (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
721
+ Momentum value must be [0, 1].
722
+ Default: ``0.1`` .
723
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
724
+ and the ``'NHWC'`` format is only supported in GPU target.
725
+ Default: ``"NCHW"`` .
726
+
727
+ Returns:
728
+ output_x (Tensor): The same type and shape as the input_x. The shape is :math:`(N, C)`.
729
+ batch_mean (Tensor): Tensor of shape :math:`(C,)`.
730
+ batch_variance (Tensor): Tensor of shape :math:`(C,)`.
731
+ reserve_space_1 (Tensor): Tensor of shape :math:`(C,)`.
732
+ reserve_space_2 (Tensor): Tensor of shape :math:`(C,)`.
733
+
734
+ Raises:
735
+ TypeError: If `is_training` is not a bool.
736
+ TypeError: If dtype of `epsilon` or `momentum` is not float.
737
+ TypeError: If `data_format` is not a str.
738
+ TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
739
+ TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
740
+
741
+ Supported Platforms:
742
+ ``Ascend`` ``GPU`` ``CPU``
743
+
744
+ Examples:
745
+ >>> import mindspore
746
+ >>> import numpy as np
747
+ >>> from mindspore import Tensor, ops
748
+ >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
749
+ >>> scale = Tensor(np.ones([2]), mindspore.float32)
750
+ >>> bias = Tensor(np.ones([2]), mindspore.float32)
751
+ >>> mean = Tensor(np.ones([2]), mindspore.float32)
752
+ >>> variance = Tensor(np.ones([2]), mindspore.float32)
753
+ >>> output = ops.batch_norm_(input_x, scale, bias, mean, variance, is_training, epsilon, momentum, data_format)
754
+ >>> print(output[0])
755
+ [[1. 1.]
756
+ [1. 1.]]
757
+ """
758
+ batch_norm_op = _get_cache_prim(BatchNorm)(is_training, epsilon, momentum,
759
+ data_format)
760
+ return batch_norm_op(input_x, scale, bias, mean, variance)
761
+
762
+
763
+ class Rank(Primitive):
764
+ """
765
+ Returns the rank of a tensor.
766
+
767
+ Refer to :func:`mindspore.ops.rank` for more details.
768
+
769
+ Supported Platforms:
770
+ ``Ascend`` ``GPU`` ``CPU``
771
+
772
+ Examples:
773
+ >>> import mindspore
774
+ >>> import numpy as np
775
+ >>> from mindspore import Tensor, ops
776
+ >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
777
+ >>> rank = ops.Rank()
778
+ >>> output = rank(input_tensor)
779
+ >>> print(output)
780
+ 2
781
+ >>> print(type(output))
782
+ <class 'int'>
783
+ """
784
+
785
+ @prim_attr_register
786
+ def __init__(self):
787
+ """Initialize Rank"""
788
+
789
+ def __call__(self, x):
790
+ if not isinstance(x, (Tensor, Tensor_)):
791
+ raise TypeError("the input x must be Tensor!")
792
+ return len(x.shape)
793
+
794
+
795
+ def rank(input_x):
796
+ """
797
+ Returns the rank of a tensor.
798
+
799
+ Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
800
+ is the number of indices required to uniquely select each element of the tensor.
801
+
802
+ Args:
803
+ input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
804
+
805
+ Returns:
806
+ Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int.
807
+
808
+ Raises:
809
+ TypeError: If `input_x` is not a Tensor.
810
+
811
+ Supported Platforms:
812
+ ``Ascend`` ``GPU`` ``CPU``
813
+
814
+ Examples:
815
+ >>> import mindspore
816
+ >>> import numpy as np
817
+ >>> from mindspore import Tensor, ops
818
+ >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
819
+ >>> output = ops.rank(input_tensor)
820
+ >>> print(output)
821
+ 2
822
+ >>> print(type(output))
823
+ <class 'int'>
824
+
825
+ """
826
+ rank_op = _get_cache_prim(Rank)()
827
+ return rank_op(input_x)
828
+
829
+
830
+ class Shape(Primitive):
831
+ """
832
+ Returns the shape of the input tensor.
833
+
834
+ Refer to :func:`mindspore.ops.shape` for more details.
835
+
836
+ Inputs:
837
+ - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
838
+
839
+ Outputs:
840
+ tuple[int], the output tuple is constructed by multiple integers,
841
+ :math:`(x_1, x_2, ..., x_R)`.
842
+
843
+ Supported Platforms:
844
+ ``Ascend`` ``GPU`` ``CPU``
845
+
846
+ Examples:
847
+ >>> import mindspore
848
+ >>> import numpy as np
849
+ >>> from mindspore import Tensor, ops
850
+ >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
851
+ >>> shape = ops.Shape()
852
+ >>> output = shape(input_x)
853
+ >>> print(output)
854
+ (3, 2, 1)
855
+ """
856
+
857
+ @prim_attr_register
858
+ def __init__(self):
859
+ """Initialize Shape"""
860
+
861
+ def __call__(self, x):
862
+ if isinstance(x, (Tensor, COOTensor, CSRTensor, Tensor_)):
863
+ return x.shape
864
+ raise TypeError(f"For primitive[{self.name}], the input argument must be Tensor, but got {type(x)}.")
865
+
866
+
867
+ def shape_(input_x):
868
+ """
869
+ Returns the shape of the input tensor.
870
+
871
+ Args:
872
+ input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
873
+
874
+ Returns:
875
+ tuple[int], the output tuple is constructed by multiple integers,
876
+ :math:`(x_1, x_2, ..., x_R)`.
877
+
878
+ Raises:
879
+ TypeError: If `input_x` is not a Tensor.
880
+
881
+ Supported Platforms:
882
+ ``Ascend`` ``GPU`` ``CPU``
883
+
884
+ Examples:
885
+ >>> import mindspore
886
+ >>> import numpy as np
887
+ >>> from mindspore import Tensor, ops
888
+ >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
889
+ >>> output = ops.shape(input_x)
890
+ >>> print(output)
891
+ (3, 2, 1)
892
+ """
893
+ shape_op = _get_cache_prim(Shape)()
894
+ return shape_op(input_x)
895
+
896
+
897
+ class ScalarToTensor(PrimitiveWithInfer):
898
+ """
899
+ Converts a scalar to a `Tensor`, and converts the data type to the specified type.
900
+
901
+ Refer to :func:`mindspore.ops.scalar_to_tensor` for more details.
902
+
903
+ Inputs:
904
+ - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
905
+ - **dtype** (mindspore.dtype) - The target data type. Default: ``mindspore.float32`` . Only
906
+ constant value is allowed.
907
+
908
+ Outputs:
909
+ Tensor. 0-D Tensor and the content is the input.
910
+
911
+ Supported Platforms:
912
+ ``Ascend`` ``GPU`` ``CPU``
913
+
914
+ Examples:
915
+ >>> import mindspore
916
+ >>> from mindspore import ops
917
+ >>> op = ops.ScalarToTensor()
918
+ >>> data = 1
919
+ >>> output = op(data, mindspore.float32)
920
+ >>> print(output)
921
+ 1.0
922
+ """
923
+
924
+ @prim_attr_register
925
+ def __init__(self):
926
+ self.init_prim_io_names(inputs=['input_scalar', 'dtype'], outputs=['output_data'])
927
+
928
+ def __call__(self, x, dtype=mstype.float32):
929
+ validator.check_value_type("x", x, [bool, int, float], self.name)
930
+ validator.check_subclass("dtype", dtype, mstype.number, self.name)
931
+ data_type = mstype.dtype_to_nptype(dtype)
932
+ return Tensor(np.array(x, data_type), dtype=dtype)
933
+
934
+
935
+ class Tile(Primitive):
936
+ r"""
937
+ Replicates an input tensor with given multiple times.
938
+
939
+ Refer to :func:`mindspore.ops.tile` for more details.
940
+
941
+ Note:
942
+ On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
943
+ where more than 4 dimensions are repeated simultaneously.
944
+
945
+ Inputs:
946
+ - **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
947
+ :math:`(x_1, x_2, ..., x_S)` .
948
+ - **dims** (tuple[int]) - The parameter that specifies the number of replications,
949
+ the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
950
+ Only constant value is allowed.
951
+
952
+ Outputs:
953
+ Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
954
+ the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
955
+
956
+ - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
957
+ the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
958
+ - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
959
+ Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
960
+ then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
961
+ :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
962
+ - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the
963
+ `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
964
+ can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
965
+
966
+ Raises:
967
+ TypeError: If `dims` is not a tuple or its elements are not all int.
968
+ ValueError: If the elements of `dims` are not all greater than or equal to 0.
969
+
970
+ Supported Platforms:
971
+ ``Ascend`` ``GPU`` ``CPU``
972
+
973
+ Examples:
974
+ >>> import mindspore
975
+ >>> import numpy as np
976
+ >>> from mindspore import Tensor, ops
977
+ >>> tile = ops.Tile()
978
+ >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
979
+ >>> dims = (2, 3)
980
+ >>> output = tile(input, dims)
981
+ >>> print(output)
982
+ [[1. 2. 1. 2. 1. 2.]
983
+ [3. 4. 3. 4. 3. 4.]
984
+ [1. 2. 1. 2. 1. 2.]
985
+ [3. 4. 3. 4. 3. 4.]]
986
+ >>> dims = (2, 3, 2)
987
+ >>> output = tile(input, dims)
988
+ >>> print(output)
989
+ [[[1. 2. 1. 2.]
990
+ [3. 4. 3. 4.]
991
+ [1. 2. 1. 2.]
992
+ [3. 4. 3. 4.]
993
+ [1. 2. 1. 2.]
994
+ [3. 4. 3. 4.]]
995
+ [[1. 2. 1. 2.]
996
+ [3. 4. 3. 4.]
997
+ [1. 2. 1. 2.]
998
+ [3. 4. 3. 4.]
999
+ [1. 2. 1. 2.]
1000
+ [3. 4. 3. 4.]]]
1001
+ """
1002
+
1003
+ @prim_attr_register
1004
+ def __init__(self):
1005
+ """Initialize."""
1006
+
1007
+ def __call__(self, input, dims):
1008
+ return _convert_stub(pyboost_tile(self, [input, dims]))
1009
+
1010
+ # pylint: disable=missing-docstring
1011
+ def check_elim(self, *args):
1012
+ base_tensor, dims = args
1013
+ if not isinstance(base_tensor, Tensor):
1014
+ raise TypeError(f"For '{self.name}', the type of 'input' must be Tensor, "
1015
+ f"but got {type(base_tensor).__name__}.")
1016
+ if not isinstance(dims, tuple):
1017
+ raise TypeError(f"For '{self.name}', the type of 'dims' must be tuple, "
1018
+ f"but got {type(dims).__name__}.")
1019
+
1020
+ if all(v == 1 for v in dims) and len(base_tensor.shape) >= len(dims):
1021
+ from mindspore.ops.auto_generate.gen_ops_def import Identity
1022
+ ret = Identity()(base_tensor)
1023
+ return (True, ret)
1024
+ return (False, None)
1025
+
1026
+
1027
+ def tile(input, dims):
1028
+ r"""
1029
+ Creates a new tensor by replicating `input` `dims` times. The i'th dimension of
1030
+ output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
1031
+ are replicated `dims[i]` times along the i'th dimension.
1032
+
1033
+ Note:
1034
+ On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
1035
+ where more than 4 dimensions are repeated simultaneously.
1036
+
1037
+ Args:
1038
+ input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
1039
+ :math:`(x_1, x_2, ..., x_S)` .
1040
+
1041
+ dims (tuple[int]): The parameter that specifies the number of replications,
1042
+ the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
1043
+ Only constant value is allowed.
1044
+
1045
+ Returns:
1046
+ Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
1047
+ the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
1048
+
1049
+ - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
1050
+ the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
1051
+ - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
1052
+ Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
1053
+ then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
1054
+ :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
1055
+ - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the
1056
+ `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
1057
+ can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
1058
+
1059
+ Raises:
1060
+ TypeError: If `dims` is not a tuple or its elements are not all int.
1061
+ ValueError: If the elements of `dims` are not all greater than or equal to 0.
1062
+
1063
+ Supported Platforms:
1064
+ ``Ascend`` ``GPU`` ``CPU``
1065
+
1066
+ Examples:
1067
+ >>> import mindspore
1068
+ >>> import numpy as np
1069
+ >>> from mindspore import Tensor, ops
1070
+ >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
1071
+ >>> dims = (2, 3)
1072
+ >>> output = ops.tile(input, dims)
1073
+ >>> print(output)
1074
+ [[1. 2. 1. 2. 1. 2.]
1075
+ [3. 4. 3. 4. 3. 4.]
1076
+ [1. 2. 1. 2. 1. 2.]
1077
+ [3. 4. 3. 4. 3. 4.]]
1078
+ >>> dims = (2, 3, 2)
1079
+ >>> output = ops.tile(input, dims)
1080
+ >>> print(output)
1081
+ [[[1. 2. 1. 2.]
1082
+ [3. 4. 3. 4.]
1083
+ [1. 2. 1. 2.]
1084
+ [3. 4. 3. 4.]
1085
+ [1. 2. 1. 2.]
1086
+ [3. 4. 3. 4.]]
1087
+ [[1. 2. 1. 2.]
1088
+ [3. 4. 3. 4.]
1089
+ [1. 2. 1. 2.]
1090
+ [3. 4. 3. 4.]
1091
+ [1. 2. 1. 2.]
1092
+ [3. 4. 3. 4.]]]
1093
+ """
1094
+ tile_op = _get_cache_prim(Tile)()
1095
+ return tile_op(input, dims)
1096
+
1097
+
1098
+ def scalar_cast(input_x, input_y):
1099
+ r"""
1100
+ The interface is deprecated from version 2.3 and will be removed in a future version,
1101
+ please use `int(x)` or `float(x)` instead.
1102
+
1103
+ Casts the input scalar to another type.
1104
+
1105
+ Args:
1106
+ input_x (scalar): The input scalar.
1107
+ input_y (mindspore.dtype): The type to be cast. Only constant value is allowed.
1108
+ The value should only be mindspore.int64, mindspore.float64, or mindspore.bool\_.
1109
+
1110
+ Returns:
1111
+ Scalar, the type is the same as the python type corresponding to `input_y`.
1112
+
1113
+ Raises:
1114
+ ValueError: if input_y's value is invalid.
1115
+
1116
+ Supported Platforms:
1117
+ Deprecated
1118
+
1119
+ Examples:
1120
+ >>> import mindspore
1121
+ >>> from mindspore import ops
1122
+ >>> output = ops.scalar_cast(255.0, mindspore.int64)
1123
+ >>> print(output)
1124
+ 255
1125
+ """
1126
+ scalar_cast_op = _get_cache_prim(ScalarCast)()
1127
+ return scalar_cast_op(input_x, input_y)
1128
+
1129
+
1130
+ class Cast(Primitive):
1131
+ """
1132
+ Returns a tensor with the new specified data type.
1133
+
1134
+ Note:
1135
+ When converting complex numbers to boolean type, the imaginary part of the complex number is not
1136
+ taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
1137
+
1138
+ Inputs:
1139
+ - **input** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1140
+ The tensor to be cast.
1141
+ - **dtype** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
1142
+
1143
+ Outputs:
1144
+ Tensor, the shape of tensor is the same as `input`, :math:`(x_1, x_2, ..., x_R)`.
1145
+
1146
+ Raises:
1147
+ TypeError: If `input` is neither Tensor nor Number.
1148
+ TypeError: If `dtype` is not a Number.
1149
+
1150
+ Supported Platforms:
1151
+ ``Ascend`` ``GPU`` ``CPU``
1152
+
1153
+ Examples:
1154
+ >>> import mindspore
1155
+ >>> import numpy as np
1156
+ >>> from mindspore import Tensor, ops
1157
+ >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
1158
+ >>> input = Tensor(input_np)
1159
+ >>> dtype = mindspore.int32
1160
+ >>> cast = ops.Cast()
1161
+ >>> output = cast(input, dtype)
1162
+ >>> print(output.dtype)
1163
+ Int32
1164
+ >>> print(output.shape)
1165
+ (2, 3, 4, 5)
1166
+ """
1167
+
1168
+ @prim_attr_register
1169
+ def __init__(self):
1170
+ """Initialize Cast"""
1171
+ self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
1172
+
1173
+ def check_elim(self, x, dtype):
1174
+ if isinstance(x, (Tensor, numbers.Number, Parameter)):
1175
+ if isinstance(x, Parameter):
1176
+ data = x.data
1177
+ if data.dtype == dtype:
1178
+ return (True, x)
1179
+ if isinstance(x, Tensor) and x.dtype == dtype:
1180
+ x = Tensor(x)
1181
+ x.set_cast_dtype()
1182
+ return (True, x)
1183
+ if isinstance(x, numbers.Number):
1184
+ return (True, Tensor(x, dtype=dtype))
1185
+ return (False, None)
1186
+
1187
+ def __call__(self, input_x, dtype):
1188
+ should_elim, output = self.check_elim(input_x, dtype)
1189
+ if should_elim:
1190
+ return output
1191
+ return _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
1192
+
1193
+
1194
+ def to_sequence(val):
1195
+ """
1196
+ to_sequence
1197
+ """
1198
+ if isinstance(val, (tuple, list)):
1199
+ return tuple(val)
1200
+ return (val,)
1201
+
1202
+
1203
+ class EmbeddingTableExport(Primitive):
1204
+ """
1205
+ EmbeddingTableExport
1206
+ """
1207
+
1208
+ @prim_attr_register
1209
+ def __init__(self, embedding_dim, value_total_len, export_mode="all",
1210
+ only_var_flag=False, file_type="bin", table_name=(),
1211
+ filter_export_flag=False, steps_to_live_list=()):
1212
+ """Initialize EmbeddingTableExport"""
1213
+ self.add_prim_attr("_process_node_engine_id", "PS")
1214
+
1215
+
1216
+ class EmbeddingTableImport(Primitive):
1217
+ """
1218
+ EmbeddingTableImport
1219
+ """
1220
+
1221
+ @prim_attr_register
1222
+ def __init__(self, embedding_dim, value_total_len,
1223
+ only_var_flag=False, file_type="bin", table_name=()):
1224
+ """Initialize EmbeddingTableImport"""
1225
+ self.add_prim_attr("_process_node_engine_id", "PS")
1226
+
1227
+
1228
+ class EmbeddingComputeVarImport(Primitive):
1229
+ """
1230
+ EmbeddingComputeVarImport
1231
+ """
1232
+
1233
+ @prim_attr_register
1234
+ def __init__(self, table_name=()):
1235
+ """Initialize EmbeddingComputeVarImport"""
1236
+ self.add_prim_attr("_process_node_engine_id", "PS")
1237
+
1238
+
1239
+ class EmbeddingComputeVarExport(Primitive):
1240
+ """
1241
+ EmbeddingComputeVarExport
1242
+ """
1243
+
1244
+ @prim_attr_register
1245
+ def __init__(self, table_name=()):
1246
+ """Initialize EmbeddingComputeVarExport"""
1247
+ self.add_prim_attr("_process_node_engine_id", "PS")
1248
+
1249
+
1250
+ class InitEmbeddingHashmap(Primitive):
1251
+ """
1252
+ InitEmbeddingHashmap
1253
+ """
1254
+ @prim_attr_register
1255
+ def __init__(self, value_total_len, embedding_dim, _table_id,
1256
+ bucket_size=0, dtype=mstype.float32, initializer_mode="",
1257
+ constant_valu=0., min=-2., max=2., mu=0., sigma=1., seed=0,
1258
+ seed2=0, filter_mode="no_filter", optimizer_mode="",
1259
+ optimizer_params=()):
1260
+ self.add_prim_attr("_process_node_engine_id", "PS")
1261
+
1262
+
1263
+ def init_embedding_hashmap(table_id, value_total_len, embedding_dim, _table_id,
1264
+ bucket_size=0, dtype=mstype.float32, initializer_mode='',
1265
+ constant_value=0.0, min=-2.0, max=2.0, mu=0.0, sigma=1.0,
1266
+ seed=0, seed2=0, filter_mode='no_filter',
1267
+ optimizer_mode='', optimizer_params=()):
1268
+ """
1269
+ init_embedding_hashmap
1270
+ """
1271
+ op = _get_cache_prim(InitEmbeddingHashmap)(value_total_len, embedding_dim, _table_id,
1272
+ bucket_size, dtype, initializer_mode,
1273
+ constant_value, min, max, mu, sigma, seed,
1274
+ seed2, filter_mode, optimizer_mode, optimizer_params)
1275
+ return op(table_id)
1276
+
1277
+
1278
+ class InitPartitionMap(Primitive):
1279
+ """
1280
+ InitPartitionMap
1281
+ """
1282
+ @prim_attr_register
1283
+ def __init__(self, _embedding_dim, _max_key_num,
1284
+ _ps_num=1, partition_num=65537):
1285
+ self.add_prim_attr("_process_node_engine_id", "PS")
1286
+
1287
+
1288
+ def init_partition_map(ps_num, ps_ids, _embedding_dim, _max_key_num,
1289
+ _ps_num=1, partition_num=65537):
1290
+ """
1291
+ init_partition_map
1292
+ """
1293
+ op = _get_cache_prim(InitPartitionMap)(_embedding_dim, _max_key_num, _ps_num, partition_num)
1294
+ return op(ps_num, ps_ids)
1295
+
1296
+
1297
+ class EmbeddingApplyAdam(Primitive):
1298
+ """
1299
+ EmbeddingApplyAdam
1300
+ """
1301
+ @prim_attr_register
1302
+ def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1303
+ padding_key=(0,), padding_key_mask=(1,),
1304
+ completion_key=(0,), completion_key_mask=(1,)):
1305
+ self.add_prim_attr("_process_node_engine_id", "PS")
1306
+
1307
+
1308
+ class EmbeddingApplyAdamW(Primitive):
1309
+ """
1310
+ EmbeddingApplyAdam
1311
+ """
1312
+ @prim_attr_register
1313
+ def __init__(self, embedding_dim, _max_key_num, amsgrad=(0,),
1314
+ maximize=(0,), mask_zero=(0,), padding_key=(0,),
1315
+ padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,)):
1316
+ self.add_prim_attr("_process_node_engine_id", "PS")
1317
+
1318
+
1319
+ class EmbeddingApplyAdaGrad(Primitive):
1320
+ """
1321
+ EmbeddingApplyAdaGrad
1322
+ """
1323
+ @prim_attr_register
1324
+ def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1325
+ padding_key=(0,), padding_key_mask=(1,),
1326
+ completion_key=(0,), completion_key_mask=(1,)):
1327
+ self.add_prim_attr("_process_node_engine_id", "PS")
1328
+
1329
+
1330
+ class EmbeddingApplyFtrl(Primitive):
1331
+ """
1332
+ EmbeddingApplyFtrl
1333
+ """
1334
+ @prim_attr_register
1335
+ def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1336
+ padding_key=(0,), padding_key_mask=(1,),
1337
+ completion_key=(0,), completion_key_mask=(1,)):
1338
+ self.add_prim_attr("_process_node_engine_id", "PS")
1339
+
1340
+
1341
+ class EmbeddingTableFind(Primitive):
1342
+ """
1343
+ EmbeddingTableFind
1344
+ """
1345
+ @prim_attr_register
1346
+ def __init__(self, embedding_dim, _embedding_dim, _max_key_num,
1347
+ _table_id, default_value=(-1.), _use_counter_filter=0):
1348
+ self.add_prim_attr("_process_node_engine_id", "PS")
1349
+ self.add_prim_attr("_execute_times", 2)
1350
+
1351
+
1352
+ def embedding_table_find(table_id, keys, embedding_dim, _max_key_num,
1353
+ _table_id, default_value=(-1.0,), _use_counter_filter=0):
1354
+ r"""
1355
+ embedding_table_find
1356
+ """
1357
+ _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1358
+ op = _get_cache_prim(EmbeddingTableFind)(to_sequence(embedding_dim), _embedding_dim,
1359
+ _max_key_num, _table_id,
1360
+ to_sequence(default_value),
1361
+ _use_counter_filter)
1362
+ return op(table_id, keys)
1363
+
1364
+
1365
+ class EmbeddingTableFindAndInit(Primitive):
1366
+ """
1367
+ EmbeddingTableFindAndInit
1368
+ """
1369
+ @prim_attr_register
1370
+ def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id,
1371
+ _max_key_num, initializer_mode=("random_uniform",),
1372
+ constant_value=(0.,), min=(-2.,), max=(2.,), mu=(0.,),
1373
+ sigma=(1.,), seed=(0,), seed2=(0,),
1374
+ filter_mode=("no_filter",), filter_freq=(0,),
1375
+ default_key_or_value=(0,), default_key=(0,),
1376
+ default_value=(0.,), completion_key=(0,),
1377
+ completion_key_mask=(1,), optimizer_mode=(),
1378
+ optimizer_params=(), _use_counter_filter=0,
1379
+ backward_mode="adam",
1380
+ backward_int_params=((0,), (0,), (0,), (1,)),
1381
+ backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1382
+ self.add_prim_attr("_process_node_engine_id", "PS")
1383
+ self.add_prim_attr("_execute_times", 2)
1384
+
1385
+
1386
+ def embedding_table_find_and_init(table_id, keys, max_grad_norm, parameter, embedding_dim,
1387
+ value_total_len, _table_id, _max_key_num,
1388
+ initializer_mode=('random_uniform',), constant_value=(0.,),
1389
+ min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,),
1390
+ seed2=(0,), filter_mode=("no_filter",),
1391
+ filter_freq=(0,), default_key_or_value=(0,),
1392
+ default_key=(0,), default_value=(0.,),
1393
+ completion_key=(0,), completion_key_mask=(1,),
1394
+ optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1395
+ backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)),
1396
+ backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1397
+ """
1398
+ embedding_table_find_and_init
1399
+
1400
+ backward_int_params (Union[tuple[tuple[int]], list[list[int]]]):
1401
+ - when the backward_mode is 'adam', 'ftrl' or 'adagrad',
1402
+ it means [[global_step], mask_zero, padding_key, padding_key_mask]
1403
+ - when the backward_mode is 'adamw', it means:
1404
+ [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask]
1405
+ backward_float_params (Union[tuple[float], list[float]]):
1406
+ - when the backward_mode is 'adam', it means:
1407
+ [beta1_power, beta2_power, lr, beta1, beta2, epsilon]
1408
+ - when the backward_mode is 'ftrl', it means:
1409
+ [lr, lr_power, lambda1, lambda2]
1410
+ - when the backward_mode is 'adamw', it means:
1411
+ [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
1412
+ - when the backward_mode is 'adagrad', it means [lr,]
1413
+ """
1414
+ _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1415
+ op = _get_cache_prim(EmbeddingTableFindAndInit)(to_sequence(embedding_dim), to_sequence(value_total_len),
1416
+ _embedding_dim, _table_id, _max_key_num,
1417
+ to_sequence(initializer_mode),
1418
+ to_sequence(constant_value), to_sequence(min),
1419
+ to_sequence(max), to_sequence(mu),
1420
+ to_sequence(sigma), to_sequence(seed),
1421
+ to_sequence(seed2), to_sequence(filter_mode),
1422
+ to_sequence(filter_freq), to_sequence(default_key_or_value),
1423
+ to_sequence(default_key), to_sequence(default_value),
1424
+ to_sequence(completion_key), to_sequence(completion_key_mask),
1425
+ to_sequence(optimizer_mode), to_sequence(optimizer_params),
1426
+ _use_counter_filter,
1427
+ backward_mode, backward_int_params, backward_float_params)
1428
+ return op(table_id, keys, max_grad_norm, parameter)
1429
+
1430
+
1431
+ class FakeRemoteLookupUniqued(Primitive):
1432
+
1433
+ """
1434
+ FakeRemoteLookupUniqued
1435
+ """
1436
+ @prim_attr_register
1437
+ def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id,
1438
+ _max_key_num, initializer_mode=('random_uniform',), constant_value=(0.,),
1439
+ min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,), seed2=(0,),
1440
+ filter_mode=("no_filter",), filter_freq=(0,),
1441
+ default_key_or_value=(0,), default_key=(0,), default_value=(0.,),
1442
+ completion_key=(0,), completion_key_mask=(1,),
1443
+ optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1444
+ backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)),
1445
+ backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1446
+ self.add_prim_attr("_process_node_engine_id", "PS")
1447
+ self.add_prim_attr("_execute_times", 2)
1448
+
1449
+
1450
+ def fake_remote_lookup_uniqued(table_id, keys, actual_keys_num, unique_indices,
1451
+ key_count, max_grad_norm, parameter,
1452
+ embedding_dim, value_total_len, _table_id, _max_key_num,
1453
+ initializer_mode=('random_uniform',), constant_value=(0.,),
1454
+ min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,),
1455
+ seed2=(0,), filter_mode=("no_filter",),
1456
+ filter_freq=(0,), default_key_or_value=(0,),
1457
+ default_key=(0,), default_value=(0.,),
1458
+ completion_key=(0,), completion_key_mask=(1,),
1459
+ optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1460
+ backward_mode='adam', backward_int_params=((0,), (0,), (0,), (1,)),
1461
+ backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1462
+ """
1463
+ fake_remote_lookup_uniqued
1464
+
1465
+ backward_mode (str): determine the optimizer used by backpropagation,
1466
+ valid values are ["adam", "adamw", "adagrad", "ftrl"]
1467
+ backward_int_params (Union[tuple[tuple[int]], list[list[int]]]):
1468
+ - when the backward_mode is 'adam', 'ftrl' or 'adagrad',
1469
+ it means [[global_step], mask_zero, padding_key, padding_key_mask]
1470
+ - when the backward_mode is 'adamw', it means:
1471
+ [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask]
1472
+ backward_float_params (Union[tuple[float], list[float]]):
1473
+ - when the backward_mode is 'adam', it means:
1474
+ [beta1_power, beta2_power, lr, beta1, beta2, epsilon]
1475
+ - when the backward_mode is 'ftrl', it means:
1476
+ [lr, lr_power, lambda1, lambda2]
1477
+ - when the backward_mode is 'adamw', it means:
1478
+ [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
1479
+ - when the backward_mode is 'adagrad', it means [lr,]
1480
+ """
1481
+ _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1482
+ op = _get_cache_prim(FakeRemoteLookupUniqued)(to_sequence(embedding_dim), to_sequence(value_total_len),
1483
+ _embedding_dim, _table_id, _max_key_num,
1484
+ to_sequence(initializer_mode), to_sequence(constant_value),
1485
+ to_sequence(min), to_sequence(max), to_sequence(mu),
1486
+ to_sequence(sigma), to_sequence(seed), to_sequence(seed2),
1487
+ to_sequence(filter_mode), to_sequence(filter_freq),
1488
+ to_sequence(default_key_or_value), to_sequence(default_key),
1489
+ to_sequence(default_value), to_sequence(completion_key),
1490
+ to_sequence(completion_key_mask), to_sequence(optimizer_mode),
1491
+ to_sequence(optimizer_params), _use_counter_filter,
1492
+ backward_mode, backward_int_params, backward_float_params)
1493
+ return op(table_id, keys, actual_keys_num, unique_indices, key_count, max_grad_norm, parameter)
1494
+
1495
+
1496
+ # Following is Python Infer Value.
1497
+ # A valid infer value function should be:
1498
+ #
1499
+ # 1. named as infer_value_for_OpName
1500
+ # 2. All inputs should pass without default value.
1501
+ # 3. If not const input is given, return None. (for now)
1502
+
1503
+
1504
+ def infer_value_for_Tile(input, dims):
1505
+ """Infer value for Tile op."""
1506
+ if input is None or dims is None or None in dims:
1507
+ return None
1508
+ return Tensor(np.tile(input.asnumpy(), dims))
1509
+
1510
+
1511
+ def infer_value_for_Concat(tensors, axis):
1512
+ """Infer value for Concat op."""
1513
+ if not tensors or None in tensors or axis is None:
1514
+ return None
1515
+
1516
+ tensor_to_concat = [x.asnumpy() if x.dtype != mstype.bfloat16 else x.float().asnumpy() for x in tensors]
1517
+ return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
1518
+
1519
+
1520
+ def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode):
1521
+ """Infer value for ReduceSum op."""
1522
+ value = None
1523
+ if input_x is not None and axis is not None:
1524
+ value = input_x.asnumpy()
1525
+ if isinstance(axis, int):
1526
+ pass
1527
+ elif axis:
1528
+ axis = tuple(set(axis))
1529
+ elif axis in ((), []) and skip_mode:
1530
+ return input_x
1531
+ else:
1532
+ axis = tuple(range(len(value.shape)))
1533
+ value = np.sum(value, axis, keepdims=keep_dims)
1534
+ value = np.array(value)
1535
+ value = Tensor(value)
1536
+ return value
1537
+
1538
+
1539
+ def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
1540
+ """Infer value for Common Reduce op."""
1541
+ value = None
1542
+ if input_x is not None and axis is not None:
1543
+ prim_map = {
1544
+ 'ReduceMax': np.max,
1545
+ 'ReduceMin': np.min,
1546
+ 'ReduceProd': np.prod,
1547
+ 'ReduceMean': np.mean,
1548
+ 'ReduceAll': np.all,
1549
+ 'ReduceAny': np.any,
1550
+ }
1551
+ np_reduce_func = prim_map.get(prim_name, None)
1552
+
1553
+ if np_reduce_func is not None:
1554
+ value = input_x.asnumpy()
1555
+ if isinstance(axis, int):
1556
+ pass
1557
+ elif axis:
1558
+ axis = tuple(set(axis))
1559
+ else:
1560
+ axis = tuple(range(len(value.shape)))
1561
+ value = np_reduce_func(value, axis, keepdims=keep_dims)
1562
+ value = np.array(value)
1563
+ value = Tensor(value)
1564
+ return value
1565
+
1566
+
1567
+ def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
1568
+ """Infer value for Common ReduceExtand op."""
1569
+ value = None
1570
+ if input_x is not None:
1571
+ prim_map = {
1572
+ 'MeanExt': np.mean,
1573
+ 'SumExt': np.sum,
1574
+ 'ProdExt': np.prod,
1575
+ }
1576
+ np_reduce_extand_func = prim_map.get(prim_name, None)
1577
+
1578
+ if np_reduce_extand_func is not None:
1579
+ value = input_x.asnumpy()
1580
+ if isinstance(axis, int):
1581
+ pass
1582
+ elif axis:
1583
+ axis = tuple(set(axis))
1584
+ else:
1585
+ axis = tuple(range(len(value.shape)))
1586
+ if dtype is not None:
1587
+ np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1588
+ value = np_reduce_extand_func(value, axis, dtype=np_dtype, keepdims=keep_dims)
1589
+ else:
1590
+ value = np_reduce_extand_func(value, axis, keepdims=keep_dims)
1591
+
1592
+ value = np.array(value)
1593
+ value = Tensor(value)
1594
+ return value
1595
+
1596
+
1597
+ def _infer_value_for_max_min(input_x, prim_name):
1598
+ """Infer value for Max/Min op."""
1599
+ value = None
1600
+ if input_x is not None:
1601
+ prim_map = {
1602
+ 'Max': np.max,
1603
+ 'Min': np.min,
1604
+ }
1605
+ np_reduce_func = prim_map.get(prim_name, None)
1606
+
1607
+ if np_reduce_func is not None:
1608
+ value = input_x.asnumpy()
1609
+ value = np_reduce_func(value, None, keepdims=False)
1610
+ value = np.array(value)
1611
+ value = Tensor(value)
1612
+ return value
1613
+
1614
+
1615
+ def infer_value_for_Cast(x, dst_type_enum=None):
1616
+ """Infer value for Cast op."""
1617
+ if x is None or dst_type_enum is None:
1618
+ return None
1619
+ dst_type = typing.type_id_to_type(dst_type_enum)
1620
+ src_type = mstype.get_py_obj_dtype(x)
1621
+ validator.check_subclass("input_x", src_type, [mstype.tensor_type, mstype.number], "Cast")
1622
+ validator.check_subclass("type", dst_type, mstype.number, "Cast")
1623
+
1624
+ if isinstance(src_type, type(mstype.tensor_type)):
1625
+ src_type = src_type.element_type()
1626
+ if isinstance(dst_type, type(mstype.tensor_type)):
1627
+ dst_type = dst_type.element_type()
1628
+
1629
+ value = None
1630
+ np_dst_type = mstype.dtype_to_nptype(dst_type)
1631
+ if isinstance(x, (int, float)):
1632
+ value = Tensor(np.array(x).astype(np_dst_type), dtype=dst_type)
1633
+ else:
1634
+ value = Tensor_(x.asnumpy().astype(np_dst_type), dtype=dst_type)
1635
+ return value
1636
+
1637
+
1638
+ def infer_value_for_ReduceMax(input_x, axis, keep_dims):
1639
+ """Infer value for ReduceMax op."""
1640
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
1641
+
1642
+
1643
+ def infer_value_for_Max(input_x):
1644
+ """Infer value for Max op."""
1645
+ return _infer_value_for_max_min(input_x, 'Max')
1646
+
1647
+
1648
+ def infer_value_for_ReduceMin(input_x, axis, keep_dims):
1649
+ """Infer value for ReduceMin op."""
1650
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMin')
1651
+
1652
+
1653
+ def infer_value_for_Min(input_x):
1654
+ """Infer value for Max op."""
1655
+ return _infer_value_for_max_min(input_x, 'Min')
1656
+
1657
+
1658
+ def infer_value_for_ReduceProd(input_x, axis, keep_dims):
1659
+ """Infer value for ReduceProd op."""
1660
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceProd')
1661
+
1662
+
1663
+ def infer_value_for_ReduceMean(input_x, axis, keep_dims):
1664
+ """Infer value for ReduceMean op."""
1665
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMean')
1666
+
1667
+
1668
+ def infer_value_for_ReduceAll(input_x, axis, keep_dims):
1669
+ """Infer value for ReduceAll op."""
1670
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAll')
1671
+
1672
+
1673
+ def infer_value_for_ReduceAny(input_x, axis, keep_dims):
1674
+ """Infer value for ReduceAny op."""
1675
+ return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAny')
1676
+
1677
+
1678
+ def infer_value_for_MeanExt(input_x, axis, keep_dims, dtype):
1679
+ """Infer value for MeanExt op."""
1680
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'MeanExt')
1681
+
1682
+
1683
+ def infer_value_for_SumExt(input_x, axis, keep_dims, dtype):
1684
+ """Infer value for SumExt op."""
1685
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'SumExt')
1686
+
1687
+
1688
+ def infer_value_for_ProdExt(input_x, axis, keep_dims, dtype):
1689
+ """Infer value for ProdExt op."""
1690
+ return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'ProdExt')
1691
+
1692
+
1693
+ def infer_value_for_Diag(input_x):
1694
+ """Infer value for Diag op."""
1695
+ if input_x is None:
1696
+ return None
1697
+ # do constant-folding only when x rank is 1
1698
+ if len(input_x.shape) != 1:
1699
+ return None
1700
+ ret = np.diag(input_x.asnumpy())
1701
+ return Tensor(ret)
1702
+
1703
+
1704
+ def infer_value_for_BroadcastTo(x, shape):
1705
+ """Infer value for BroadcastTo op."""
1706
+ def none_in_tuple_or_list(x):
1707
+ return isinstance(x, (tuple, list)) and None in x
1708
+ if shape is None or none_in_tuple_or_list(shape) or x is None:
1709
+ return None
1710
+
1711
+ if isinstance(shape, (Tensor, Tensor_)):
1712
+ validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
1713
+ [mstype.int32, mstype.int64], "BroadcastTo")
1714
+ shape = shape.asnumpy().tolist()
1715
+ else:
1716
+ validator.check_value_type("shape", shape, [tuple], "BroadcastTo")
1717
+ shape = list(shape)
1718
+
1719
+ np_data = np.broadcast_to(x.asnumpy(), shape)
1720
+ if 0 in shape:
1721
+ init_func = Zero()
1722
+ init_func.__enable_zero_dim__ = True
1723
+ out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1724
+ return out
1725
+ return Tensor(np_data)
1726
+
1727
+
1728
+ def infer_value_for_Reshape(x, shape):
1729
+ """Infer value for Reshape op."""
1730
+ def none_in_tuple_or_list(x):
1731
+ return isinstance(x, (tuple, list)) and None in x
1732
+ # for shape is not constant
1733
+ if shape is None or none_in_tuple_or_list(shape) or x is None:
1734
+ return None
1735
+
1736
+ if isinstance(shape, (Tensor, Tensor_)):
1737
+ validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
1738
+ [mstype.int32, mstype.int64], "Reshape")
1739
+ shape = shape.asnumpy().tolist()
1740
+ else:
1741
+ validator.check_value_type("shape", shape, [tuple], "Reshape")
1742
+ shape = list(shape)
1743
+
1744
+ neg_index = -1
1745
+ dim_prod = 1
1746
+ for i, shp_i in enumerate(shape):
1747
+ validator.check_value_type("shape[%d]" % i, shp_i, [int], "Reshape")
1748
+ if shp_i == -1:
1749
+ if neg_index != -1:
1750
+ raise ValueError(f"For 'Reshape', there can be at most one '-1' in 'input_shape', "
1751
+ f"but got {shape}.")
1752
+ neg_index = i
1753
+ else:
1754
+ dim_prod *= shp_i
1755
+ out = None
1756
+ if not is_shape_unknown(x.shape):
1757
+ x_shp = x.shape
1758
+ if dim_prod < 0:
1759
+ raise ValueError(f"For 'Reshape', the shape of 'input_x' is {x_shp}, "
1760
+ f"the value of 'input_shape' is {shape}. "
1761
+ f"The product of 'input_shape' should > 0, but got {dim_prod}.")
1762
+ arr_prod = np.prod(x_shp)
1763
+ if neg_index != -1:
1764
+ shape[neg_index] = int(arr_prod // dim_prod)
1765
+ dim_prod *= shape[neg_index]
1766
+ if dim_prod != arr_prod:
1767
+ raise ValueError(f"For 'Reshape', the product of the 'input_x' shape "
1768
+ f"should be equal to product of 'input_shape', but got product of the"
1769
+ f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
1770
+ if 0 in shape:
1771
+ init_func = Zero()
1772
+ init_func.__enable_zero_dim__ = True
1773
+ out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1774
+ else:
1775
+ out = Tensor(x.asnumpy().reshape(shape))
1776
+ return out
1777
+
1778
+
1779
+ class Ones(Primitive):
1780
+ r"""
1781
+ Creates a tensor filled with value ones.
1782
+
1783
+ Refer to :func:`mindspore.ops.ones` for more details.
1784
+
1785
+ .. warning::
1786
+ For argument `size`, Tensor type input will be deprecated in the future version.
1787
+
1788
+ Inputs:
1789
+ - **shape** (Union[tuple[int], List[int], int, Tensor]) - The specified shape of output tensor.
1790
+ - **type** (:class:`mindspore.dtype`) - The specified type of output tensor.
1791
+
1792
+ Outputs:
1793
+ Tensor, whose dtype and size are defined by input.
1794
+
1795
+ Raises:
1796
+ TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1797
+
1798
+ Supported Platforms:
1799
+ ``Ascend`` ``GPU`` ``CPU``
1800
+
1801
+ Examples:
1802
+ >>> import mindspore
1803
+ >>> from mindspore import ops
1804
+ >>> ones = ops.Ones()
1805
+ >>> output = ones((2, 2), mindspore.float32)
1806
+ >>> print(output)
1807
+ [[1. 1.]
1808
+ [1. 1.]]
1809
+ >>> output = ones((3, 3), mindspore.float32)
1810
+ >>> print(output)
1811
+ [[1. 1. 1.]
1812
+ [1. 1. 1.]
1813
+ [1. 1. 1.]]
1814
+ """
1815
+
1816
+ __mindspore_signature__ = (
1817
+ sig.make_sig('size'),
1818
+ sig.make_sig('type', default=None),
1819
+ )
1820
+
1821
+ @prim_arg_register
1822
+ def __init__(self):
1823
+ pass
1824
+
1825
+ def __call__(self, size, type=None):
1826
+ return _convert_stub(pyboost_ones(self, [size, type if type is None \
1827
+ else handler.dtype_to_type_id('Ones', 'type', type)]))
1828
+
1829
+
1830
+ class Zeros(Primitive):
1831
+ r"""
1832
+ Zeros will be deprecated in the future. Please use class `mindspore.ops.zeros` instead.
1833
+
1834
+ Creates a tensor filled with value zeros.
1835
+
1836
+ Creates a tensor with shape described by the first argument and
1837
+ fills it with value zeros in type of the second argument.
1838
+
1839
+ .. warning::
1840
+ For argument `size`, Tensor type input will be deprecated in the future version.
1841
+
1842
+ Inputs:
1843
+ - **shape** (tuple[int], List[int], int, Tensor) - The specified shape of output tensor.
1844
+ - **type** (mindspore.dtype) - The specified type of output tensor.
1845
+
1846
+ Outputs:
1847
+ Tensor, whose dtype and size are defined by input.
1848
+
1849
+ Raises:
1850
+ TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1851
+
1852
+ Supported Platforms:
1853
+ ``Ascend`` ``GPU`` ``CPU``
1854
+
1855
+ Examples:
1856
+ >>> import mindspore
1857
+ >>> from mindspore import ops
1858
+ >>> zeros = ops.Zeros()
1859
+ >>> output = zeros((2, 2), mindspore.float32)
1860
+ >>> print(output)
1861
+ [[0. 0.]
1862
+ [0. 0.]]
1863
+
1864
+ """
1865
+
1866
+ __mindspore_signature__ = (
1867
+ sig.make_sig('size'),
1868
+ sig.make_sig('type', default=None),
1869
+ )
1870
+
1871
+ @prim_arg_register
1872
+ def __init__(self):
1873
+ pass
1874
+
1875
+ def __call__(self, size, type=None):
1876
+ return _convert_stub(pyboost_zeros(self, [size, type if type is None else \
1877
+ handler.dtype_to_type_id('Zeros', 'type', type)]))
1878
+
1879
+
1880
+ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
1881
+ attn_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, keep_prob=1.0,
1882
+ scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
1883
+ input_layout='BSH', sparse_mode=0):
1884
+ r"""
1885
+ The interface is not open to the public, just for internal use,
1886
+
1887
+ .. math::
1888
+ \begin{array}{ll} \\
1889
+ y = Dropout(Softmax(Mask(scale_value \mul (real_shift + query * key), attn_mask), -1), keep\_prob) \\
1890
+ \mul value \\
1891
+ \end{array}
1892
+
1893
+ B -- Batch size. Value range 1 to 2k.
1894
+ S1 -- Sequence length of query. Value range 1 to 512k.
1895
+ S2 -- Sequence length of key and value. Value range 1 to 512k.
1896
+ N1 -- Num heads of query. Value range 1 to 256.
1897
+ N2 -- Num heads of key and value, and N2 must be a factor of N1.
1898
+ D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
1899
+ H1 -- Hidden size of query, which equals to N1 * D.
1900
+ H2 -- Hidden size of key and value, which equals to N2 * D.
1901
+
1902
+ .. warning::
1903
+ This is an experimental API that is subject to change or deletion. Only support on Atlas A2 training series.
1904
+
1905
+ Args:
1906
+ query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
1907
+ `(B, N1, S1, D)`, `(S1, B, H1)`, `(B, S1, N1, D)` or `(T1, N1, D)`.
1908
+ key (Tensor[float16, bfloat16]): The key tensor. Input tensor of shape :math:`(B, S2, H2)`,
1909
+ `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`.
1910
+ value (Tensor[float16, bfloat16]): The value tensor. Input tensor of shape :math:`(B, S2, H2)`,
1911
+ `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. The key and value have the same shape.
1912
+ head_num (int): The head num of query, equal to N1.
1913
+ real_shift (Union[Tensor[float16, bfloat16], None]): Also known as pse. The position embedding code. If S
1914
+ is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of
1915
+ the lower triangle for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
1916
+ `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1917
+
1918
+ - ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
1919
+ In this scenario, real_shift is `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1920
+ - Non-ALiBi scenario: real_shift is `(B, N1, S1, S2)`, `(1, N1, S1, S2)`.
1921
+
1922
+ The shape of `real_shift` should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)` when input_layout is
1923
+ `TND`.
1924
+ drop_mask (Union[Tensor[uint8], None]): The dropout mask tensor. Input tensor of shape :math:
1925
+ `(B, N1, S1, S2 // 8) or None`. S2 is a multiple of 8 when not None.
1926
+ padding_mask (None): Reserved parameter. Not implemented yet.
1927
+ attn_mask (Union[Tensor[uint8], Tensor[bool], None]): The attention mask tensor. For each element, 0
1928
+ indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
1929
+ `(B, 1, S1, S2)`, `(S1, S2)` or `(2048, 2048)`. In compression scenario, sparse_mode is 2, 3, or 4,
1930
+ attn_mask must be `(2048, 2048)`. When sparse_mode is 5, attn_mask must be `(B, N1, S1, S2)`,
1931
+ `(B, 1, S1, S2)`. When sparse_mode is 0 and 1, attn_mask should be `(B, N1, S1, S2)`, `(B, 1, S1, S2)`,
1932
+ `(S1, S2)`.
1933
+ prefix (Union[List[int64], Tuple[int64] None]): N value of each Batch in the prefix sparse calculation
1934
+ scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when sparse_mode is 5.
1935
+ If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
1936
+ actual_seq_qlen (Union[List[int64], Tuple[int64], None]): Size of query corresponding to each batch, array
1937
+ with increasing values and the last value equal to T1.
1938
+ actual_seq_kvlen (Union[List[int64], Tuple[int64], None]): Size of key and value corresponding to each batch,
1939
+ array with increasing values and the last value equal to T2.
1940
+ keep_prob (float): The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0. when keep_prob
1941
+ is 1.0, drop_mask should be none.
1942
+ scale_value (float): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0.
1943
+ pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
1944
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1945
+ next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
1946
+ When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1947
+ The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the
1948
+ valid area on the attn_mask matrix. It must ensure that the band is not empty.
1949
+ The following values are not allowed:
1950
+
1951
+ - pre_tokens < 0 and next_tokens < 0.
1952
+ - (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
1953
+ - (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
1954
+
1955
+ inner_precise (int): The parameter is reserved and not implemented yet. Default: 0.
1956
+ input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH", "BNSD",
1957
+ "SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH".
1958
+ When input_layout is "TND", the following restrictions must be met.
1959
+ There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
1960
+ value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
1961
+ list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
1962
+ sum(list_seq_k) = 22.
1963
+ max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
1964
+ qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
1965
+
1966
+ - The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024.
1967
+ - When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none.
1968
+ Otherwise, they are none.
1969
+ - The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
1970
+ be non-decreasing.
1971
+ - If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
1972
+ list_seq_k is greater than 1024. Real_shift should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)`, and
1973
+ S2 is equal to max_seqlen_k.
1974
+ - Attn mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask
1975
+ should be `(2048, 2048)`.
1976
+ - The shape of drop_mask is (qk_pointer * N1 // 8,).
1977
+ - Prefix is none.
1978
+ - Next_tokens is 0, and pre_tokens is not less than max_seqlen_q.
1979
+ - When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
1980
+ - 0 should not exist in list_seq_k.
1981
+
1982
+ sparse_mode (int): Indicates sparse mode. Default 0.
1983
+
1984
+ - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
1985
+ and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full
1986
+ attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and
1987
+ nextTokens needs to be calculated.
1988
+ - 1: Represents allMask, that is, passing in the complete attn_mask matrix.
1989
+ - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
1990
+ vertex, and the optimized attn_mask matrix (2048*2048) is required.
1991
+ - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
1992
+ right vertex, and the optimized attn_mask matrix (2048*2048) is required.
1993
+ - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
1994
+ optimized attn_mask matrix (2048*2048) is required.
1995
+ - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
1996
+ width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
1997
+ of each Batch axis is different, not implemented yet.
1998
+ - 6: Represents the global scenario, not implemented yet.
1999
+ - 7: Represents the dilated scenario, not implemented yet.
2000
+ - 8: Represents the block_local scenario, not implemented yet.
2001
+
2002
+ Returns:
2003
+ attention_out (Tensor[float16, bfloat16]), The output of attention, its shape, and data type are the same
2004
+ as the query.
2005
+
2006
+ Supported Platforms:
2007
+ ``Ascend``
2008
+
2009
+ Examples:
2010
+ >>> import mindspore
2011
+ >>> import mindspore.common.dtype as mstype
2012
+ >>> import numpy as np
2013
+ >>> from mindspore import ops, Tensor
2014
+ >>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
2015
+ >>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
2016
+ >>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
2017
+ >>> head_num = 4
2018
+ >>> output = ops.flash_attention_score(query, key, value, head_num)
2019
+ >>> print(output.shape)
2020
+ (2, 4, 64)
2021
+ """
2022
+ rank_op = _get_cache_prim(FlashAttentionScore)(head_num, keep_prob, scalar_value, pre_tokens, next_tokens,
2023
+ inner_precise, input_layout, sparse_mode)
2024
+ return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen,
2025
+ actual_seq_kvlen)[3]
2026
+
2027
+
2028
+ class WhileLoop(Primitive):
2029
+ """
2030
+ Provide a useful op for reducing compilation times of while loop.
2031
+ The execution logic of the WhileLoop operator can be roughly represented by the following code:
2032
+
2033
+ .. code-block:: python
2034
+
2035
+ def WhileLoop(cond_func, loop_func, init_val):
2036
+ while(cond_func(init_val)):
2037
+ init_val = loop_func(init_val)
2038
+ return init_val
2039
+
2040
+ The current WhileLoop operator has the following syntactic limitations:
2041
+
2042
+ - Using a side-effect function as `loop_func` is currently not support,
2043
+ such as operations that modify parameters, global variables, etc.
2044
+ - The return value of `loop_func` being of a different type or shape
2045
+ from the `init_val` is currently not support.
2046
+
2047
+ .. warning::
2048
+ This is an experimental API that is subject to change or deletion.
2049
+
2050
+ Inputs:
2051
+ - **cond_func** (Function) - The condition function.
2052
+ - **loop_func** (Function) - The loop function, take one argument and
2053
+ return value has the same type with input argument.
2054
+ - **init_val** (Union[Tensor, number, str, bool, list, tuple, dict]) - The initial value.
2055
+
2056
+ Outputs:
2057
+ Union[Tensor, number, str, bool, list, tuple, dict], the final result of the while loop,
2058
+ has same type and shape with input `init_val` .
2059
+
2060
+ Raises:
2061
+ TypeError: If `cond_func` is not a function.
2062
+ TypeError: If `loop_func` is not a function.
2063
+ ValueError: If `loop_func` cannot take `init_val` as input or has different
2064
+ output type or shape with `init_val` .
2065
+
2066
+ Supported Platforms:
2067
+ ``Ascend`` ``GPU`` ``CPU``
2068
+
2069
+ Examples:
2070
+ >>> from mindspore import ops
2071
+ >>> def loop_while_fun(init_val):
2072
+ ... val = init_val
2073
+ ... val = val + 1
2074
+ ... return val
2075
+ ...
2076
+ >>> init_state = 10
2077
+ >>> while_loop = ops.WhileLoop()
2078
+ >>> result = while_loop(lambda x : x < 100, loop_while_fun, init_state)
2079
+ >>> print(result)
2080
+ 100
2081
+ """
2082
+
2083
+ @prim_attr_register
2084
+ def __init__(self):
2085
+ """Initialize WhileLoop."""
2086
+
2087
+ def __call__(self, cond_func, loop_func, init_val):
2088
+ validator.check_value_type("cond_func", cond_func,
2089
+ [types.FunctionType, types.MethodType], "WhileLoop")
2090
+ validator.check_value_type("loop_func", loop_func,
2091
+ [types.FunctionType, types.MethodType], "WhileLoop")
2092
+ val = init_val
2093
+ try:
2094
+ while cond_func(val):
2095
+ val = loop_func(val)
2096
+ except Exception as e:
2097
+ raise ValueError("Invalid loop_func, please check input arguments and \
2098
+ return value, error info: {}".format(e))
2099
+ return val
2100
+
2101
+
2102
+ class Scan(Primitive):
2103
+ """
2104
+ Scan a function over an array while the processing of the current element
2105
+ depends on the execution result of the previous element.
2106
+ The execution logic of the Scan operator can be roughly represented by the following code:
2107
+
2108
+ .. code-block:: python
2109
+
2110
+ def Scan(loop_func, init, xs, length=None):
2111
+ if xs is None:
2112
+ xs = [None] * length
2113
+ carry = init
2114
+ ys = []
2115
+ for x in xs:
2116
+ carry, y = loop_func(carry, x)
2117
+ ys.append(y)
2118
+ return carry, ys
2119
+
2120
+ The current Scan operator has the following syntactic limitations:
2121
+
2122
+ - Using a side-effect function as `loop_func` is currently not support,
2123
+ such as operations that modify parameters, global variables, etc.
2124
+ - The first element of the return value of `loop_func` being of a different
2125
+ type or shape from the `init_val` is currently not support.
2126
+
2127
+ .. warning::
2128
+ This is an experimental API that is subject to change or deletion.
2129
+
2130
+ Inputs:
2131
+ - **loop_func** (Function) - The loop function.
2132
+ - **init** (Union[Tensor, number, str, bool, list, tuple, dict]) - An initial loop carry value.
2133
+ - **xs** (Union[tuple, list, None]) - The value over which to scan.
2134
+ - **length** (Union[int, None], optional) - The size of xs. Default: ``None`` .
2135
+ - **unroll** (bool, optional) - The flag for whether to perform loop unrolling in compile process.
2136
+ Default: ``True`` .
2137
+
2138
+ Outputs:
2139
+ Tuple(Union[Tensor, number, str, bool, list, tuple, dict], list). Output of scan loop,
2140
+ a tuple with two elements, the first element has same type and shape with init argument,
2141
+ and the second is a list containing the results of each loop.
2142
+
2143
+ Raises:
2144
+ TypeError: If `loop_func` is not a function.
2145
+ TypeError: If `xs` is not a tuple, a list or None.
2146
+ TypeError: If `length` is not an int or None.
2147
+ TypeError: If `unroll` is not a bool.
2148
+ ValueError: If `loop_func` cannot take `init` and element of `xs` as inputs.
2149
+ ValueError: If the return value of `loop_func` is not a tuple with two elements,
2150
+ or the first element of the tuple has different type or shape from `init` .
2151
+
2152
+ Supported Platforms:
2153
+ ``Ascend`` ``GPU`` ``CPU``
2154
+
2155
+ Examples:
2156
+ >>> from mindspore import ops
2157
+ >>> def cumsum(res, el):
2158
+ ... res = res + el
2159
+ ... return res, res
2160
+ ...
2161
+ >>> a = [1, 2, 3, 4]
2162
+ >>> result_init = 0
2163
+ >>> scan_op = ops.Scan()
2164
+ >>> result = scan_op(cumsum, result_init, a)
2165
+ >>> print(result == (10, [1, 3, 6, 10]))
2166
+ True
2167
+ """
2168
+
2169
+ @prim_attr_register
2170
+ def __init__(self):
2171
+ """Initialize Scan."""
2172
+
2173
+ def __call__(self, loop_func, init, xs, length=None, unroll=True):
2174
+ validator.check_value_type("loop_func", loop_func,
2175
+ [types.FunctionType, types.MethodType], "Scan")
2176
+ validator.check_value_type("xs", xs, [list, tuple, None], "Scan")
2177
+ if xs is None:
2178
+ validator.check_value_type("length", length, [int], "Scan")
2179
+ xs = [None] * length
2180
+ carry = init
2181
+ length = len(xs)
2182
+ if not length:
2183
+ return init, []
2184
+ try:
2185
+ carry, y = loop_func(carry, xs[0])
2186
+ ys = [y]
2187
+ i = 1
2188
+ while i < length:
2189
+ carry, y = loop_func(carry, xs[i])
2190
+ ys.append(y)
2191
+ i = i + 1
2192
+ except Exception as e:
2193
+ raise ValueError("Invalid loop_func, please check input arguments and \
2194
+ return value, error info: {}".format(e))
2195
+ return carry, ys
2196
+
2197
+
2198
+ class ForiLoop(Primitive):
2199
+ """
2200
+ Provide a useful op for loop from lower to upper.
2201
+ The execution logic of the ForiLoop operator can be roughly represented by the following code:
2202
+
2203
+ .. code-block:: python
2204
+
2205
+ def ForiLoop(lower, upper, loop_func, init_val):
2206
+ for i in range(lower, upper):
2207
+ init_val = loop_func(i, init_val)
2208
+ return init_val
2209
+
2210
+ The current ForiLoop operator has the following syntactic limitations:
2211
+
2212
+ - Using a side-effect function as `loop_func` is currently not support,
2213
+ such as operations that modify parameters, global variables, etc.
2214
+ - The return value of `loop_func` being of a different type or shape
2215
+ from the `init_val` is currently not support.
2216
+ - Negative numbers or custom increments is currently not support.
2217
+
2218
+ .. warning::
2219
+ This is an experimental API that is subject to change or deletion.
2220
+
2221
+ Inputs:
2222
+ - **lower** (Union[int, Tensor]) - The start index of loop.
2223
+ - **upper** (Union[int, Tensor]) - The end index of loop.
2224
+ - **loop_func** (Function) - The loop function, takes two arguments.
2225
+ - **init_val** (Union[Tensor, number, str, bool, list, tuple, dict]) - The init value.
2226
+ - **unroll** (bool, optional) - The flag for whether unroll in compile process,
2227
+ only valid when the number of loop iterations is determined. Default: ``True`` .
2228
+
2229
+ Outputs:
2230
+ Union[Tensor, number, str, bool, list, tuple, dict], the final result of the loop,
2231
+ has same type and shape with input `init_val` .
2232
+
2233
+ Raises:
2234
+ TypeError: If `lower` is not an int or a Tensor.
2235
+ TypeError: If `upper` is not an int or a Tensor.
2236
+ TypeError: If `loop_func` is not a function.
2237
+ ValueError: If `loop_func` cannot take index and `init_val` as arguments or if the type
2238
+ of output it produces is different from the type or shape of `init_val` .
2239
+
2240
+ Supported Platforms:
2241
+ ``Ascend`` ``GPU`` ``CPU``
2242
+
2243
+ Examples:
2244
+ >>> from mindspore import ops
2245
+ >>> def cumsum(index, res):
2246
+ ... return index + res
2247
+ ...
2248
+ >>> result_init = 0
2249
+ >>> fori_loop = ops.ForiLoop()
2250
+ >>> result = fori_loop(0, 4, cumsum, result_init)
2251
+ >>> print(result)
2252
+ 6
2253
+ """
2254
+
2255
+ @prim_attr_register
2256
+ def __init__(self):
2257
+ """Initialize ForiLoop."""
2258
+
2259
+ def __call__(self, lower, upper, loop_func, init_val, unroll=True):
2260
+ validator.check_value_type("lower", lower, [int, Tensor], "ForiLoop")
2261
+ validator.check_value_type("upper", upper, [int, Tensor], "ForiLoop")
2262
+ validator.check_value_type("loop_func", loop_func,
2263
+ [types.FunctionType, types.MethodType], "ForiLoop")
2264
+ val = init_val
2265
+ try:
2266
+ for i in range(lower, upper):
2267
+ val = loop_func(i, val)
2268
+ except Exception as e:
2269
+ raise ValueError("Invalid loop_func, please check input arguments and \
2270
+ return value, error info: {}".format(e))
2271
+ return val