mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1217) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +124 -25
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +299 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +182 -68
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +192 -252
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +67 -26
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +20 -7
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_pijit_context.py +190 -0
  46. mindspore/common/_register_for_adapter.py +7 -0
  47. mindspore/common/_register_for_recompute.py +48 -0
  48. mindspore/common/_register_for_tensor.py +10 -10
  49. mindspore/common/_stub_tensor.py +7 -1
  50. mindspore/common/_tensor_overload.py +139 -0
  51. mindspore/common/_utils.py +5 -17
  52. mindspore/common/api.py +449 -129
  53. mindspore/common/auto_dynamic_shape.py +27 -14
  54. mindspore/common/dtype.py +17 -10
  55. mindspore/common/dump.py +8 -11
  56. mindspore/common/file_system.py +48 -0
  57. mindspore/common/generator.py +254 -0
  58. mindspore/common/hook_handle.py +65 -30
  59. mindspore/common/initializer.py +1 -1
  60. mindspore/common/jit_config.py +34 -14
  61. mindspore/common/lazy_inline.py +72 -19
  62. mindspore/common/mindir_util.py +12 -2
  63. mindspore/common/mutable.py +79 -14
  64. mindspore/common/no_inline.py +54 -0
  65. mindspore/common/np_dtype.py +25 -0
  66. mindspore/common/parameter.py +73 -21
  67. mindspore/common/recompute.py +292 -0
  68. mindspore/common/seed.py +9 -9
  69. mindspore/common/sparse_tensor.py +276 -24
  70. mindspore/common/symbol.py +122 -0
  71. mindspore/common/tensor.py +668 -514
  72. mindspore/communication/__init__.py +6 -11
  73. mindspore/communication/_comm_helper.py +43 -3
  74. mindspore/communication/comm_func.py +1395 -0
  75. mindspore/communication/management.py +117 -104
  76. mindspore/config/op_info.config +22 -54
  77. mindspore/context.py +455 -71
  78. mindspore/dataset/__init__.py +5 -5
  79. mindspore/dataset/audio/__init__.py +6 -6
  80. mindspore/dataset/audio/transforms.py +711 -158
  81. mindspore/dataset/callback/ds_callback.py +2 -2
  82. mindspore/dataset/core/config.py +7 -0
  83. mindspore/dataset/core/validator_helpers.py +7 -0
  84. mindspore/dataset/engine/cache_client.py +2 -2
  85. mindspore/dataset/engine/datasets.py +201 -116
  86. mindspore/dataset/engine/datasets_audio.py +14 -14
  87. mindspore/dataset/engine/datasets_standard_format.py +83 -3
  88. mindspore/dataset/engine/datasets_text.py +39 -39
  89. mindspore/dataset/engine/datasets_user_defined.py +230 -141
  90. mindspore/dataset/engine/datasets_vision.py +78 -74
  91. mindspore/dataset/engine/iterators.py +29 -0
  92. mindspore/dataset/engine/obs/util.py +7 -0
  93. mindspore/dataset/engine/offload.py +5 -7
  94. mindspore/dataset/engine/queue.py +138 -66
  95. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  96. mindspore/dataset/engine/validators.py +41 -15
  97. mindspore/dataset/text/__init__.py +2 -5
  98. mindspore/dataset/text/transforms.py +408 -121
  99. mindspore/dataset/text/utils.py +9 -9
  100. mindspore/dataset/transforms/__init__.py +0 -3
  101. mindspore/dataset/transforms/transforms.py +261 -76
  102. mindspore/dataset/utils/browse_dataset.py +9 -9
  103. mindspore/dataset/utils/line_reader.py +2 -0
  104. mindspore/dataset/vision/__init__.py +7 -10
  105. mindspore/dataset/vision/c_transforms.py +10 -10
  106. mindspore/dataset/vision/py_transforms_util.py +1 -1
  107. mindspore/dataset/vision/transforms.py +2844 -549
  108. mindspore/dataset/vision/utils.py +161 -10
  109. mindspore/dataset/vision/validators.py +16 -3
  110. mindspore/dnnl.dll +0 -0
  111. mindspore/dpcmi.dll +0 -0
  112. mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
  113. mindspore/experimental/es/embedding_service.py +883 -0
  114. mindspore/experimental/es/embedding_service_layer.py +581 -0
  115. mindspore/experimental/llm_boost/__init__.py +21 -0
  116. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  119. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  120. mindspore/experimental/llm_boost/register.py +129 -0
  121. mindspore/experimental/llm_boost/utils.py +31 -0
  122. mindspore/experimental/optim/__init__.py +12 -2
  123. mindspore/experimental/optim/adadelta.py +161 -0
  124. mindspore/experimental/optim/adagrad.py +168 -0
  125. mindspore/experimental/optim/adam.py +35 -34
  126. mindspore/experimental/optim/adamax.py +170 -0
  127. mindspore/experimental/optim/adamw.py +124 -15
  128. mindspore/experimental/optim/asgd.py +153 -0
  129. mindspore/experimental/optim/lr_scheduler.py +66 -121
  130. mindspore/experimental/optim/nadam.py +157 -0
  131. mindspore/experimental/optim/optimizer.py +18 -8
  132. mindspore/experimental/optim/radam.py +194 -0
  133. mindspore/experimental/optim/rmsprop.py +154 -0
  134. mindspore/experimental/optim/rprop.py +164 -0
  135. mindspore/experimental/optim/sgd.py +28 -19
  136. mindspore/hal/__init__.py +40 -0
  137. mindspore/hal/_ascend.py +57 -0
  138. mindspore/hal/_base.py +57 -0
  139. mindspore/hal/_cpu.py +56 -0
  140. mindspore/hal/_gpu.py +57 -0
  141. mindspore/hal/contiguous_tensors_handle.py +175 -0
  142. mindspore/hal/device.py +356 -0
  143. mindspore/hal/event.py +179 -0
  144. mindspore/hal/memory.py +326 -0
  145. mindspore/hal/stream.py +357 -0
  146. mindspore/include/api/data_type.h +2 -2
  147. mindspore/include/api/dual_abi_helper.h +16 -3
  148. mindspore/include/api/model.h +4 -3
  149. mindspore/include/api/model_group.h +13 -1
  150. mindspore/include/api/status.h +14 -0
  151. mindspore/include/api/types.h +10 -10
  152. mindspore/include/c_api/model_c.h +173 -0
  153. mindspore/include/c_api/types_c.h +19 -0
  154. mindspore/include/dataset/config.h +2 -2
  155. mindspore/include/dataset/constants.h +2 -2
  156. mindspore/include/dataset/execute.h +3 -5
  157. mindspore/include/dataset/vision.h +58 -2
  158. mindspore/jpeg62.dll +0 -0
  159. mindspore/log.py +3 -3
  160. mindspore/mindrecord/__init__.py +5 -1
  161. mindspore/mindrecord/config.py +809 -0
  162. mindspore/mindrecord/filereader.py +25 -0
  163. mindspore/mindrecord/filewriter.py +138 -103
  164. mindspore/mindrecord/mindpage.py +40 -6
  165. mindspore/mindrecord/shardutils.py +3 -2
  166. mindspore/mindrecord/shardwriter.py +7 -0
  167. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  168. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  169. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  170. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  171. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  172. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  173. mindspore/mindspore_backend.dll +0 -0
  174. mindspore/mindspore_common.dll +0 -0
  175. mindspore/mindspore_core.dll +0 -0
  176. mindspore/mindspore_glog.dll +0 -0
  177. mindspore/mindspore_np_dtype.dll +0 -0
  178. mindspore/mindspore_ops.dll +0 -0
  179. mindspore/mint/__init__.py +1586 -0
  180. mindspore/mint/distributed/__init__.py +31 -0
  181. mindspore/mint/distributed/distributed.py +254 -0
  182. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  183. mindspore/mint/nn/__init__.py +757 -0
  184. mindspore/mint/nn/functional.py +679 -0
  185. mindspore/mint/nn/layer/__init__.py +39 -0
  186. mindspore/mint/nn/layer/activation.py +133 -0
  187. mindspore/mint/nn/layer/normalization.py +477 -0
  188. mindspore/mint/nn/layer/pooling.py +110 -0
  189. mindspore/mint/optim/__init__.py +24 -0
  190. mindspore/mint/optim/adamw.py +206 -0
  191. mindspore/mint/special/__init__.py +63 -0
  192. mindspore/msobj140.dll +0 -0
  193. mindspore/mspdb140.dll +0 -0
  194. mindspore/mspdbcore.dll +0 -0
  195. mindspore/mspdbst.dll +0 -0
  196. mindspore/mspft140.dll +0 -0
  197. mindspore/msvcdis140.dll +0 -0
  198. mindspore/msvcp140_1.dll +0 -0
  199. mindspore/msvcp140_2.dll +0 -0
  200. mindspore/msvcp140_atomic_wait.dll +0 -0
  201. mindspore/msvcp140_codecvt_ids.dll +0 -0
  202. mindspore/multiprocessing/__init__.py +73 -0
  203. mindspore/nn/cell.py +461 -323
  204. mindspore/nn/dynamic_lr.py +2 -2
  205. mindspore/nn/layer/activation.py +292 -135
  206. mindspore/nn/layer/basic.py +288 -83
  207. mindspore/nn/layer/channel_shuffle.py +3 -16
  208. mindspore/nn/layer/container.py +3 -3
  209. mindspore/nn/layer/conv.py +75 -66
  210. mindspore/nn/layer/embedding.py +221 -45
  211. mindspore/nn/layer/image.py +4 -7
  212. mindspore/nn/layer/math.py +1 -1
  213. mindspore/nn/layer/normalization.py +150 -68
  214. mindspore/nn/layer/padding.py +64 -87
  215. mindspore/nn/layer/pooling.py +175 -12
  216. mindspore/nn/layer/rnn_cells.py +6 -16
  217. mindspore/nn/layer/rnns.py +6 -5
  218. mindspore/nn/layer/thor_layer.py +1 -2
  219. mindspore/nn/layer/timedistributed.py +1 -1
  220. mindspore/nn/layer/transformer.py +55 -53
  221. mindspore/nn/learning_rate_schedule.py +6 -5
  222. mindspore/nn/loss/__init__.py +2 -2
  223. mindspore/nn/loss/loss.py +145 -88
  224. mindspore/nn/optim/__init__.py +2 -1
  225. mindspore/nn/optim/ada_grad.py +4 -2
  226. mindspore/nn/optim/adadelta.py +4 -2
  227. mindspore/nn/optim/adafactor.py +1 -1
  228. mindspore/nn/optim/adam.py +102 -181
  229. mindspore/nn/optim/adamax.py +4 -2
  230. mindspore/nn/optim/adasum.py +3 -3
  231. mindspore/nn/optim/asgd.py +4 -2
  232. mindspore/nn/optim/ftrl.py +31 -61
  233. mindspore/nn/optim/lamb.py +5 -3
  234. mindspore/nn/optim/lars.py +2 -2
  235. mindspore/nn/optim/lazyadam.py +6 -4
  236. mindspore/nn/optim/momentum.py +13 -25
  237. mindspore/nn/optim/optimizer.py +6 -3
  238. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  239. mindspore/nn/optim/rmsprop.py +9 -3
  240. mindspore/nn/optim/rprop.py +4 -2
  241. mindspore/nn/optim/sgd.py +5 -3
  242. mindspore/nn/optim/tft_wrapper.py +127 -0
  243. mindspore/nn/optim/thor.py +2 -2
  244. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  245. mindspore/nn/probability/distribution/beta.py +2 -2
  246. mindspore/nn/probability/distribution/categorical.py +4 -6
  247. mindspore/nn/probability/distribution/cauchy.py +2 -2
  248. mindspore/nn/probability/distribution/exponential.py +2 -2
  249. mindspore/nn/probability/distribution/geometric.py +1 -1
  250. mindspore/nn/probability/distribution/gumbel.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +1 -1
  252. mindspore/nn/probability/distribution/poisson.py +2 -2
  253. mindspore/nn/probability/distribution/uniform.py +2 -2
  254. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  255. mindspore/nn/wrap/__init__.py +2 -1
  256. mindspore/nn/wrap/cell_wrapper.py +46 -12
  257. mindspore/nn/wrap/grad_reducer.py +148 -8
  258. mindspore/nn/wrap/loss_scale.py +44 -7
  259. mindspore/numpy/__init__.py +2 -0
  260. mindspore/numpy/array_creations.py +67 -68
  261. mindspore/numpy/array_ops.py +70 -66
  262. mindspore/numpy/dtypes.py +3 -3
  263. mindspore/numpy/fft.py +966 -0
  264. mindspore/numpy/logic_ops.py +11 -10
  265. mindspore/numpy/math_ops.py +147 -152
  266. mindspore/numpy/utils.py +3 -0
  267. mindspore/numpy/utils_const.py +4 -4
  268. mindspore/opencv_core452.dll +0 -0
  269. mindspore/opencv_imgcodecs452.dll +0 -0
  270. mindspore/opencv_imgproc452.dll +0 -0
  271. mindspore/ops/__init__.py +9 -6
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +162 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +147 -59
  295. mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +201 -66
  307. mindspore/ops/composite/math_ops.py +10 -49
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/function/__init__.py +53 -11
  340. mindspore/ops/function/array_func.py +1269 -1821
  341. mindspore/ops/function/clip_func.py +19 -31
  342. mindspore/ops/function/debug_func.py +114 -5
  343. mindspore/ops/function/fft_func.py +44 -0
  344. mindspore/ops/function/grad/grad_func.py +30 -22
  345. mindspore/ops/function/image_func.py +27 -21
  346. mindspore/ops/function/linalg_func.py +35 -68
  347. mindspore/ops/function/math_func.py +1170 -2697
  348. mindspore/ops/function/nn_func.py +2116 -1128
  349. mindspore/ops/function/other_func.py +8 -8
  350. mindspore/ops/function/parameter_func.py +5 -93
  351. mindspore/ops/function/random_func.py +435 -113
  352. mindspore/ops/function/reshard_func.py +104 -0
  353. mindspore/ops/function/sparse_func.py +4 -4
  354. mindspore/ops/function/sparse_unary_func.py +9 -16
  355. mindspore/ops/function/spectral_func.py +1 -1
  356. mindspore/ops/function/vmap_func.py +16 -15
  357. mindspore/ops/functional.py +355 -346
  358. mindspore/ops/op_info_register.py +18 -45
  359. mindspore/ops/operations/__init__.py +38 -24
  360. mindspore/ops/operations/_grad_ops.py +21 -927
  361. mindspore/ops/operations/_infer_ops.py +19 -0
  362. mindspore/ops/operations/_inner_ops.py +173 -607
  363. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/_tensor_array.py +8 -8
  367. mindspore/ops/operations/array_ops.py +106 -2837
  368. mindspore/ops/operations/comm_ops.py +799 -127
  369. mindspore/ops/operations/custom_ops.py +124 -119
  370. mindspore/ops/operations/debug_ops.py +142 -41
  371. mindspore/ops/operations/image_ops.py +1 -217
  372. mindspore/ops/operations/inner_ops.py +5 -40
  373. mindspore/ops/operations/linalg_ops.py +1 -49
  374. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  375. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  376. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  377. mindspore/ops/operations/math_ops.py +666 -4972
  378. mindspore/ops/operations/nn_ops.py +205 -2213
  379. mindspore/ops/operations/other_ops.py +60 -49
  380. mindspore/ops/operations/random_ops.py +50 -54
  381. mindspore/ops/operations/reshard_ops.py +53 -0
  382. mindspore/ops/operations/sparse_ops.py +4 -4
  383. mindspore/ops/primitive.py +216 -103
  384. mindspore/ops_generate/__init__.py +27 -0
  385. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  386. mindspore/ops_generate/arg_handler.py +197 -0
  387. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  388. mindspore/ops_generate/gen_constants.py +36 -0
  389. mindspore/ops_generate/gen_ops.py +1099 -0
  390. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  391. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  392. mindspore/ops_generate/gen_utils.py +209 -0
  393. mindspore/ops_generate/op_proto.py +145 -0
  394. mindspore/ops_generate/pyboost_utils.py +367 -0
  395. mindspore/ops_generate/template.py +261 -0
  396. mindspore/parallel/__init__.py +8 -4
  397. mindspore/parallel/_auto_parallel_context.py +100 -10
  398. mindspore/parallel/_cell_wrapper.py +99 -9
  399. mindspore/parallel/_cost_model_context.py +1 -1
  400. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  401. mindspore/parallel/_parallel_serialization.py +67 -23
  402. mindspore/parallel/_ps_context.py +1 -1
  403. mindspore/parallel/_recovery_context.py +1 -1
  404. mindspore/parallel/_tensor.py +99 -22
  405. mindspore/parallel/_transformer/__init__.py +1 -1
  406. mindspore/parallel/_transformer/layers.py +1 -1
  407. mindspore/parallel/_transformer/loss.py +1 -1
  408. mindspore/parallel/_transformer/moe.py +1 -1
  409. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  410. mindspore/parallel/_transformer/transformer.py +2 -2
  411. mindspore/parallel/_utils.py +173 -6
  412. mindspore/parallel/algo_parameter_config.py +8 -10
  413. mindspore/parallel/checkpoint_transform.py +204 -38
  414. mindspore/parallel/cluster/__init__.py +15 -0
  415. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  416. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  417. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  418. mindspore/parallel/cluster/run.py +136 -0
  419. mindspore/parallel/mpi/__init__.py +1 -1
  420. mindspore/parallel/mpi/_mpi_config.py +1 -1
  421. mindspore/parallel/parameter_broadcast.py +151 -0
  422. mindspore/parallel/shard.py +279 -37
  423. mindspore/parallel/transform_safetensors.py +993 -0
  424. mindspore/pgodb140.dll +0 -0
  425. mindspore/pgort140.dll +0 -0
  426. mindspore/profiler/__init__.py +4 -2
  427. mindspore/profiler/common/constant.py +29 -0
  428. mindspore/profiler/common/process_pool.py +41 -0
  429. mindspore/profiler/common/registry.py +47 -0
  430. mindspore/profiler/common/singleton.py +28 -0
  431. mindspore/profiler/common/util.py +153 -0
  432. mindspore/profiler/dynamic_profiler.py +694 -0
  433. mindspore/profiler/envprofiling.py +18 -20
  434. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  435. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  436. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  437. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  438. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  439. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  440. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  441. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  442. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  443. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  444. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  445. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  446. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  447. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  448. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  449. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  450. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  451. mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
  452. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  453. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  454. mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
  455. mindspore/profiler/parser/base_timeline_generator.py +25 -25
  456. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  457. mindspore/profiler/parser/framework_parser.py +4 -393
  458. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  459. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  460. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  461. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  462. mindspore/profiler/parser/integrator.py +3 -1
  463. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  464. mindspore/profiler/parser/minddata_parser.py +72 -3
  465. mindspore/profiler/parser/profiler_info.py +94 -7
  466. mindspore/profiler/profiler.py +153 -0
  467. mindspore/profiler/profiling.py +631 -508
  468. mindspore/rewrite/__init__.py +2 -14
  469. mindspore/rewrite/api/node.py +122 -36
  470. mindspore/rewrite/api/pattern_engine.py +2 -3
  471. mindspore/rewrite/api/scoped_value.py +16 -15
  472. mindspore/rewrite/api/symbol_tree.py +45 -29
  473. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  474. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  475. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  476. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  477. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  478. mindspore/rewrite/common/__init__.py +1 -2
  479. mindspore/rewrite/common/config.py +24 -0
  480. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  481. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  482. mindspore/rewrite/common/namespace.py +118 -0
  483. mindspore/rewrite/node/__init__.py +5 -5
  484. mindspore/rewrite/node/call_function.py +23 -7
  485. mindspore/rewrite/node/cell_container.py +7 -3
  486. mindspore/rewrite/node/control_flow.py +53 -28
  487. mindspore/rewrite/node/node.py +212 -196
  488. mindspore/rewrite/node/node_manager.py +51 -22
  489. mindspore/rewrite/node/node_topological_manager.py +3 -23
  490. mindspore/rewrite/parsers/__init__.py +12 -0
  491. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  492. mindspore/rewrite/parsers/assign_parser.py +637 -413
  493. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  494. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  495. mindspore/rewrite/parsers/constant_parser.py +5 -5
  496. mindspore/rewrite/parsers/container_parser.py +4 -6
  497. mindspore/rewrite/parsers/expr_parser.py +55 -0
  498. mindspore/rewrite/parsers/for_parser.py +31 -98
  499. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  500. mindspore/rewrite/parsers/if_parser.py +28 -10
  501. mindspore/rewrite/parsers/module_parser.py +8 -182
  502. mindspore/rewrite/parsers/parser.py +1 -5
  503. mindspore/rewrite/parsers/parser_register.py +1 -1
  504. mindspore/rewrite/parsers/return_parser.py +5 -10
  505. mindspore/rewrite/parsers/while_parser.py +59 -0
  506. mindspore/rewrite/sparsify/utils.py +1 -1
  507. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  508. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
  509. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  510. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  511. mindspore/run_check/_check_version.py +40 -115
  512. mindspore/run_check/run_check.py +1 -1
  513. mindspore/safeguard/rewrite_obfuscation.py +597 -263
  514. mindspore/swresample-4.dll +0 -0
  515. mindspore/swscale-6.dll +0 -0
  516. mindspore/tbbmalloc.dll +0 -0
  517. mindspore/tinyxml2.dll +0 -0
  518. mindspore/train/__init__.py +7 -5
  519. mindspore/train/_utils.py +204 -4
  520. mindspore/train/amp.py +335 -295
  521. mindspore/train/anf_ir_pb2.py +14 -2
  522. mindspore/train/callback/__init__.py +5 -2
  523. mindspore/train/callback/_backup_and_restore.py +5 -5
  524. mindspore/train/callback/_callback.py +4 -4
  525. mindspore/train/callback/_checkpoint.py +220 -43
  526. mindspore/train/callback/_cluster_monitor.py +201 -0
  527. mindspore/train/callback/_early_stop.py +2 -2
  528. mindspore/train/callback/_flops_collector.py +239 -0
  529. mindspore/train/callback/_landscape.py +15 -9
  530. mindspore/train/callback/_loss_monitor.py +5 -5
  531. mindspore/train/callback/_on_request_exit.py +136 -33
  532. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  533. mindspore/train/callback/_summary_collector.py +12 -12
  534. mindspore/train/callback/_tft_register.py +352 -0
  535. mindspore/train/callback/_time_monitor.py +3 -3
  536. mindspore/train/data_sink.py +6 -5
  537. mindspore/train/dataset_helper.py +66 -23
  538. mindspore/train/loss_scale_manager.py +2 -2
  539. mindspore/train/metrics/accuracy.py +7 -7
  540. mindspore/train/metrics/confusion_matrix.py +8 -6
  541. mindspore/train/metrics/cosine_similarity.py +6 -4
  542. mindspore/train/metrics/error.py +2 -2
  543. mindspore/train/metrics/metric.py +3 -3
  544. mindspore/train/metrics/perplexity.py +2 -1
  545. mindspore/train/metrics/roc.py +4 -4
  546. mindspore/train/metrics/topk.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +116 -37
  548. mindspore/train/model.py +382 -76
  549. mindspore/train/serialization.py +787 -288
  550. mindspore/train/summary/_summary_adapter.py +1 -1
  551. mindspore/train/summary/summary_record.py +51 -28
  552. mindspore/train/train_thor/convert_utils.py +3 -3
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +21 -0
  555. mindspore/utils/utils.py +60 -0
  556. mindspore/vcmeta.dll +0 -0
  557. mindspore/vcruntime140.dll +0 -0
  558. mindspore/vcruntime140_1.dll +0 -0
  559. mindspore/version.py +1 -1
  560. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
  561. mindspore-2.4.0.dist-info/RECORD +1406 -0
  562. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
  563. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  564. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  565. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  566. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  567. mindspore/gen_ops.py +0 -273
  568. mindspore/include/c_api/ms/abstract.h +0 -67
  569. mindspore/include/c_api/ms/attribute.h +0 -197
  570. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  571. mindspore/include/c_api/ms/base/macros.h +0 -32
  572. mindspore/include/c_api/ms/base/status.h +0 -33
  573. mindspore/include/c_api/ms/base/types.h +0 -282
  574. mindspore/include/c_api/ms/context.h +0 -102
  575. mindspore/include/c_api/ms/graph.h +0 -160
  576. mindspore/include/c_api/ms/node.h +0 -606
  577. mindspore/include/c_api/ms/tensor.h +0 -161
  578. mindspore/include/c_api/ms/value.h +0 -84
  579. mindspore/mindspore_shared_lib.dll +0 -0
  580. mindspore/nn/layer/flash_attention.py +0 -189
  581. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  582. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  583. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  584. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  585. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  586. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  587. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  588. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  589. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  590. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  591. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  592. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  593. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  594. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  595. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  596. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  597. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  601. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  602. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  603. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  604. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  605. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  606. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  607. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  608. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  609. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  610. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  611. mindspore/ops/_op_impl/tbe/add.py +0 -42
  612. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  613. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  614. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  615. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  616. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  617. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  618. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  620. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  621. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  622. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  623. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  624. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  625. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  626. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  627. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  628. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  629. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  631. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  632. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  633. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  634. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  635. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  636. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  637. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  638. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  639. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  640. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  641. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  642. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  643. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  644. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  645. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  646. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  647. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  648. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  649. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  650. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  651. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  653. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  654. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  655. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  656. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  658. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  659. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  660. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  661. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  662. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  663. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  664. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  666. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  667. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  668. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  669. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  670. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  671. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  672. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  673. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  674. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  675. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  676. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  677. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  678. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  680. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  681. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  682. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  683. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  684. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  685. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  686. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  687. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  688. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  689. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  690. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  691. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  692. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  693. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  694. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  695. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  696. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  697. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  698. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  699. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  700. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  701. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  702. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  703. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  705. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  706. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  707. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  709. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  710. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  711. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  712. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  713. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  714. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  715. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  716. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  717. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  718. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  719. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  720. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  721. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  722. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  724. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  725. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  726. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  727. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  728. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  729. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  730. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  731. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  732. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  733. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  734. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  735. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  737. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  738. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  739. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  740. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  741. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  742. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  743. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  744. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  745. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  746. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  747. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  748. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  749. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  750. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  751. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  752. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  754. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  756. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  757. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  758. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  759. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  760. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  761. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  762. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  763. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  764. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  765. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  766. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  767. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  768. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  769. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  770. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  771. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  772. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  773. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  774. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  776. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  777. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  778. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  779. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  780. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  781. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  782. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  783. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  784. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  785. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  786. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  787. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  788. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  789. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  790. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  791. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  792. mindspore/ops/_op_impl/tbe/div.py +0 -41
  793. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  794. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  795. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  797. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  798. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  799. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  800. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  801. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  802. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  803. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  804. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  805. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  806. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  807. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  809. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  810. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  811. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  813. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  814. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  815. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  816. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  817. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  818. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  819. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  820. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  821. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  822. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  823. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  824. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  825. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  826. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  827. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  828. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  829. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  830. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  831. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  832. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  833. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  834. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  835. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  836. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  837. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  838. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  839. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  840. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  841. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  842. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  843. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  844. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  845. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  846. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  847. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  848. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  849. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  850. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  851. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  852. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  853. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  854. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  855. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  856. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  857. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  858. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  859. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  861. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  862. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  863. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  864. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  865. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  866. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  867. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  868. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  869. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  870. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  871. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  872. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  873. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  874. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  875. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  876. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  877. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  878. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  879. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  881. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  883. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  884. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  885. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  886. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  887. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  888. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  889. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  890. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  891. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  892. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  893. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  895. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  896. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  897. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  898. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  899. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  900. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  901. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  902. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  903. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  904. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  905. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  906. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  907. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  908. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  909. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  910. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  911. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  912. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  913. mindspore/ops/_op_impl/tbe/less.py +0 -41
  914. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  915. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  916. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  917. mindspore/ops/_op_impl/tbe/log.py +0 -40
  918. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  919. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  921. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  922. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  923. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  924. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  925. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  926. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  928. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  929. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  930. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  931. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  932. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  933. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  934. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  935. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  936. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  937. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  938. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  939. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  940. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  941. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  942. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  943. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  944. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  945. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  946. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  947. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  948. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  949. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  950. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  951. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  952. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  953. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  954. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  955. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  956. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  957. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  958. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  959. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  960. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  962. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  963. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  964. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  965. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  966. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  967. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  968. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  969. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  971. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  972. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  974. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  975. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  976. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  977. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  978. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  980. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  981. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  982. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  983. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  984. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  985. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  986. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  987. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  988. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  989. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  990. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  991. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  992. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  993. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  994. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  995. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  996. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  997. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  998. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  999. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  1000. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  1002. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  1003. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  1004. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  1005. mindspore/ops/_op_impl/tbe/range.py +0 -39
  1006. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  1007. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  1008. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  1009. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  1013. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  1014. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  1017. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  1018. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  1020. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  1022. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  1023. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  1024. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  1026. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  1027. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  1029. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  1030. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  1032. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  1033. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  1034. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  1035. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  1036. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  1037. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  1038. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  1039. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1041. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1042. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1043. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1045. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1051. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1052. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1053. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1054. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1055. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1060. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1061. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1062. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1063. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1064. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1065. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1066. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1067. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1068. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1069. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1070. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1071. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1072. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1073. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1074. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1075. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1076. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1077. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1078. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1079. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1080. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1083. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1084. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1085. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1086. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1087. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1088. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1089. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1090. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1092. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1093. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1097. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1099. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1100. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1101. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1102. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1103. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1104. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1106. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1107. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1110. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1111. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1115. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1118. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1119. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1122. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1123. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1124. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1125. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1126. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1127. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1128. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1129. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1130. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1132. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1133. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1134. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1135. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1136. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1137. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1138. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1143. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1144. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1145. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1146. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1147. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1148. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1149. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1151. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1152. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1154. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1155. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1156. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1157. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1158. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1159. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1160. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1161. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1162. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1163. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1164. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1165. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1166. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1167. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1168. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1169. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1170. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1171. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1172. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1173. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1174. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1175. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1176. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1177. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1178. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1179. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1180. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1181. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1182. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1183. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1184. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1185. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1186. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1187. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1188. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1189. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1190. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1191. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1192. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1193. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1194. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1195. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1196. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1197. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1198. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1199. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1200. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1201. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1202. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1203. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1204. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1205. mindspore/ops/_tracefunc.py +0 -241
  1206. mindspore/ops/arg_dtype_cast.py +0 -54
  1207. mindspore/ops/silent_check.py +0 -162
  1208. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  1209. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  1210. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1211. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1212. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1213. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1214. mindspore/rewrite/namespace.py +0 -53
  1215. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1216. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  1217. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
14
14
  # ============================================================================
15
15
 
16
16
  """Operators for gradients."""
17
+ # pylint: disable=unused-import
17
18
  from __future__ import absolute_import
18
19
 
19
20
  from __future__ import division
@@ -27,6 +28,15 @@ from mindspore import _checkparam as validator
27
28
  from mindspore.common import dtype as mstype
28
29
  from mindspore.communication.management import GlobalComm
29
30
  from mindspore.common._utils import is_shape_unknown, is_dim_unknown
31
+ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad, AsinhGrad, ReciprocalGrad, RsqrtGrad,
32
+ SqrtGrad, BatchNormGrad, BatchNormGradGrad, BiasAddGrad, GeLUGrad, FastGeLUGrad,
33
+ AvgPoolGrad, MinimumGrad, LogSoftmaxGrad, PReLUGrad, ReluGrad, ReLU6Grad, EluGrad,
34
+ GatherDGradV2, ResizeBilinearGrad, ResizeLinear1DGrad, ResizeNearestNeighborV2Grad,
35
+ SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
36
+ ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
37
+ HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
38
+ FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
39
+ BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad)
30
40
 
31
41
 
32
42
  class SparseFillEmptyRowsGrad(Primitive):
@@ -39,92 +49,6 @@ class SparseFillEmptyRowsGrad(Primitive):
39
49
  outputs=['y_values', 'y_default_value'])
40
50
 
41
51
 
42
- class AbsGrad(PrimitiveWithInfer):
43
- """Computes gradients for abs operation."""
44
-
45
- @prim_attr_register
46
- def __init__(self):
47
- """Initialize AbsGrad"""
48
-
49
-
50
- class ACosGrad(Primitive):
51
- """
52
- Computes ACosGrad of input element-wise.
53
-
54
- Returns:
55
- Tensor, has the same type as input.
56
- """
57
-
58
- @prim_attr_register
59
- def __init__(self):
60
- """Initialize ACosGrad"""
61
- self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
62
-
63
-
64
- class LogitGrad(Primitive):
65
- """
66
- Computes LogitGrad of input element-wise.
67
-
68
- Returns:
69
- Tensor, has the same type as input.
70
- """
71
- @prim_attr_register
72
- def __init__(self, eps=-1.0):
73
- """Initialize Exp"""
74
- self.init_prim_io_names(inputs=['grad', 'input'], outputs=['dx'])
75
- validator.check_value_type("eps", eps, [float], self.name)
76
- self.add_prim_attr('eps', eps)
77
-
78
-
79
- class AcoshGrad(Primitive):
80
- """Performs grad of Acosh operation."""
81
-
82
- @prim_attr_register
83
- def __init__(self):
84
- """Initialize AcoshGrad"""
85
- self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
86
-
87
-
88
- class AsinGrad(Primitive):
89
- """
90
- Computes AsinGrad of input element-wise.
91
-
92
- Returns:
93
- Tensor, has the same type as input.
94
- """
95
-
96
- @prim_attr_register
97
- def __init__(self):
98
- """Initialize AsinGrad"""
99
- self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
100
-
101
-
102
- class AsinhGrad(Primitive):
103
- """Performs grad of Asinh operation."""
104
-
105
- @prim_attr_register
106
- def __init__(self):
107
- """Initialize AsinhGrad"""
108
- self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
109
-
110
-
111
- class ReciprocalGrad(Primitive):
112
- """Performs grad of Reciprocal operation."""
113
-
114
- @prim_attr_register
115
- def __init__(self):
116
- """Initialize ReciprocalGrad"""
117
- self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
118
-
119
-
120
- class RsqrtGrad(Primitive):
121
- """Performs grad of Rsqrt operation."""
122
-
123
- @prim_attr_register
124
- def __init__(self):
125
- """Initialize RsqrtGrad"""
126
-
127
-
128
52
  class ScaleAndTranslateGrad(Primitive):
129
53
  """Performs grad of ScaleAndTranslate operation."""
130
54
 
@@ -137,39 +61,15 @@ class ScaleAndTranslateGrad(Primitive):
137
61
  validator.check_value_type("antialias", antialias, [bool], self.name)
138
62
 
139
63
 
140
- class SoftmaxGrad(ReciprocalGrad):
64
+ class SoftmaxGrad(Primitive):
141
65
  """Performs grad of Softmax operation."""
142
66
 
143
-
144
- class SqrtGrad(Primitive):
145
- """Performs grad of Sqrt operation."""
146
-
147
67
  @prim_attr_register
148
68
  def __init__(self):
149
- """Initialize SqrtGrad"""
69
+ """Initialize SoftmaxGrad"""
150
70
  self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
151
71
 
152
72
 
153
- class BatchNormGrad(Primitive):
154
- """Performs grad of BatchNorm operation."""
155
-
156
- @prim_attr_register
157
- def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'):
158
- self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
159
- self.epsilon = validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
160
- self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
161
-
162
-
163
- class BatchNormGradGrad(Primitive):
164
- """Performs grad of BatchNormGrad operation."""
165
-
166
- @prim_attr_register
167
- def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'):
168
- self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
169
- self.epsilon = validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
170
- self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
171
-
172
-
173
73
  class SyncBatchNormGrad(Primitive):
174
74
  """Performs grad of SyncBatchNorm operation."""
175
75
 
@@ -181,18 +81,6 @@ class SyncBatchNormGrad(Primitive):
181
81
  validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
182
82
 
183
83
 
184
- class BiasAddGrad(Primitive):
185
- """Computes gradients of BiasAdd."""
186
-
187
- @prim_attr_register
188
- def __init__(self, data_format="NCHW"):
189
- self.init_prim_io_names(inputs=['dout'], outputs=['output'])
190
- self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
191
- if self.format == "NCDHW":
192
- self.format = "NCHW"
193
- self.add_prim_attr('data_format', self.format)
194
-
195
-
196
84
  class KLDivLossGrad(Primitive):
197
85
  """Computes gradients for `KLDivLoss` operation."""
198
86
 
@@ -210,14 +98,6 @@ class KLDivLossGrad(Primitive):
210
98
  self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
211
99
 
212
100
 
213
- class BinaryCrossEntropyGrad(Primitive):
214
- """Computes gradients for `BinaryCrossEntropy` operation."""
215
-
216
- @prim_attr_register
217
- def __init__(self, reduction='mean'):
218
- self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
219
-
220
-
221
101
  class LuUnpackGrad(Primitive):
222
102
  """Computes gradients for `LuUnpack` operation."""
223
103
 
@@ -713,22 +593,6 @@ class NeighborExchangeV2Grad(PrimitiveWithInfer):
713
593
  'value': None}
714
594
 
715
595
 
716
- class GeLUGrad(Primitive):
717
- """Gradients of GeLU operation."""
718
-
719
- @prim_attr_register
720
- def __init__(self):
721
- self.init_prim_io_names(inputs=['dy', 'x', 'y'], outputs=['z'])
722
-
723
-
724
- class FastGeLUGrad(Primitive):
725
- """Gradients of FastGeLU operation."""
726
-
727
- @prim_attr_register
728
- def __init__(self):
729
- """init FastGeLUGrad"""
730
-
731
-
732
596
  class _PoolGrad(PrimitiveWithInfer):
733
597
  """Gradients of the max/avg pool operation."""
734
598
 
@@ -813,20 +677,6 @@ class AvgPoolGradGe(_PoolGrad):
813
677
  return out
814
678
 
815
679
 
816
- class AvgPoolGrad(_PoolGrad):
817
- """Gradients of the avg pool operation."""
818
-
819
- @prim_attr_register
820
- def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
821
- super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format)
822
-
823
- def infer_shape(self, x1_shape, x2_shape, grad_shape):
824
- return x1_shape
825
-
826
- def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
827
- return x1_dtype
828
-
829
-
830
680
  class AvgPoolGradV1(Primitive):
831
681
  """Gradients of the AvgPoolV1 operation."""
832
682
 
@@ -1192,25 +1042,6 @@ class MaxPool3DGradGrad(PrimitiveWithInfer):
1192
1042
  return x_dtype
1193
1043
 
1194
1044
 
1195
- class MaximumGrad(Primitive):
1196
- """Grad for maximum."""
1197
-
1198
- @prim_attr_register
1199
- def __init__(self, grad_x=True, grad_y=True):
1200
- """Initialize MaximumGrad"""
1201
- self.init_prim_io_names(inputs=['x1', 'x2', 'grads'], outputs=['y1', 'y2'])
1202
-
1203
-
1204
- class MaximumGradGrad(Primitive):
1205
- """Grad for maximum grad."""
1206
-
1207
- @prim_attr_register
1208
- def __init__(self, grad_x=True, grad_y=True):
1209
- """Initialize MaximumGradGrad"""
1210
- super().__init__("MaximumGradGrad")
1211
- self.init_prim_io_names(inputs=['x1', 'x2', 'dy1', 'dy2'], outputs=['sopd_x1', 'sopd_x2', 'sopd_grad'])
1212
-
1213
-
1214
1045
  class MaxPoolGradWithArgmax(Primitive):
1215
1046
  """Computes the gradients of MaxPoolWithArgmax."""
1216
1047
  @prim_attr_register
@@ -1359,15 +1190,6 @@ class MaxPoolGradGradWithArgmax(_PoolGrad):
1359
1190
  return grad_dtype
1360
1191
 
1361
1192
 
1362
- class MinimumGrad(Primitive):
1363
- """Grad for minimum."""
1364
-
1365
- @prim_attr_register
1366
- def __init__(self, grad_x=True, grad_y=True):
1367
- """Initialize MinimumGrad"""
1368
- self.init_prim_io_names(inputs=['x1', 'x2', 'grads'], outputs=['y1', 'y2'])
1369
-
1370
-
1371
1193
  class MinimumGradGrad(Primitive):
1372
1194
  """Grad for minimum_grad."""
1373
1195
  @prim_attr_register
@@ -1406,79 +1228,6 @@ class L2NormalizeGrad(Primitive):
1406
1228
  raise TypeError("The length of axis must be 1, later will support multiple axis!")
1407
1229
 
1408
1230
 
1409
- class LayerNormGrad(Primitive):
1410
- """
1411
- Applies the layer Normalization to the input array.
1412
-
1413
- This operator will calculate the input gradients of layernorm.
1414
-
1415
- Args:
1416
- begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
1417
- begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
1418
-
1419
- Returns:
1420
- tuple[int], tuple of 3 values (the gradients of layernorm input, gamma, beta).
1421
- """
1422
-
1423
- @prim_attr_register
1424
- def __init__(self, begin_norm_axis=1, begin_params_axis=1):
1425
- """init"""
1426
- self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
1427
- self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
1428
-
1429
-
1430
- class LayerNormGradGrad(Primitive):
1431
- """
1432
- Gets the gradient of LayerNormGrad operation.
1433
-
1434
- Args:
1435
- begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
1436
- begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
1437
-
1438
- Inputs:
1439
- - **x** (Tensor) - The input tensor to be normalized, float32 or float16.
1440
- - **dy** (Tensor) - The gradient of LayerNorm's output y, float32 or float16.
1441
- - **variance** (Tensor) - The variance of x, float32 or float16.
1442
- - **mean** (Tensor) - The mean of x, float32 or float16.
1443
- - **gamma** (Tensor) - The original value of weight gamma initialized in LayerNorm, float32 or float16.
1444
- Default: 'ones'.
1445
- - **d_dx** (Tensor) - The gradient of dx, where dx is the gradient of LayerNorm's input x, float32 or float16.
1446
- - **d_dg** (Tensor) - The gradient of dg, where dg is the gradient of LayerNorm's weight gamma,
1447
- float32 or float16.
1448
- - **d_db** (Tensor) - The gradient of db, where db is the gradient of LayerNorm's weight beta,
1449
- float32 or float16.
1450
-
1451
- Returns:
1452
- Tuple[Tensor], tuple of 3 Tensors (the gradients of layernormgrad x, dy, gamma).
1453
-
1454
- Raises:
1455
- TypeError: If the 8 inputs don't have the same dtype.
1456
- ValueError: If x, dy, d_dx don't have the same shape.
1457
- ValueError: If variance, mean don't have the same shape.
1458
- ValueError: If gamma, d_dg, d_db don't have the same shape.
1459
-
1460
- Supported Platforms:
1461
- ``Ascend`` ``GPU`` ``CPU``
1462
- """
1463
-
1464
- @prim_attr_register
1465
- def __init__(self, begin_norm_axis=1, begin_params_axis=1):
1466
- """init"""
1467
- self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
1468
- self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
1469
- self.init_prim_io_names(inputs=['x', 'dy', 'variance', 'mean', 'gamma', 'd_dx', 'd_dg', 'd_db'],
1470
- outputs=['sopd_x', 'sopd_dy', 'sopd_gamma'])
1471
-
1472
-
1473
- class LogSoftmaxGrad(Primitive):
1474
- """Computes gradient for the Log Softmax activation."""
1475
-
1476
- @prim_attr_register
1477
- def __init__(self, axis=-1):
1478
- """Initialize LogSoftmaxGrad"""
1479
- validator.check_value_type("axis", axis, [int], self.name)
1480
-
1481
-
1482
1231
  class LSTMGradData(Primitive):
1483
1232
  """Computes the data gradients of LSTM."""
1484
1233
 
@@ -1741,27 +1490,6 @@ class DynamicGRUV2Grad(Primitive):
1741
1490
  ])
1742
1491
 
1743
1492
 
1744
- class PReLUGrad(Primitive):
1745
- r"""
1746
- Gradients of PReLU operation.
1747
-
1748
- Note:
1749
- 1-dimensional input_x is not supported.
1750
-
1751
- Inputs:
1752
- - **y_backprop** (Tensor) - Representing the backprop of the next layer.
1753
- - **input_x** (Tensor) - Must be the input `input_x` of forward operator PRelu.
1754
- - **weight** (Tensor) - Float Tensor, w > 0, must be the input `weight` of forward operator PRelu.
1755
-
1756
- Outputs:
1757
- Tensor, with the same type as `input_x`.
1758
- """
1759
-
1760
- @prim_attr_register
1761
- def __init__(self):
1762
- pass
1763
-
1764
-
1765
1493
  class RandomGammaGrad(Primitive):
1766
1494
  r"""
1767
1495
  Computes the derivative of a random sample of Gamma with respect to alpha.:
@@ -1800,180 +1528,6 @@ class RandomGammaGrad(Primitive):
1800
1528
  self.add_prim_attr("side_effect_hidden", True)
1801
1529
 
1802
1530
 
1803
- class ReluGrad(Primitive):
1804
- """Performs grad of Relu operation."""
1805
-
1806
- @prim_attr_register
1807
- def __init__(self):
1808
- """Initialize ReluGrad"""
1809
- self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
1810
-
1811
-
1812
- class SiLUGrad(Primitive):
1813
- """Performs grad of SiLU operation."""
1814
-
1815
- @prim_attr_register
1816
- def __init__(self):
1817
- """Initialize SiLUGrad"""
1818
- self.init_prim_io_names(inputs=['dout', 'out'], outputs=['output'])
1819
-
1820
-
1821
- class ReLU6Grad(Primitive):
1822
- """Performs grad of ReLU6 operation."""
1823
-
1824
- @prim_attr_register
1825
- def __init__(self):
1826
- self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
1827
-
1828
-
1829
- class ReluGradV2(Primitive):
1830
- """Performs grad of ReLUV2 operation."""
1831
-
1832
- @prim_attr_register
1833
- def __init__(self):
1834
- self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
1835
-
1836
-
1837
- class EluGrad(Primitive):
1838
- """Performs grad of Elu operation."""
1839
-
1840
- @prim_attr_register
1841
- def __init__(self):
1842
- """Initialize EluGrad"""
1843
- self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
1844
-
1845
-
1846
- class GatherDGrad(Primitive):
1847
- """Performs grad of GatherD operation."""
1848
-
1849
- @prim_attr_register
1850
- def __init__(self, dim=0, shape=None):
1851
- """Initialize GatherDGrad"""
1852
- validator.check_is_int(dim, int)
1853
- self.add_prim_attr("dim", dim)
1854
- self.dim = dim
1855
- self.out_shape = shape
1856
- self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output'])
1857
-
1858
-
1859
- class GatherDGradV2(Primitive):
1860
- """Performs grad of GatherD operation."""
1861
-
1862
- @prim_attr_register
1863
- def __init__(self):
1864
- """Initialize GatherDGradV2"""
1865
- self.init_prim_io_names(inputs=['x', 'dim', 'index', 'grad'], outputs=['output'])
1866
-
1867
-
1868
- class ResizeBilinearGrad(Primitive):
1869
- """Performs grad of ResizeBilinear operation."""
1870
-
1871
- @prim_attr_register
1872
- def __init__(self, align_corners=False, half_pixel_centers=False):
1873
- """init"""
1874
- validator.check_value_type("align_corners", align_corners, [bool], self.name)
1875
- validator.check_value_type("half_pixel_centers", half_pixel_centers, [bool], self.name)
1876
- self.align_corners = validator.check_value_type("align_corners", align_corners, [bool], self.name)
1877
- self.half_pixel_centers = validator.check_value_type("half_pixel_centers",
1878
- half_pixel_centers, [bool], self.name)
1879
- self.init_prim_io_names(inputs=['grads', 'original_image'], outputs=['y'])
1880
- if half_pixel_centers and align_corners:
1881
- raise ValueError(f"If half_pixel_centers is True, align_corners must be False, but got {align_corners}")
1882
-
1883
-
1884
- class ResizeNearestNeighborGrad(Primitive):
1885
- """
1886
- Compute gradient of `ResizeNearestNeighbor` operator.
1887
-
1888
- Note:
1889
- The shape of input parameter `size` must be (height, width).
1890
-
1891
- Args:
1892
- align_corners (bool): Whether the centers of the 4 corner pixels of the input
1893
- and output tensors are aligned. Default: ``False``.
1894
- """
1895
-
1896
- @prim_attr_register
1897
- def __init__(self, align_corners=False):
1898
- """Initialize ResizeNearestNeighborGrad"""
1899
- self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
1900
-
1901
-
1902
- class ResizeLinear1DGrad(Primitive):
1903
- """
1904
- Compute gradient of `ResizeLinear1D` operator.
1905
-
1906
- .. warning::
1907
- This is an experimental API that is subject to change.
1908
-
1909
- Args:
1910
- coordinate_transformation_mode (string): Default is 'align_corners'. Describes how to transform the coordinate
1911
- in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel'.
1912
- """
1913
-
1914
- @prim_attr_register
1915
- def __init__(self, coordinate_transformation_mode="align_corners"):
1916
- """Initialize ResizeLinear1DGrad"""
1917
- self.init_prim_io_names(
1918
- inputs=['grads', 'input_x'], outputs=['y'])
1919
- validator.check_value_type(
1920
- "coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
1921
- validator.check_string(coordinate_transformation_mode, ["align_corners", "half_pixel"],
1922
- "coordinate_transformation_mode", self.name)
1923
-
1924
-
1925
- class ResizeNearestNeighborV2Grad(Primitive):
1926
- """
1927
- Compute gradient of `ResizeNearestNeighborV2` operator.
1928
-
1929
- Args:
1930
- align_corners (bool): Whether the centers of the 4 corner pixels of the input
1931
- and output tensors are aligned. Default: ``False``.
1932
- half_pixel_centers (bool): Default: ``False``.
1933
- """
1934
-
1935
- @prim_attr_register
1936
- def __init__(self, align_corners=False, half_pixel_centers=False):
1937
- """Initialize ResizeNearestNeighborV2Grad"""
1938
- self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
1939
- validator.check_value_type('align_corners', align_corners, [bool], self.name)
1940
- validator.check_value_type('half_pixel_centers', half_pixel_centers, [bool], self.name)
1941
-
1942
-
1943
- class UpsampleNearest3DGrad(Primitive):
1944
- """
1945
- Upsample the 3-D gradient data with the nearest neighbor interpolation algorithm.
1946
-
1947
- Note:
1948
- Only one of 'scales' and 'output_size' can be specified, and it is an error if both are specified.
1949
-
1950
- Inputs:
1951
- - **dy** (Tensor) - Tensor of shape [N, C, D, H, W], Must be one of the following types:
1952
- float16, float32, float64.
1953
- - **input_size** (listInt): An required listInt, which contain 5 elements:
1954
- [min_batch, channels, depth, height, width].
1955
- Must: input_size[0] == dy_tensor_size[0], input_size[1] == dy_tensor_size[1].
1956
- - **output_size** (listInt): An optional listInt. Default: ``None``.
1957
- It contains 3 elements: depth, height, width, whose elements should be the same as `dy`.
1958
- Must:
1959
- dy_tensor_size[2] == floor(input_size[2] * scales[0]) == output_size[0],
1960
- dy_tensor_size[3] == floor(input_size[3] * scales[1]) == output_size[1],
1961
- dy_tensor_size[4] == floor(input_size[4] * scales[2]) == output_size[2].
1962
- - **scales** (listFloat): An optional listFloat. Default: ``None``.
1963
- The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width.
1964
- The number of elements of 'scales' should be the same as the rank of `dy`.
1965
-
1966
- Outputs:
1967
- - **dx**- (Tensor) - A 5-D tensor. Has the same type as `dy`, shape depends on `input_size`.
1968
- """
1969
- @prim_attr_register
1970
- def __init__(self):
1971
- """Initialize UpsampleNearest3DGrad."""
1972
- self.init_prim_io_names(
1973
- inputs=['dy', 'input_size', 'output_size', 'scales'],
1974
- outputs=['dx'])
1975
-
1976
-
1977
1531
  class ROIAlignGrad(Primitive):
1978
1532
  """
1979
1533
  ROIAlignGrad operator.
@@ -2034,15 +1588,6 @@ class PsROIPoolingGrad(PrimitiveWithInfer):
2034
1588
  return ydiff_type
2035
1589
 
2036
1590
 
2037
- class SigmoidGrad(Primitive):
2038
- """Gets the gradient of Sigmoid operation."""
2039
-
2040
- @prim_attr_register
2041
- def __init__(self):
2042
- """Initialize SigmoidGrad"""
2043
- self.init_prim_io_names(inputs=['y', 'dy'], outputs=['output'])
2044
-
2045
-
2046
1591
  class _ActivationGrad(PrimitiveWithInfer):
2047
1592
  """_ActivationGrad base class."""
2048
1593
 
@@ -2060,14 +1605,6 @@ class _ActivationGrad(PrimitiveWithInfer):
2060
1605
  return x_dtype
2061
1606
 
2062
1607
 
2063
- class HSwishGrad(Primitive):
2064
- """Gets the gradient of HSwish operation."""
2065
- @prim_attr_register
2066
- def __init__(self):
2067
- """Initialize HSwishGrad"""
2068
- self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
2069
-
2070
-
2071
1608
  class SigmoidCrossEntropyWithLogitsGrad(Primitive):
2072
1609
  """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
2073
1610
 
@@ -2102,19 +1639,6 @@ class SliceGrad(PrimitiveWithInfer):
2102
1639
  'value': None}
2103
1640
 
2104
1641
 
2105
- class NLLLossGrad(PrimitiveWithInfer):
2106
- """Computes the gradients of `NLLLoss`."""
2107
-
2108
- @prim_attr_register
2109
- def __init__(self, reduction="mean", ignore_index=-100):
2110
- """Initialize NLLLoss"""
2111
- self.init_prim_io_names(inputs=['x', 'loss_grad', 'target', 'weight', 'total_weight'], outputs=['x_grad'])
2112
- self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2113
- self.ignore_index = ignore_index
2114
- self.add_prim_attr('reduction', self.reduction)
2115
- self.add_prim_attr('ignore_index', self.ignore_index)
2116
-
2117
-
2118
1642
  class SmoothL1LossGrad(Primitive):
2119
1643
  """Computes gradient for prediction on SmoothL1Loss."""
2120
1644
 
@@ -2134,35 +1658,6 @@ class SoftMarginLossGrad(Primitive):
2134
1658
  self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2135
1659
 
2136
1660
 
2137
- class StridedSliceV2Grad(Primitive):
2138
- """
2139
- Performs grad of StridedSliceV2 operation.
2140
-
2141
- Inputs:
2142
- - **shapex** (Tensor) - StridedSliceV2 shape of input
2143
- - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
2144
- constant value is allowed.
2145
- - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
2146
- Only constant value is allowed.
2147
- - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
2148
- before reaching the maximum location. Only constant value is allowed.
2149
- - **dy** (Tensor) - The output of StridedSliceV2
2150
-
2151
- Outputs:
2152
- Tensor, the shape same as the input of StridedSliceV2
2153
- """
2154
-
2155
- @prim_attr_register
2156
- def __init__(self,
2157
- begin_mask=0,
2158
- end_mask=0,
2159
- ellipsis_mask=0,
2160
- new_axis_mask=0,
2161
- shrink_axis_mask=0):
2162
- """Initialize StridedSliceV2Grad"""
2163
- self.init_prim_io_names(inputs=['shapex', 'begin', 'end', 'strides', 'dy'], outputs=['output'])
2164
-
2165
-
2166
1661
  class StridedSliceGrad(Primitive):
2167
1662
  """
2168
1663
  Performs grad of StridedSlice operation.
@@ -2301,19 +1796,6 @@ class RefToEmbed(Primitive):
2301
1796
  pass
2302
1797
 
2303
1798
 
2304
- class AtanGrad(Primitive):
2305
- """
2306
- Computes AtanGrad of input element-wise.
2307
-
2308
- Returns:
2309
- Tensor, has the same type as input.
2310
- """
2311
-
2312
- @prim_attr_register
2313
- def __init__(self):
2314
- """Initialize AtanGrad"""
2315
-
2316
-
2317
1799
  class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
2318
1800
  """Computes the state gradients of BasicLSTMCell."""
2319
1801
 
@@ -2480,51 +1962,6 @@ class MvlgammaGrad(Primitive):
2480
1962
  self.p = validator.check_value_type('p', p, [int], self.name)
2481
1963
 
2482
1964
 
2483
- class MaskedSelectGrad(PrimitiveWithInfer):
2484
- """Computes gradient for MaskedSelect."""
2485
-
2486
- @prim_attr_register
2487
- def __init__(self):
2488
- pass
2489
-
2490
- def infer_shape(self, x, mask, grad):
2491
- return x
2492
-
2493
- def infer_dtype(self, x, mask, grad):
2494
- return x
2495
-
2496
-
2497
- class SoftShrinkGrad(Primitive):
2498
- r"""
2499
- Gradients for SoftShrink operation.
2500
-
2501
- Args:
2502
- lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5.
2503
-
2504
- Inputs:
2505
- - **input_grad** (Tensor) - The input gradient.
2506
- - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
2507
- Any number of additional dimensions.
2508
-
2509
- Outputs:
2510
- output - Tensor, has the same shape and data type as input_x.
2511
-
2512
- Raises:
2513
- TypeError: If lambd is not a float.
2514
- TypeError: If dtype of input_x is neither float16 nor float32.
2515
- ValueError: If lambd is less than to 0.
2516
-
2517
- Supported Platforms:
2518
- ``Ascend``
2519
- """
2520
-
2521
- @prim_attr_register
2522
- def __init__(self, lambd=0.5):
2523
- self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
2524
- validator.check_value_type("lambd", lambd, [float], self.name)
2525
- validator.check_number("lambd", lambd, 0, validator.GE, self.name)
2526
-
2527
-
2528
1965
  class CdistGrad(Primitive):
2529
1966
  """Computes gradient for Cdist."""
2530
1967
 
@@ -2616,40 +2053,6 @@ class MultilabelMarginLossGrad(Primitive):
2616
2053
  self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'is_target'], outputs=['x_grad'])
2617
2054
 
2618
2055
 
2619
- class HShrinkGrad(Primitive):
2620
- """
2621
- Computes gradients for HShrinkGrad operation.
2622
-
2623
- Args:
2624
- lambd (float): the λ value for the Hardshrink formulation. Default: 0.5
2625
-
2626
- Inputs:
2627
- - **Gradients** (Tensor) - the gradients of loss to output of HShrink function.
2628
- Currently gradients data type only support float16 and float32.
2629
- - **Features** (Tensor) - Must be the input `input_x` of the forward operator HSHrink.
2630
- Currently features data type only support float16 and float32.
2631
-
2632
- Outputs:
2633
- backprops - Tensor, with the same shape and data type as `features`.
2634
-
2635
- Rasise:
2636
- ValueError: If `lambd` is not a float.
2637
- ValueError: If shape of `gradients` is not the same as `features`.
2638
- TypeError: If dtype of `gradients` is not the same as `features`.
2639
- TypeError: If dtype of `gradients` or `features` is neither float16 nor float32.
2640
-
2641
- Supported Platforms:
2642
- ``Ascend`` ``GPU`` ``CPU``
2643
- """
2644
-
2645
- @prim_attr_register
2646
- def __init__(self, lambd=0.5):
2647
- validator.check_value_type("lambd", lambd, [float], self.name)
2648
- if lambd < 0.0:
2649
- lambd = 0.0
2650
- self.add_prim_attr('lambd', lambd)
2651
-
2652
-
2653
2056
  class Dilation2DBackpropInput(Primitive):
2654
2057
  """
2655
2058
  Computes the gradient of morphological 2-D dilation with respect to the input.
@@ -2962,6 +2365,12 @@ class MultiMarginLossGrad(Primitive):
2962
2365
  Supported Platforms:
2963
2366
  ``Ascend`` ``CPU``
2964
2367
  """
2368
+ __mindspore_signature__ = (
2369
+ sig.make_sig('y_grad'),
2370
+ sig.make_sig('x'),
2371
+ sig.make_sig('target'),
2372
+ sig.make_sig('weight', default=None)
2373
+ )
2965
2374
 
2966
2375
  @prim_attr_register
2967
2376
  def __init__(self, p=1, margin=1.0, reduction="mean"):
@@ -2972,96 +2381,8 @@ class MultiMarginLossGrad(Primitive):
2972
2381
  self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2973
2382
  self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'weight'], outputs=['x_grad'])
2974
2383
 
2975
-
2976
- class UpsampleTrilinear3DGrad(Primitive):
2977
- r"""
2978
- Upsample the 3-D gradient data with trilinear interpolation algorithm.
2979
-
2980
- Note:
2981
- One of 'scales' and 'output_size' must be specified. And it is an error if both are specified.
2982
-
2983
- Args:
2984
- align_corners (bool): An optional bool. Default: ``False``.
2985
-
2986
- Inputs:
2987
- - **dy** (Tensor) - Tensor of shape [N, C, D, H, W]. Must be one of the following types:
2988
- float16, float32, float64.
2989
- - **input_size** (Union[tuple[int], list[int]]): An required listInt which contains 5 elements:
2990
- [batch, channels, depth, height, width]. Must:
2991
- input_size[0] == dy_tensor_size[0]
2992
- input_size[1] == dy_tensor_size[1].
2993
- - **output_size** (Union[tuple[int], list[int]]): An optional listInt. Default: ``None``.
2994
- It contains 3 elements: depth, height, width, whose elements should be the same as `dy`. Must:
2995
- dy_tensor_size[2] == floor(input_size[2] * scales[0]) == output_size[0]
2996
- dy_tensor_size[3] == floor(input_size[3] * scales[1]) == output_size[1]
2997
- dy_tensor_size[4] == floor(input_size[4] * scales[2]) == output_size[2].
2998
- - **scales** (Union[tuple[float], list[float]]): An optional listFloat. Default: ``None``.
2999
- The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width.
3000
- The number of elements of 'scales' should be the same as the rank of input `dy`.
3001
-
3002
- Outputs:
3003
- - **dx** (Tensor) - A Tensor with shape depending on intput_size, and its' dtype is the same as `dy`.
3004
- """
3005
- @prim_attr_register
3006
- def __init__(self, align_corners=False):
3007
- """Initialize UpsampleTrilinear3DGrad."""
3008
- self.init_prim_io_names(
3009
- inputs=['dy', 'input_size', 'output_size', 'scales'],
3010
- outputs=['dx'])
3011
- self.align_corners = align_corners
3012
- self.add_prim_attr('align_corners', self.align_corners)
3013
-
3014
-
3015
- class GridSampler3DGrad(Primitive):
3016
- """
3017
- Computes gradients for GridSampler3D operation.
3018
-
3019
- Args:
3020
- interpolation_mode (str): An optional string specifying the interpolation method. The optional values are
3021
- "bilinear" or "nearest". Default: "bilinear".
3022
- padding_mode (str): An optional string specifying the pad method. The optional values are "zeros", "border" or
3023
- "reflection". Default: "zeros".
3024
- align_corners (bool): An optional bool. If "true", the centers of the corner pixels of the input and output
3025
- tensors are aligned. Defaults to "false".
3026
-
3027
- Inputs:
3028
- - **grad** (Tensor) - A 5-D tensor whose dtype is float32 or float64 and whose shape is :math:`(N, C, D_{out},
3029
- H_{out}, W_{out})`. The shape is inconsistent with the shape of the output result of forward calculation.
3030
- - **input_x** (Tensor) - A 5-D tensor whose dtype is the same as `grad` and whose shape is :math:`(N, C,
3031
- D_{in}, H_{in}, W_{in})`.
3032
- - **grid** (Tensor) - A 5-D tensor whose dtype is the same as `grad` and whose shape is :math:`(N, D_{out},
3033
- H_{out}, W_{out}, 3)`.
3034
-
3035
- Outputs:
3036
- - **dx** (Tensor) - A 5-D tensor whose dtype and shape are the same as `input_x`.
3037
- - **dgrid** (Tensor) - A 5-D tensor whose dtype and shape are the same as `grid`.
3038
-
3039
- Raises:
3040
- TypeError: If `grad`, `input_x` or `grid` is not a Tensor.
3041
- TypeError: If the dtypes of `grad`, `input_x` and `grid` are inconsistent.
3042
- TypeError: If the dtype of `grad`, `input_x` or `grid` is not a valid type.
3043
- TypeError: If `align_corners` is not a boolean value.
3044
- ValueError: If the rank of `grad`, `input_x` or `grid` is not equal to 5.
3045
- ValueError: If the first dimension of `grad`, `input_x` and `grid` are inconsistent.
3046
- ValueError: If the last dimension of `grid` is not equal to 3.
3047
- ValueError: If `interpolation_mode` is not "bilinear", "nearest" or a string value.
3048
- ValueError: If `padding_mode` is not "zeros", "border", "reflection" or a string value.
3049
- ValueError: If the shape of `grad` is inconsistent with the shape of the output result of forward calculation.
3050
-
3051
- Supported Platforms:
3052
- ``GPU`` ``CPU``
3053
- """
3054
-
3055
- @prim_attr_register
3056
- def __init__(self, interpolation_mode='bilinear', padding_mode='zeros', align_corners=False):
3057
- """Initialize GridSampler3DGrad."""
3058
- validator.check_string(interpolation_mode, ['bilinear', 'nearest'], 'interpolation_mode', self.name)
3059
- validator.check_string(padding_mode, ['zeros', 'border', 'reflection'], 'padding_mode', self.name)
3060
- validator.check_bool(align_corners, 'align_corners', self.name)
3061
- self.init_prim_io_names(inputs=['grad', 'input_x', 'grid'], outputs=['dx', 'dgrid'])
3062
- self.add_prim_attr('interpolation_mode', interpolation_mode)
3063
- self.add_prim_attr('padding_mode', padding_mode)
3064
- self.add_prim_attr('align_corners', align_corners)
2384
+ def __call__(self, y_grad, x, target, weight=None):
2385
+ return super().__call__(y_grad, x, target, weight)
3065
2386
 
3066
2387
 
3067
2388
  class SparseSegmentMeanGrad(Primitive):
@@ -3466,136 +2787,6 @@ class SparseSegmentSqrtNGrad(Primitive):
3466
2787
  self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
3467
2788
 
3468
2789
 
3469
- class GridSampler2DGrad(Primitive):
3470
- """
3471
- Computes gradients for GridSampler2D operation.
3472
-
3473
- Args:
3474
- interpolation_mode (str): An optional string specifying the interpolation method. The optional values are
3475
- "bilinear" or "nearest". Default: "bilinear".
3476
- padding_mode (str): An optional string specifying the pad method. The optional values are "zeros", "border" or
3477
- "reflection". Default: "zeros".
3478
- align_corners (bool): An optional bool. If "true", the centers of the corner pixels of the input and output
3479
- tensors are aligned. Defaults to "false".
3480
-
3481
- Inputs:
3482
- - **grad** (Tensor) - A 4-D tensor whose dtype is float16 or float32 and whose shape is :math:`(N, C,
3483
- H_{out}, W_{out})`. The shape is inconsistent with the shape of the output result of forward calculation.
3484
- - **input_x** (Tensor) - A 4-D tensor whose dtype is the same as `grad` and whose shape is :math:`(N, C,
3485
- H_{in}, W_{in})`.
3486
- - **grid** (Tensor) - A 4-D tensor whose dtype is the same as `grad` and whose
3487
- shape is :math:`(N, H_{out}, W_{out}, 2)`.
3488
-
3489
- Outputs:
3490
- - **dx** (Tensor) - A 4-D tensor whose dtype and shape are the same as `input_x`.
3491
- - **dgrid** (Tensor) - A 4-D tensor whose dtype and shape are the same as `grid`.
3492
-
3493
- Raises:
3494
- TypeError: If `grad`, `input_x` or `grid` is not a Tensor.
3495
- TypeError: If the dtypes of `grad`, `input_x` and `grid` are inconsistent.
3496
- TypeError: If the dtype of `grad`, `input_x` or `grid` is not a valid type.
3497
- TypeError: If `align_corners` is not a boolean value.
3498
- ValueError: If the rank of `grad`, `input_x` or `grid` is not equal to 4.
3499
- ValueError: If the first dimension of `grad`, `input_x` and `grid` are inconsistent.
3500
- ValueError: If the last dimension of `grid` is not equal to 2.
3501
- ValueError: If `interpolation_mode` is not "bilinear", "nearest" or a string value.
3502
- ValueError: If `padding_mode` is not "zeros", "border", "reflection" or a string value.
3503
- ValueError: If the shape of `grad` is inconsistent with the shape of the output result of forward calculation.
3504
-
3505
- Supported Platforms:
3506
- ``GPU`` ``CPU``
3507
- """
3508
-
3509
- @prim_attr_register
3510
- def __init__(self, interpolation_mode='bilinear', padding_mode='zeros', align_corners=False):
3511
- """Initialize GridSampler2DGrad."""
3512
- validator.check_string(interpolation_mode, ['bilinear', 'nearest'], 'interpolation_mode', self.name)
3513
- validator.check_string(padding_mode, ['zeros', 'border', 'reflection'], 'padding_mode', self.name)
3514
- validator.check_bool(align_corners, 'align_corners', self.name)
3515
- self.init_prim_io_names(inputs=['grad', 'input_x', 'grid'], outputs=['dx', 'dgrid'])
3516
- self.add_prim_attr('interpolation_mode', interpolation_mode)
3517
- self.add_prim_attr('padding_mode', padding_mode)
3518
- self.add_prim_attr('align_corners', align_corners)
3519
-
3520
-
3521
- class ResizeBicubicGrad(Primitive):
3522
- """
3523
- Computes gradients for ResizeBicubicGrad operation.
3524
-
3525
- Args:
3526
- align_corners (bool):If true, the centers of the 4 corner pixels of the input
3527
- and output tensors are aligned, preserving the values at the corner pixels.Default: ``False``.
3528
- half_pixel_centers (bool): An optional bool. Default: ``False``.
3529
-
3530
- Inputs:
3531
- - **grads** (Tensor) - A Tensor of type float. 4-D with shape
3532
- [batch, height, width,channels]. The format must be NHWC.
3533
- - **original_image** (Tensor) - A Tensor. Must be one of the following types: float,double.
3534
- 4-D with shape [batch, orig_height, orig_width, channels], The image tensor that was resized.
3535
- The format must be NHWC.
3536
-
3537
- Outputs:
3538
- A 4-D Tensor , with the same shape and data type as `original_image`.
3539
-
3540
- Rasise:
3541
- TypeError: If `grads` is not allowed.
3542
- TypeError: If `original_image` is not allowed.
3543
- ValueError: If `images` dim is not 4.
3544
- ValueError: If `size` dim is not 4.
3545
-
3546
- Supported Platforms:
3547
- ``Ascend`` ``GPU`` ``CPU``
3548
- """
3549
- @prim_attr_register
3550
- def __init__(self, align_corners=False, half_pixel_centers=False):
3551
- """Initialize CropAndResize"""
3552
- validator.check_value_type('align_corners', align_corners, bool, self.name)
3553
- validator.check_value_type('half_pixel_centers', half_pixel_centers, bool, self.name)
3554
- self.init_prim_io_names(inputs=['grads', 'original_image'], outputs=['y'])
3555
-
3556
- def __infer__(self, grads, original_image):
3557
- # get shape
3558
- grads_shape = list(grads['shape'])
3559
- original_image_shape = list(original_image['shape'])
3560
- # get value
3561
- if grads['value'] is None:
3562
- raise ValueError(
3563
- f"For '{self.name}', the 'grads' cannot be None, but got {grads['value']}."
3564
- )
3565
- if original_image['value'] is None:
3566
- raise ValueError(
3567
- f"For '{self.name}', the 'original_image' cannot be None, but got {original_image['value']}."
3568
- )
3569
- # get dtype
3570
- grads_dtype = grads['dtype']
3571
- original_image_dtype = original_image['dtype']
3572
- # check dytpe
3573
- validator.check_tensor_dtype_valid("grads", grads_dtype,
3574
- [mstype.float32], self.name)
3575
- validator.check_tensor_dtype_valid("original_image", original_image_dtype,
3576
- [mstype.float32, mstype.float64], self.name)
3577
- # check input shape rank
3578
- validator.check("grads rank", len(grads_shape), "expected", 4, validator.EQ, self.name)
3579
- validator.check("original_image rank", len(original_image_shape), "expected", 4, validator.EQ, self.name)
3580
- validator.check("batch_size equal", grads_shape[0], "expected",
3581
- original_image_shape[0], validator.EQ, self.name)
3582
- validator.check("channel equal", grads_shape[3], "expected", original_image_shape[3], validator.EQ, self.name)
3583
- # check original_image_shape and grads_shape
3584
- validator.check("original_image[0] and grads[0]", original_image_shape[0],
3585
- "expected", grads_shape[0], validator.EQ, self.name)
3586
- validator.check("original_image[3] and grads[3]", original_image_shape[3],
3587
- "expected", grads_shape[3], validator.EQ, self.name)
3588
-
3589
- batch_size = grads_shape[0]
3590
- height = original_image_shape[1]
3591
- width = original_image_shape[2]
3592
- channel = grads_shape[3]
3593
- out_shape = (batch_size, height, width, channel)
3594
- return {'shape': out_shape,
3595
- 'dtype': original_image_dtype,
3596
- 'value': None}
3597
-
3598
-
3599
2790
  class SparseSliceGrad(Primitive):
3600
2791
  r"""
3601
2792
  Computes gradients for SparseSlice operation.
@@ -3717,13 +2908,6 @@ class AffineGridGrad(Primitive):
3717
2908
  self.init_prim_io_names(inputs=['y_grad', 'x_size'], outputs=['x_grad'])
3718
2909
 
3719
2910
 
3720
- class HSigmoidGrad(Primitive):
3721
- """Gets the gradient of HSigmoid operation."""
3722
- @prim_attr_register
3723
- def __init__(self):
3724
- """Initialize HSigmoidGrad"""
3725
- self.init_prim_io_names(inputs=['grads', 'input_x'], outputs=['output'])
3726
-
3727
2911
 
3728
2912
  class GluGrad(Primitive):
3729
2913
  """
@@ -3737,46 +2921,6 @@ class GluGrad(Primitive):
3737
2921
  validator.check_value_type("axis", axis, [int], self.name)
3738
2922
 
3739
2923
 
3740
- class CholeskyGrad(Primitive):
3741
- r"""
3742
- Computes the reverse mode backpropgated gradient of the Cholesky algorithm.
3743
-
3744
- Inputs:
3745
- - **x** (Tensor) - A tensor with float32 or float64 data type.
3746
- - **grad** (Tensor) - A tensor with float32 or float64 data type. `x` should have
3747
- the same dtype with `a`.
3748
-
3749
- Outputs:
3750
- Tensor, has the same dtype as `a` and `x`.
3751
-
3752
- Raises:
3753
- TypeError: If x is not Tensor.
3754
- TypeError: If grad is not Tensor.
3755
- TypeError: If dtype of input x and grad is not float64 nor float32,
3756
- TypeError: If x has different dtype with grad.
3757
- ValueError: If input tensor's last two dims are not equal,
3758
- ValueError: If the shape of x and grad mismatch.
3759
-
3760
- Supported Platforms:
3761
- ``Ascend``
3762
-
3763
- Examples:
3764
- >>> x = Tensor(np.array([[4, 2],[2, 3]]), mstype.float64)
3765
- >>> grad = Tensor(np.array([[4, 2],[2, 3]]), mstype.float64)
3766
- >>> choleskygrad = G.CholeskyGrad()
3767
- >>> output = choleskygrad(x, grad)
3768
- >>> print (output)
3769
- [[0.5 0. ]
3770
- [0. 0.5]]
3771
-
3772
- """
3773
-
3774
- @prim_attr_register
3775
- def __init__(self):
3776
- """Initialize CholeskyGrad"""
3777
- self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
3778
-
3779
-
3780
2924
  class MapTensorGetGrad(Primitive):
3781
2925
  """
3782
2926
  Computes gradients for MapTensorGet operation.
@@ -3832,53 +2976,3 @@ class WKVGrad(Primitive):
3832
2976
  """Initialize WKVGrad."""
3833
2977
  self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
3834
2978
  outputs=["gw", "gu", "gk", "gv"])
3835
-
3836
-
3837
- class FlashAttentionScoreGrad(Primitive):
3838
- r"""
3839
- Calculates the gradient of FlashAttentionScore operation.
3840
- .. warning::
3841
- This is an experimental API that is subject to change or deletion.
3842
-
3843
- Supported Platforms:
3844
- ``Ascend``
3845
- """
3846
- @prim_attr_register
3847
- def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
3848
- input_layout='BSH', sparse_mode=0):
3849
- """Initialize FlashAttentionScoreGrad."""
3850
- validator.check_value_type('head_num', head_num, [int], self.name)
3851
- validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
3852
- validator.check_float(keep_prob, 0.0, validator.GE, "keep_prob", self.name)
3853
- validator.check_float(keep_prob, 1.0, validator.LE, "keep_prob", self.name)
3854
- validator.check_value_type('scale_value', scale_value, [float], self.name)
3855
- validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
3856
- validator.check_value_type('next_tokens', next_tokens, [int], self.name)
3857
- validator.check_value_type('inner_precise', inner_precise, [int], self.name)
3858
- validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
3859
- if inner_precise not in [0, 1]:
3860
- raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
3861
- validator.check_value_type('input_layout', input_layout, [str], self.name)
3862
- if input_layout not in ["BSH", "BNSD"]:
3863
- raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
3864
- self.init_prim_io_names(inputs=['query', 'key', 'value', 'dy', 'pse_shift', 'drop_mask', "padding_mask",
3865
- 'attn_mask', 'softmax_max', 'softmax_sum', 'softmax_out', 'attention_in',
3866
- 'prefix'],
3867
- outputs=['dq', 'dk', 'dv', 'dpse'])
3868
-
3869
-
3870
- class RmsNormGrad(Primitive):
3871
- r"""
3872
- Calculates the gradient of RmsNorm operation.
3873
- .. warning::
3874
- This is an experimental API that is subject to change or deletion.
3875
-
3876
- Supported Platforms:
3877
- ``Ascend``
3878
- """
3879
-
3880
- @prim_attr_register
3881
- def __init__(self):
3882
- """Initialize RmsNormGrad."""
3883
- self.init_prim_io_names(inputs=["dy", "x", "rstd", "gamma"],
3884
- outputs=["dx", "dgamma"])