mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2978 @@
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Operators for gradients."""
17
+ # pylint: disable=unused-import
18
+ from __future__ import absolute_import
19
+
20
+ from __future__ import division
21
+ from mindspore._checkparam import _check_3d_int_or_tuple
22
+ from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple
23
+ from mindspore.ops import signature as sig
24
+ from mindspore.ops._utils import get_concat_offset
25
+ from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
26
+ import mindspore.context as context
27
+ from mindspore import _checkparam as validator
28
+ from mindspore.common import dtype as mstype
29
+ from mindspore.communication.management import GlobalComm
30
+ from mindspore.common._utils import is_shape_unknown, is_dim_unknown
31
+ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad, AsinhGrad, ReciprocalGrad, RsqrtGrad,
32
+ SqrtGrad, BatchNormGrad, BatchNormGradGrad, BiasAddGrad, GeLUGrad, FastGeLUGrad,
33
+ AvgPoolGrad, MinimumGrad, LogSoftmaxGrad, PReLUGrad, ReluGrad, ReLU6Grad, EluGrad,
34
+ GatherDGradV2, ResizeBilinearGrad, ResizeLinear1DGrad, ResizeNearestNeighborV2Grad,
35
+ SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
36
+ ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
37
+ HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
38
+ FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
39
+ BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad)
40
+
41
+
42
+ class SparseFillEmptyRowsGrad(Primitive):
43
+ """Performs grad of SparseFillEmptyRows operation."""
44
+
45
+ @prim_attr_register
46
+ def __init__(self):
47
+ """Initialize SparseFillEmptyRowsGrad."""
48
+ self.init_prim_io_names(inputs=['reverse_index_map', 'grad_values'],
49
+ outputs=['y_values', 'y_default_value'])
50
+
51
+
52
+ class ScaleAndTranslateGrad(Primitive):
53
+ """Performs grad of ScaleAndTranslate operation."""
54
+
55
+ @prim_attr_register
56
+ def __init__(self, kernel_type="lanczos3", antialias=True):
57
+ """Initialize ScaleAndTranslateGrad"""
58
+ validator.check_value_type("kernel_type", kernel_type, [str], self.name)
59
+ validator.check_string(kernel_type, ["lanczos1", "lanczos3", "lanczos5", "gaussian", "box", "triangle",
60
+ "keyscubic", "mitchellcubic"], "kernel_type", self.name)
61
+ validator.check_value_type("antialias", antialias, [bool], self.name)
62
+
63
+
64
+ class SoftmaxGrad(Primitive):
65
+ """Performs grad of Softmax operation."""
66
+
67
+ @prim_attr_register
68
+ def __init__(self):
69
+ """Initialize SoftmaxGrad"""
70
+ self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
71
+
72
+
73
+ class SyncBatchNormGrad(Primitive):
74
+ """Performs grad of SyncBatchNorm operation."""
75
+
76
+ @prim_attr_register
77
+ def __init__(self, epsilon=1e-5, group="group0", device_num=2):
78
+ validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
79
+ if not isinstance(group, str):
80
+ raise TypeError("The group attr of SyncBatchNormGrad must be str.")
81
+ validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
82
+
83
+
84
+ class KLDivLossGrad(Primitive):
85
+ """Computes gradients for `KLDivLoss` operation."""
86
+
87
+ @prim_attr_register
88
+ def __init__(self, reduction='mean'):
89
+ device_target = context.get_context("device_target")
90
+ if device_target == "CPU":
91
+ support_mode = ['none', 'mean', 'batchmean', 'sum']
92
+ elif device_target == "GPU":
93
+ support_mode = ['none', 'mean', 'sum']
94
+ elif device_target == "Ascend":
95
+ support_mode = ['none', 'mean', 'batchmean', 'sum']
96
+ else:
97
+ raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
98
+ self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
99
+
100
+
101
+ class LuUnpackGrad(Primitive):
102
+ """Computes gradients for `LuUnpack` operation."""
103
+
104
+ @prim_attr_register
105
+ def __init__(self, L_grad_flag, U_grad_flag):
106
+ validator.check_value_type("L_grad_flag", L_grad_flag, [bool], self.name)
107
+ validator.check_value_type("U_grad_flag", U_grad_flag, [bool], self.name)
108
+ self.add_prim_attr("cust_aicpu", self.name)
109
+
110
+
111
+ class ConcatOffset(PrimitiveWithInfer):
112
+ """primitive for computing Concat's gradient."""
113
+
114
+ @prim_attr_register
115
+ def __init__(self, N=2, axis=0):
116
+ """Initialize ConcatOffset"""
117
+
118
+ def __infer__(self, input_x):
119
+ axis = self.axis
120
+ x_shp = input_x['shape']
121
+ x_type = input_x['dtype']
122
+ self.add_prim_attr('T', x_type[0].element_type())
123
+
124
+ # input_x is dynamic rank
125
+ rank = -1
126
+ is_dyn_rank = False
127
+ for _, sh in enumerate(x_shp):
128
+ if is_dim_unknown(sh):
129
+ is_dyn_rank = True
130
+ else:
131
+ rank = len(sh)
132
+ if is_dyn_rank:
133
+ return {
134
+ 'shape': [len(x_shp), rank],
135
+ 'dtype': mstype.int64,
136
+ 'value': None
137
+ }
138
+
139
+ # if the dimension of input_x on the axis is dynamic
140
+ if axis < -rank or axis >= rank:
141
+ raise ValueError("For 'ConcatOffset', 'axis' must be in range [{}, {}), but got {}"
142
+ .format(-rank, rank, axis))
143
+ if axis < 0:
144
+ axis = axis + rank
145
+ for each in x_shp:
146
+ if each[axis] == -1:
147
+ return {
148
+ 'shape': [len(x_shp), len(x_shp[0])],
149
+ 'dtype': mstype.int64,
150
+ 'value': None
151
+ }
152
+
153
+ offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
154
+ offset_values = []
155
+ for i in range(len(x_shp)):
156
+ values = []
157
+ for j in range(len(x_shp[0])):
158
+ value = 0
159
+ if j == axis:
160
+ value = offset[i]
161
+ values.append(value)
162
+ offset_values.append(tuple(values))
163
+ out = {'shape': None,
164
+ 'dtype': None,
165
+ 'value': tuple(offset_values)}
166
+ return out
167
+
168
+
169
+ class Conv3DBackpropFilter(Primitive):
170
+ """
171
+ Computes the gradients of convolution 3D with respect to the filter.
172
+
173
+ Args:
174
+ out_channel (int): The dimension of the output.
175
+ kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
176
+ mode (int): Modes for different convolutions. Not currently used.
177
+ pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
178
+ pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
179
+ head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
180
+ integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
181
+ pad[3], pad[4] and pad[5] correspondingly.
182
+ stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
183
+ dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
184
+ group (int): Splits input into groups. Default: 1.
185
+ data_format (str): The optional value for data format. Currently only support 'NCDHW'.
186
+
187
+ Inputs:
188
+ - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
189
+ Currently dout data type only support float16 and float32.
190
+ - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default
191
+ data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16
192
+ and float32.
193
+ - **w_size** (tuple(int)) - A tuple describes the shape of the weight which conforms to the format
194
+ :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
195
+
196
+ Outputs:
197
+ Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight.
198
+
199
+ Supported Platforms:
200
+ ``Ascend``
201
+
202
+ Examples:
203
+ >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16)
204
+ >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
205
+ >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
206
+ >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
207
+ >>> output = conv3d_backprop_input(x, dout, F.shape(w))
208
+ >>> print(output.shape)
209
+ (32, 32, 4, 6, 2)
210
+ """
211
+
212
+ @prim_attr_register
213
+ def __init__(self,
214
+ out_channel,
215
+ kernel_size,
216
+ mode=1,
217
+ pad_mode="valid",
218
+ pad=0,
219
+ stride=(1, 1, 1, 1, 1),
220
+ dilation=(1, 1, 1, 1, 1),
221
+ group=1,
222
+ data_format="NCDHW"):
223
+ """Initialize Convolution"""
224
+ self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y'])
225
+ self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
226
+ self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
227
+ self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
228
+ self.add_prim_attr('strides', self.stride)
229
+ self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
230
+ self.add_prim_attr('dilations', self.dilation)
231
+ validator.check_value_type('pad', pad, (int, tuple), self.name)
232
+ if isinstance(pad, int):
233
+ pad = (pad,) * 6
234
+ validator.check_equal_int(len(pad), 6, 'pad size', self.name)
235
+ self.add_prim_attr('pad', pad)
236
+ self.pad_list = pad
237
+ self.add_prim_attr('pad_list', self.pad_list)
238
+
239
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
240
+ self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
241
+ if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
242
+ raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode must be set as 'pad'.")
243
+ if self.pad_mode == 'pad':
244
+ for item in pad:
245
+ validator.check_non_negative_int(item, 'pad item', self.name)
246
+ self.add_prim_attr('pad_mode', self.pad_mode)
247
+
248
+ self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
249
+ self.add_prim_attr('mode', self.mode)
250
+ self.group = validator.check_positive_int(group, 'group', self.name)
251
+ self.add_prim_attr('groups', self.group)
252
+ self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
253
+ self.add_prim_attr('data_format', self.format)
254
+
255
+
256
+ class Conv2DBackpropFilter(Primitive):
257
+ """
258
+ Computes the gradients of convolution with respect to the filter.
259
+
260
+ Args:
261
+ out_channel (int): The dimensionality of the output space.
262
+ kernel_size (Union[int, tuple[int]]): The size of the convolution window.
263
+ pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
264
+ pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
265
+ top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
266
+ padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
267
+ pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
268
+ mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
269
+ 2 deconvolution, 3 depthwise convolution. Default: 1.
270
+ stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
271
+ dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
272
+ group (int): Splits input into groups. Default: 1.
273
+ data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
274
+ default is 'NCHW'.
275
+
276
+ Returns:
277
+ Tensor, the gradients of convolution.
278
+ """
279
+
280
+ @prim_attr_register
281
+ def __init__(self,
282
+ out_channel,
283
+ kernel_size,
284
+ pad_mode="valid",
285
+ pad=0,
286
+ pad_list=(0, 0, 0, 0),
287
+ mode=1,
288
+ stride=(1, 1),
289
+ dilation=(1, 1, 1, 1),
290
+ group=1,
291
+ data_format="NCHW"):
292
+ """Initialize Convolution"""
293
+ self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
294
+ self.out_channel = out_channel
295
+ self.kernel_size = kernel_size
296
+ self.mode = mode
297
+ pad_mode = pad_mode.upper()
298
+ self.add_prim_attr('pad_mode', pad_mode)
299
+ if isinstance(pad, int):
300
+ pad = (pad,) * 4
301
+ else:
302
+ validator.check_equal_int(len(pad), 4, 'pad size', self.name)
303
+ self.add_prim_attr("pad", pad)
304
+ self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
305
+ if context.get_context("device_target") != "GPU" and self.format == "NHWC":
306
+ raise ValueError("NHWC format only support in GPU target.")
307
+ self.add_prim_attr('data_format', self.format)
308
+ self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
309
+ self.add_prim_attr('stride', self.stride)
310
+ self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
311
+ self.add_prim_attr('dilation', self.dilation)
312
+ self.group = group
313
+ self.add_prim_attr('groups', group)
314
+ if pad_list:
315
+ for x in pad_list:
316
+ if x != -1:
317
+ validator.check_non_negative_int(x, 'element of pad_list', self.name)
318
+ self.pad_list = pad_list
319
+
320
+
321
+ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
322
+ """
323
+ Returns the gradient of filter for DepthwiseConv2dNative.
324
+
325
+ Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
326
+
327
+ Refer to class DepthwiseConv2dNative for more details.
328
+
329
+ Args:
330
+ channel_multiplier (int): The multiplier for the original output conv.
331
+ kernel_size (int or tuple): The size of the conv kernel.
332
+ mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution,
333
+ 2 deconvolution,3 depthwise convolution. Default: 3.
334
+ pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
335
+ pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
336
+ top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
337
+ padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
338
+ pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
339
+ stride (int): The stride to be applied to the convolution filter. Default: 1.
340
+ dilation (int): Specifies the space to use between kernel elements. Default: 1.
341
+ group (int): Splits input into groups. Default: 1.
342
+
343
+ Returns:
344
+ Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
345
+ """
346
+
347
+ @prim_attr_register
348
+ def __init__(self,
349
+ channel_multiplier,
350
+ kernel_size,
351
+ pad_mode="valid",
352
+ pad=0,
353
+ pad_list=(0, 0, 0, 0),
354
+ mode=3,
355
+ stride=1,
356
+ dilation=1,
357
+ group=1):
358
+ """Initialize Convolution"""
359
+ self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
360
+ self.channel_multiplier = channel_multiplier
361
+ self.kernel_size = kernel_size
362
+ self.mode = mode
363
+ self.pad_mode = pad_mode
364
+ if isinstance(pad, int):
365
+ pad = (pad,) * 4
366
+ else:
367
+ validator.check_equal_int(len(pad), 4, 'pad size', self.name)
368
+ self.add_prim_attr("pad", pad)
369
+ self.pad_list = pad_list
370
+ self.stride = stride
371
+ self.dilation = dilation
372
+ self.group = group
373
+ self.add_prim_attr('data_format', "NCHW")
374
+
375
+ def __infer__(self, x, w_size, dout):
376
+ w_size_v = w_size['value']
377
+ args = {'x': x['dtype'], 'dout': dout['dtype']}
378
+ validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
379
+ out = {
380
+ 'value': None,
381
+ 'shape': w_size_v,
382
+ 'dtype': dout['dtype'],
383
+ }
384
+ return out
385
+
386
+
387
+ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
388
+ """
389
+ Returns the gradient of input for DepthwiseConv2dNative.
390
+
391
+ Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
392
+
393
+ Args:
394
+ channel_multiplier (int): The multiplier for the original output conv.
395
+ kernel_size (int or tuple): The size of the conv kernel.
396
+ mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
397
+ 2 deconvolution,3 depthwise convolution. Default: 3.
398
+ pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
399
+ pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
400
+ top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
401
+ padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
402
+ pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
403
+ stride (int): The stride to be applied to the convolution filter. Default: 1.
404
+ dilation (int): Specifies the space to use between kernel elements. Default: 1.
405
+ group (int): Splits input into groups. Default: 1.
406
+
407
+ Returns:
408
+ Tensor, the value is the gradient of input for DepthwiseConv2dNative.
409
+ """
410
+
411
+ @prim_attr_register
412
+ def __init__(self,
413
+ channel_multiplier,
414
+ kernel_size,
415
+ pad_mode="valid",
416
+ pad=0,
417
+ pad_list=(0, 0, 0, 0),
418
+ mode=3,
419
+ stride=1,
420
+ dilation=1,
421
+ group=1):
422
+ """Initialize Convolution"""
423
+ self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
424
+ self.channel_multiplier = channel_multiplier
425
+ self.kernel_size = kernel_size
426
+ self.mode = mode
427
+ self.pad_mode = pad_mode
428
+ if isinstance(pad, int):
429
+ pad = (pad,) * 4
430
+ else:
431
+ validator.check_equal_int(len(pad), 4, 'pad size', self.name)
432
+ self.add_prim_attr("pad", pad)
433
+ self.pad_list = pad_list
434
+ self.stride = stride
435
+ self.dilation = dilation
436
+ self.group = group
437
+ self.add_prim_attr('data_format', "NCHW")
438
+
439
+ def __infer__(self, x_size, w, dout):
440
+ args = {'w': w['dtype'], 'dout': dout['dtype']}
441
+ validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
442
+ x_size_v = x_size['value']
443
+ out = {
444
+ 'value': None,
445
+ 'shape': x_size_v,
446
+ 'dtype': dout['dtype'],
447
+ }
448
+ return out
449
+
450
+
451
+ class DropoutGrad(Primitive):
452
+ """
453
+ The gradient of Dropout. During training, randomly zeroes some of the elements
454
+ of the input tensor with probability.
455
+
456
+ Args:
457
+ keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
458
+ means dropping out 10% of input units. Default: 0.5.
459
+
460
+ Inputs:
461
+ - **shape** (tuple[int]) - The shape of target mask.
462
+
463
+ Outputs:
464
+ Tensor, the value of generated mask for input shape.
465
+
466
+ Examples:
467
+ >>> dropout_grad = ops.DropoutGrad(keep_prob=0.5)
468
+ >>> in = Tensor((20, 16, 50, 50))
469
+ >>> out = dropout_grad(in)
470
+ """
471
+
472
+ @prim_attr_register
473
+ def __init__(self, keep_prob=0.5):
474
+ self.keep_prob = validator.check_float_range(keep_prob, 0, 1, validator.INC_RIGHT, "keep_prob", self.name)
475
+
476
+
477
+ class FlattenGrad(PrimitiveWithInfer):
478
+ """Performs gradients of Flatten."""
479
+
480
+ @prim_attr_register
481
+ def __init__(self):
482
+ self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
483
+
484
+ def __infer__(self, *args):
485
+ out = {
486
+ 'value': None,
487
+ 'shape': args[1]['value'],
488
+ 'dtype': args[0]['dtype'],
489
+ }
490
+ return out
491
+
492
+
493
+ class InstanceNormGrad(PrimitiveWithInfer):
494
+ """Gradients of InstanceNorm operation."""
495
+
496
+ @prim_attr_register
497
+ def __init__(self, epsilon=0.0, momentum=0.1):
498
+ self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
499
+ outputs=['dx', 'bn_gamma', 'bn_beta'])
500
+
501
+
502
+ class InstanceNormV2Grad(Primitive):
503
+ """Gradients of InstanceNormV2 operation."""
504
+
505
+ @prim_attr_register
506
+ def __init__(self, is_training=True, epsilon=1e-5):
507
+ self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'mean', 'variance', 'save_mean', 'save_variance'],
508
+ outputs=['pd_x', 'pd_gamma', 'pd_beta'])
509
+ validator.check_is_float(epsilon, 'epsilon', self.name)
510
+ validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
511
+ validator.check_bool(is_training, "is_training", self.name)
512
+
513
+
514
+ class EinsumGrad(PrimitiveWithInfer):
515
+ """Gradients of Einsum."""
516
+
517
+ @prim_attr_register
518
+ def __init__(self, equation):
519
+ pass
520
+
521
+ def infer_shape(self, x_shapes, dout_shape):
522
+ out_shape = ()
523
+ for dim in x_shapes:
524
+ out_shape += (dim,)
525
+ return out_shape
526
+
527
+ def infer_dtype(self, x_types, dout_shape):
528
+ out_type = ()
529
+ for cur_type in x_types:
530
+ out_type += (cur_type,)
531
+ return out_type
532
+
533
+
534
+ class UniqueGrad(Primitive):
535
+ """Gradients of Unique operation."""
536
+
537
+ @prim_attr_register
538
+ def __init__(self):
539
+ self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
540
+
541
+
542
+ class BNTrainingReduceGrad(Primitive):
543
+ """Gradients of FusedBatchNorm operation."""
544
+
545
+ @prim_attr_register
546
+ def __init__(self, epsilon=0.0001, data_format='NCHW'):
547
+ self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
548
+ _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
549
+ self.init_prim_io_names(inputs=_inputs, outputs=['y'])
550
+
551
+
552
+ class BNTrainingUpdateGrad(Primitive):
553
+ """Gradients of FusedBatchNorm operation."""
554
+
555
+ @prim_attr_register
556
+ def __init__(self, epsilon=0.0001, data_format='NCHW'):
557
+ self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
558
+ self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
559
+ outputs=['diff_scale', 'diff_offset'])
560
+
561
+
562
+ class NeighborExchangeV2Grad(PrimitiveWithInfer):
563
+ """"Gradients of NeighborExchangeV2 operation."""
564
+
565
+ @prim_attr_register
566
+ def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
567
+ group=GlobalComm.WORLD_COMM_GROUP):
568
+ self.init_prim_io_names(inputs=['dy'], outputs=['dx'])
569
+ self.send_rank_ids = send_rank_ids
570
+ self.recv_rank_ids = recv_rank_ids
571
+ self.send_lens = send_lens
572
+ self.recv_lens = recv_lens
573
+ self.format = validator.check_string(data_format, ['NCHW'], 'format', self.name)
574
+ self.add_prim_attr('no_elimilate', True)
575
+
576
+ def __infer__(self, dy):
577
+ dy_shape = dy['shape']
578
+ validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, validator.EQ, self.name)
579
+ if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1:
580
+ dy_shape[3] -= self.send_lens[2]
581
+
582
+ if self.send_rank_ids[1] != -1 or self.send_rank_ids[2] != -1 or self.send_rank_ids[3] != -1:
583
+ dy_shape[3] -= self.send_lens[3]
584
+
585
+ if self.send_rank_ids[0] != -1 or self.send_rank_ids[1] != -1 or self.send_rank_ids[7] != -1:
586
+ dy_shape[2] -= self.send_lens[0]
587
+
588
+ if self.send_rank_ids[3] != -1 or self.send_rank_ids[4] != -1 or self.send_rank_ids[5] != -1:
589
+ dy_shape[2] -= self.send_lens[1]
590
+
591
+ return {'shape': dy_shape,
592
+ 'dtype': dy['dtype'],
593
+ 'value': None}
594
+
595
+
596
+ class _PoolGrad(PrimitiveWithInfer):
597
+ """Gradients of the max/avg pool operation."""
598
+
599
+ @prim_attr_register
600
+ def __init__(self, kernel_size, strides, pad_mode="VALID", data_format="NCHW"):
601
+ self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
602
+
603
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
604
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
605
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
606
+ self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
607
+ self.add_prim_attr("pad_mode", self.pad_mode)
608
+ self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
609
+ if context.get_context("device_target") != "GPU" and self.format == "NHWC":
610
+ raise ValueError("NHWC format only support in GPU target.")
611
+ self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
612
+ if not self.is_maxpoolgradwithargmax:
613
+ self.add_prim_attr('data_format', self.format)
614
+
615
+ def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
616
+ validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
617
+ error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
618
+ f"or a tuple of two or four positive int numbers, but got {arg_val}")
619
+ if isinstance(arg_val, int):
620
+ ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
621
+ elif len(arg_val) == 2:
622
+ ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
623
+ elif len(arg_val) == 4:
624
+ ret = arg_val
625
+ else:
626
+ raise error_msg
627
+ # whether all elements of tuple are positive integers
628
+ for item in ret:
629
+ if not isinstance(item, int) or item <= 0:
630
+ raise error_msg
631
+ return ret
632
+
633
+ kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size, self.is_maxpoolgradwithargmax)
634
+ strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
635
+ if self.format == "NCHW":
636
+ self.kernel_size = kernel_size
637
+ self.strides = strides
638
+ else:
639
+ self.kernel_size = [kernel_size[0], kernel_size[2], kernel_size[3], kernel_size[1]]
640
+ self.strides = [strides[0], strides[2], strides[3], strides[1]]
641
+ self.add_prim_attr("kernel_size", self.kernel_size)
642
+ self.add_prim_attr("strides", self.strides)
643
+
644
+
645
+ class AvgPoolGradVm(_PoolGrad):
646
+ """Gradients of the avg pool operation for vm."""
647
+
648
+ @prim_attr_register
649
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
650
+ super(AvgPoolGradVm, self).__init__(kernel_size, strides, pad_mode)
651
+ self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
652
+
653
+ def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
654
+ out = {
655
+ 'value': None,
656
+ 'shape': tuple(origin_input['value']),
657
+ 'dtype': dout['dtype'],
658
+ }
659
+
660
+ return out
661
+
662
+
663
+ class AvgPoolGradGe(_PoolGrad):
664
+ """Gradients of the avg pool operation for ge."""
665
+
666
+ @prim_attr_register
667
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
668
+ super(AvgPoolGradGe, self).__init__(kernel_size, strides, pad_mode, data_format)
669
+
670
+ def __infer__(self, origin_input, dout):
671
+ out = {
672
+ 'value': None,
673
+ 'shape': tuple(origin_input['value']),
674
+ 'dtype': dout['dtype'],
675
+ }
676
+
677
+ return out
678
+
679
+
680
+ class AvgPoolGradV1(Primitive):
681
+ """Gradients of the AvgPoolV1 operation."""
682
+
683
+ @prim_attr_register
684
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
685
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
686
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
687
+ self.pad_mode = validator.check_string(
688
+ pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
689
+ self.add_prim_attr("pad_mode", self.pad_mode)
690
+ self.format = validator.check_string(
691
+ data_format, ['NCHW', 'NHWC'], 'format', self.name)
692
+ self.add_prim_attr('data_format', self.format)
693
+
694
+ def _avgpoolgrad_check_int_or_tuple(argname, argval):
695
+ validator.check_value_type(argname, argval, (int, tuple), self.name)
696
+ errormsg = ValueError(f"For '{self.name}' the '{argname}' should be an positive int number "
697
+ f"or a tuple of two or four positive int numbers, but got {argval}")
698
+ if isinstance(argval, int):
699
+ ret = (1, 1, argval, argval)
700
+ elif len(argval) == 2:
701
+ ret = (1, 1, argval[0], argval[1])
702
+ elif len(argval) == 4:
703
+ ret = argval
704
+ else:
705
+ raise errormsg
706
+ # whether all elements of tuple are positive integers?
707
+ for it in ret:
708
+ if not isinstance(it, int) or it <= 0:
709
+ raise errormsg
710
+ return ret
711
+
712
+ self.kernel_size = _avgpoolgrad_check_int_or_tuple(
713
+ "kernel_size", kernel_size)
714
+ self.strides = _avgpoolgrad_check_int_or_tuple("strides", strides)
715
+
716
+ self.kernel_size_adapt = self.kernel_size if self.format == "NCHW" else (
717
+ self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
718
+ self.strides_adapt = self.strides if self.format == "NCHW" else (
719
+ self.strides[0], self.strides[2], self.strides[3], self.strides[1])
720
+
721
+ # If length of some attrs is 4 we regard it as legal, either by using the op directly,
722
+ # or passed from an instance of forward op AvgPoolV1.
723
+ if len(self.kernel_size) == 4:
724
+ self.kernel_size_adapt = self.kernel_size
725
+ if len(self.strides) == 4:
726
+ self.strides_adapt = self.strides
727
+
728
+ self.add_prim_attr("kernel_size", self.kernel_size_adapt)
729
+ self.add_prim_attr("strides", self.strides_adapt)
730
+
731
+
732
+ class AdaptiveAvgPool2DGrad(Primitive):
733
+ """Gradients of the adaptive avg pool 2D operation."""
734
+
735
+ @prim_attr_register
736
+ def __init__(self):
737
+ """Initialize AdaptiveAvgPool2DGrad"""
738
+ self.init_prim_io_names(inputs=['input_grad', 'orig_input_shape'], outputs=['output_grad'])
739
+
740
+
741
+ class AdaptiveAvgPool3DGrad(Primitive):
742
+ """Performs grad of AdaptiveAvgPool3D operation."""
743
+ @prim_attr_register
744
+ def __init__(self):
745
+ self.init_prim_io_names(inputs=['y_grad', 'orig_input_shape'], outputs=['x_grad'])
746
+
747
+
748
+ class AvgPool3DGrad(Primitive):
749
+ """Gradients of the avg pool3d operation."""
750
+
751
+ @prim_attr_register
752
+ def __init__(self, kernel_size=1, strides=1, pads=0, ceil_mode=False,
753
+ count_include_pad=True, divisor_override=0, data_format="NCDHW", pad_mode="pad"):
754
+ self.init_prim_io_names(inputs=['origin_input_shape', 'grads'], outputs=['output'])
755
+ self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name, allow_five=True, ret_five=True)
756
+ self.add_prim_attr('kernel_size', self.kernel_size)
757
+ self.strides = _check_3d_int_or_tuple('strides', strides, self.name, allow_five=True, ret_five=True)
758
+ self.add_prim_attr('strides', self.strides)
759
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
760
+ self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name)
761
+ validator.check_value_type('pads', pads, (int, tuple), self.name)
762
+ if isinstance(pads, int):
763
+ pads = (pads,) * 6
764
+ validator.check_equal_int(len(pads), 6, 'pad size', self.name)
765
+ for item in pads:
766
+ validator.check_non_negative_int(item, 'pad item', self.name)
767
+ self.add_prim_attr('pad_list', pads)
768
+ self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
769
+ self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
770
+ self.divisor_override = validator.check_value_type('divisor_override', divisor_override, int, self.name)
771
+ self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
772
+
773
+
774
+ class AdaptiveMaxPool2DGrad(Primitive):
775
+ """Gradients of the adaptive max pool 2D operation."""
776
+ @prim_attr_register
777
+ def __init__(self):
778
+ """Initialize AdaptiveMaxPool2DGrad"""
779
+ self.init_prim_io_names(inputs=['y_grad', 'x', 'argmax'], outputs=['x_grad'])
780
+
781
+
782
+ class MaxPoolGrad(_PoolGrad):
783
+ """Performs gradients of the max pool operation."""
784
+
785
+ @prim_attr_register
786
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
787
+ super(MaxPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format)
788
+
789
+ def infer_shape(self, x1_shape, x2_shape, grad_shape):
790
+ return x1_shape
791
+
792
+ def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
793
+ return x1_dtype
794
+
795
+
796
+ class MaxPoolGradV1(Primitive):
797
+ """Performs gradients of the MaxPoolV1 operation."""
798
+
799
+ @prim_attr_register
800
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
801
+ self.init_prim_io_names(
802
+ inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
803
+
804
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
805
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
806
+ self.pad_mode = validator.check_string(
807
+ pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
808
+ self.add_prim_attr("pad_mode", self.pad_mode)
809
+ self.format = validator.check_string(
810
+ data_format, ['NCHW', 'NHWC'], 'format', self.name)
811
+ self.add_prim_attr('data_format', self.format)
812
+
813
+ def _grad_check_int_or_tuple(arg_name, arg_val):
814
+ validator.check_value_type(
815
+ arg_name, arg_val, (int, tuple), self.name)
816
+ error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
817
+ f"or a tuple of two or four positive int numbers, but got {arg_val}")
818
+ if isinstance(arg_val, int):
819
+ ret = (1, 1, arg_val, arg_val)
820
+ elif len(arg_val) == 2:
821
+ ret = (1, 1, arg_val[0], arg_val[1])
822
+ elif len(arg_val) == 4:
823
+ ret = arg_val
824
+ else:
825
+ raise error_msg
826
+ # whether all elements of tuple are positive integers
827
+ for item in ret:
828
+ if not isinstance(item, int) or item <= 0:
829
+ raise error_msg
830
+ return ret
831
+
832
+ self.kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
833
+ self.strides = _grad_check_int_or_tuple("strides", strides)
834
+
835
+ kernel_size_adapted = self.kernel_size if self.format == 'NCHW' else (
836
+ self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
837
+ strides_adapted = self.strides if self.format == 'NCHW' else (
838
+ self.strides[0], self.strides[2], self.strides[3], self.strides[1])
839
+
840
+ if len(kernel_size) == 4:
841
+ kernel_size_adapted = kernel_size
842
+ if len(strides) == 4:
843
+ strides_adapted = strides
844
+
845
+ self.add_prim_attr("kernel_size", kernel_size_adapted)
846
+ self.add_prim_attr("strides", strides_adapted)
847
+
848
+
849
+ class MaxPoolGradGrad(_PoolGrad):
850
+ r"""
851
+ Performs gradients of the MaxPoolGrad operation.
852
+
853
+ Args:
854
+ kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
855
+ is an int number that represents height and width are both kernel_size, or a tuple
856
+ of two int numbers that represent height and width respectively. Default: 1.
857
+ strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
858
+ the height and width of movement are both strides, or a tuple of two int numbers that
859
+ represent height and width of movement respectively. Default: 1.
860
+ pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
861
+ Default: "valid".
862
+
863
+ - same: Adopts the way of completion. The height and width of the output will be the same as
864
+ the input. The total number of padding will be calculated in horizontal and vertical
865
+ directions and evenly distributed to top and bottom, left and right if possible.
866
+ Otherwise, the last extra padding will be done from the bottom and the right side.
867
+
868
+ - valid: Adopts the way of discarding. The possible largest height and width of output
869
+ will be returned without padding. Extra pixels will be discarded.
870
+
871
+ Inputs:
872
+ - **origin_input** (Tensor) - Tensor with data format "NCHW".
873
+ For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
874
+ - **origin_output** (Tensor) - Data type same as `origin_input`.
875
+ - **grad** (Tensor) - Data type and shape same as `origin_input`.
876
+
877
+ Outputs:
878
+ Tensor, with data type same as `origin_input`. Shape same as `origin_output`.
879
+
880
+ Raises:
881
+ TypeError: If kernel_size is neither int nor a tuple of 2/4 int numbers.
882
+ TypeError: If strides is neither int nor a tuple of 2/4 int numbers.
883
+ TypeError: If pad_mode is not string.
884
+ ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
885
+ TypeError: For Ascend, input data type is not float16. For CPU or GPU, input data type is neither
886
+ float16 nor float32.
887
+ ValueError: If the rank of `origin_input`, `origin_output` or `grad` is not equal to 4.
888
+ ValueError: If data types of all inputs are not equal.
889
+ ValueError: If the shapes of `origin_input` and `grad` are not equal.
890
+
891
+ Supported Platforms:
892
+ ``Ascend`` ``GPU`` ``CPU``
893
+ """
894
+
895
+ @prim_attr_register
896
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
897
+ super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode)
898
+
899
+ def infer_shape(self, x1_shape, x2_shape, grad_shape):
900
+ return x2_shape
901
+
902
+ def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
903
+ args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
904
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
905
+ return x2_dtype
906
+
907
+
908
+ def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):
909
+ """
910
+ helper for get max pool3d grad pads by pad_mode
911
+ """
912
+
913
+ def get_pad(origin_shape, ksize, stride):
914
+ tail = origin_shape % stride
915
+ pad = (ksize - tail) if tail > 0 else (ksize - stride)
916
+ pad = max(pad, 0)
917
+ pad1 = int(pad / 2)
918
+ pad2 = int(pad / 2) + pad % 2
919
+ return pad1, pad2
920
+
921
+ _, _, d, h, w = input_shape
922
+ _, _, kd, kh, kw = kernel_size
923
+ _, _, strd, strh, strw = strides
924
+
925
+ pads = (0, 0, 0, 0, 0, 0)
926
+ if pad_mode == 'SAME':
927
+ pads_d = get_pad(d, kd, strd)
928
+ pads_h = get_pad(h, kh, strh)
929
+ pads_w = get_pad(w, kw, strw)
930
+ pads = pads_d + pads_h + pads_w
931
+ return pads
932
+
933
+
934
+ class MaxPool3DGrad(Primitive):
935
+ """Gradients of the max pool3d operation."""
936
+
937
+ @prim_attr_register
938
+ def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1),
939
+ pad_mode='VALID', pad_list=0, data_format="NCDHW"):
940
+ self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
941
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
942
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
943
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
944
+ self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
945
+ if pad_mode.upper() == 'PAD':
946
+ pad_mode = 'CALCULATED'
947
+ self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'CALCULATED'], 'pad_mode', self.name)
948
+ self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
949
+ allow_five=True, ret_five=True)
950
+ self.add_prim_attr("kernel_size", self.kernel_size)
951
+ self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
952
+ self.add_prim_attr("strides", self.strides)
953
+ validator.check_value_type('pad_list', pad_list, (int, tuple), self.name)
954
+ self.pad_list = pad_list
955
+ if isinstance(self.pad_list, int):
956
+ self.pad_list = (self.pad_list,) * 6
957
+ if len(self.pad_list) == 3:
958
+ self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[3])
959
+ if len(self.pad_list) != 3 and len(self.pad_list) != 6:
960
+ raise ValueError(f"For `maxpool3d` attr 'pad_list' must be an positive int number or a tuple of "
961
+ f"three or six positive int numbers, but got `{len(self.pad_list)}` numbers.")
962
+ if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0):
963
+ raise ValueError(f"For '{self.name}', when pad_list is not 0, pad_mode must be set as 'pad'.")
964
+ if self.pad_mode == 'CALCULATED':
965
+ for item in self.pad_list:
966
+ validator.check_non_negative_int(item, 'pad_list item', self.name)
967
+ self.add_prim_attr("pad_list", self.pad_list)
968
+
969
+
970
+ class MaxPool3DGradGrad(PrimitiveWithInfer):
971
+ r"""Gradients of the max pool3d grad operation.
972
+
973
+ Args:
974
+ kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
975
+ is an int number that represents depth, height and width are both kernel_size, or a tuple
976
+ of two int numbers that represent depth, height and width respectively. Default: 1.
977
+ strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
978
+ the depth, height and width of movement are both strides, or a tuple of two int numbers that
979
+ represent depth, height and width of movement respectively. Default: 1.
980
+ pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
981
+ Default: "valid".
982
+
983
+ - same: Adopts the way of completion. The depth, height and width of the output will be the
984
+ same as the input. The total number of padding will be calculated in depth, horizontal and
985
+ vertical directions and evenly distributed to front and back, top and bottom, left and
986
+ right if possible. Otherwise, the last extra padding will be done from the back, the bottom
987
+ and the right side.
988
+
989
+ - valid: Adopts the way of discarding. The possible largest height and width of output
990
+ will be returned without padding. Extra pixels will be discarded.
991
+
992
+ Inputs:
993
+ - **origin_input** (Tensor) - Tensor with data format "NCDHW".
994
+ For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
995
+ - **origin_output** (Tensor) - Data type same as `origin_input`.
996
+ - **grad** (Tensor) - Data type and shape same as `origin_input`.
997
+
998
+ Outputs:
999
+ Tensor, with data type same as `origin_input`. Shape same as `origin_output`.
1000
+
1001
+ Raises:
1002
+ TypeError: If kernel_size is neither int nor a tuple of 3/5 int numbers.
1003
+ TypeError: If strides is neither int nor a tuple of 3/5 int numbers.
1004
+ TypeError: If pad_mode is not string.
1005
+ ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
1006
+ TypeError: For Ascend, input data type is not float16. For CPU or GPU, input data type is neither
1007
+ float16 nor float32.
1008
+ ValueError: If the rank of `origin_input`, `origin_output` or `grad` is not equal to 5.
1009
+ ValueError: If data types of all inputs are not equal.
1010
+ ValueError: If the shapes of `origin_input` and `grad` are not equal.
1011
+
1012
+ Supported Platforms:
1013
+ ``Ascend`` ``GPU`` ``CPU``
1014
+ """
1015
+
1016
+ @prim_attr_register
1017
+ def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"):
1018
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1019
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
1020
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1021
+ self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
1022
+ self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1023
+ self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
1024
+ allow_five=True, ret_five=True)
1025
+ self.add_prim_attr("kernel_size", self.kernel_size)
1026
+ self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
1027
+ self.add_prim_attr("strides", self.strides)
1028
+
1029
+ def infer_shape(self, x_shape, y_shape, grad_shape):
1030
+ validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
1031
+ validator.check('x_shape', x_shape, 'grad_shape', grad_shape, prim_name=self.name)
1032
+ pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
1033
+ for pad in pad_list:
1034
+ validator.check_non_negative_int(pad, 'element of pad_list', self.name)
1035
+ self.add_prim_attr("pad_list", pad_list)
1036
+ return y_shape
1037
+
1038
+ def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
1039
+ args = {'x_dtype': x_dtype, 'y_dtype': y_dtype}
1040
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
1041
+ validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name)
1042
+ return x_dtype
1043
+
1044
+
1045
+ class MaxPoolGradWithArgmax(Primitive):
1046
+ """Computes the gradients of MaxPoolWithArgmax."""
1047
+ @prim_attr_register
1048
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
1049
+ self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
1050
+ validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1051
+ validator.check_value_type('strides', strides, [int, tuple], self.name)
1052
+ validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1053
+ self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1054
+ self.add_prim_attr("pad_mode", self.pad_mode)
1055
+ self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
1056
+ if context.get_context("device_target") != "GPU" and self.format == "NHWC":
1057
+ raise ValueError("NHWC format only support in GPU target.")
1058
+ self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
1059
+ if not self.is_maxpoolgradwithargmax:
1060
+ self.add_prim_attr('data_format', self.format)
1061
+
1062
+ def _grad_check_int_or_tuple(arg_name, arg_val):
1063
+ validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
1064
+ error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
1065
+ f"or a tuple of two or four positive int numbers, but got {arg_val}")
1066
+ if isinstance(arg_val, int):
1067
+ ret = (1, arg_val, arg_val, 1)
1068
+ elif len(arg_val) == 2:
1069
+ ret = (1, arg_val[0], arg_val[1], 1)
1070
+ elif len(arg_val) == 4:
1071
+ ret = arg_val
1072
+ else:
1073
+ raise error_msg
1074
+ # whether all elements of tuple are positive integers
1075
+ for item in ret:
1076
+ if not isinstance(item, int) or item <= 0:
1077
+ raise error_msg
1078
+ return ret
1079
+
1080
+ kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
1081
+ self.kernel_size = kernel_size
1082
+ self.add_prim_attr("kernel_size", self.kernel_size)
1083
+
1084
+ strides = _grad_check_int_or_tuple("strides", strides)
1085
+ self.strides = strides
1086
+ self.add_prim_attr("strides", self.strides)
1087
+
1088
+
1089
+ class MaxPoolGradWithArgmaxV2(Primitive):
1090
+ """Gradients of the MaxPoolWithArgmaxV2 operation."""
1091
+
1092
+ @prim_attr_register
1093
+ def __init__(self, kernel_size, strides=None, pads=0, dilation=(1, 1), ceil_mode=False, argmax_type=mstype.int64):
1094
+ self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['y'])
1095
+ self.kernel_size = _check_positive_int_or_tuple("kernel_size", kernel_size, self.name, allow_four=True,
1096
+ ret_four=True)
1097
+ self.add_prim_attr('kernel_size', self.kernel_size)
1098
+ if strides is None:
1099
+ strides = kernel_size
1100
+ self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=True, ret_four=True)
1101
+ self.add_prim_attr('strides', self.strides)
1102
+ self.pads = _check_positive_int_or_tuple("pads", pads, self.name, allow_four=True, ret_four=True,
1103
+ strict_positive=False)
1104
+ self.add_prim_attr('pads', self.pads)
1105
+ validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
1106
+ self.add_prim_attr('ceil_mode', self.ceil_mode)
1107
+ self.dilation = _check_positive_int_or_tuple("dilation", dilation, self.name, allow_four=True, ret_four=True)
1108
+ self.add_prim_attr('dilation', self.dilation)
1109
+ self.add_prim_attr('argmax_type', self.argmax_type)
1110
+
1111
+
1112
+ class MaxPool3DGradWithArgmax(Primitive):
1113
+ """Gradients of the maxpool3Dwithargmax operation."""
1114
+
1115
+ @prim_attr_register
1116
+ def __init__(self, ksize, strides, pads, dilation=(1, 1, 1), ceil_mode=False, data_format="NCDHW"):
1117
+ self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
1118
+ validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
1119
+ validator.check_value_type('data_format', data_format, str, self.name)
1120
+ self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
1121
+ self.ksize = _check_3d_int_or_tuple("ksize", ksize, self.name, ret_five=False)
1122
+ self.add_prim_attr('ksize', self.ksize)
1123
+ self.strides = _check_3d_int_or_tuple("strides", strides, self.name, ret_five=False)
1124
+ self.add_prim_attr('strides', self.strides)
1125
+ self.pads = _check_3d_int_or_tuple("pads", pads, self.name, greater_zero=False, ret_five=False)
1126
+ self.add_prim_attr('pads', self.pads)
1127
+ self.dilation = _check_3d_int_or_tuple("dilation", dilation, self.name, allow_five=True, ret_five=False)
1128
+ self.add_prim_attr('dilation', self.dilation)
1129
+
1130
+
1131
+ class MaxPoolGradGradWithArgmax(_PoolGrad):
1132
+ r"""
1133
+ Computes the gradients of MaxPoolGradWithArgmax.
1134
+
1135
+ Args:
1136
+ kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
1137
+ is an int number that represents height and width are both kernel_size, or a tuple
1138
+ of two int numbers that represent height and width respectively. Default: 1.
1139
+ strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
1140
+ the height and width of movement are both strides, or a tuple of two int numbers that
1141
+ represent height and width of movement respectively. Default: 1.
1142
+ pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1143
+ Default: "valid".
1144
+
1145
+ - same: Adopts the way of completion. The height and width of the output will be the same as
1146
+ the input. The total number of padding will be calculated in horizontal and vertical
1147
+ directions and evenly distributed to top and bottom, left and right if possible.
1148
+ Otherwise, the last extra padding will be done from the bottom and the right side.
1149
+
1150
+ - valid: Adopts the way of discarding. The possible largest height and width of output
1151
+ will be returned without padding. Extra pixels will be discarded.
1152
+
1153
+ Inputs:
1154
+ - **x** (Tensor) - Tensor with data format "NCHW".
1155
+ For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
1156
+ - **grad** (Tensor) - Data type and shape same as `x`.
1157
+ - **argmax** (Tensor) - Data type must be int32 or int64.
1158
+
1159
+ Outputs:
1160
+ Tensor, with data type same as `x`. Shape same as `argmax`.
1161
+
1162
+ Raises:
1163
+ TypeError: If kernel_size is neither int nor a tuple of 2/4 int numbers.
1164
+ TypeError: If strides is neither int nor a tuple of 2/4 int numbers.
1165
+ TypeError: If pad_mode is not string.
1166
+ ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
1167
+ TypeError: For Ascend, the data types of `x` and `grad` are not float16.
1168
+ For CPU or GPU, the data types of `x` and `grad` are neither float16 nor float32.
1169
+ TypeError: The data type of `argmax` is neither int32 nor int64.
1170
+ ValueError: If the rank of `x`, `grad` or `argmax` is not equal to 4.
1171
+ ValueError: If the shapes of `x` and `grad` are not equal.
1172
+
1173
+ Supported Platforms:
1174
+ ``Ascend`` ``GPU`` ``CPU``
1175
+ """
1176
+
1177
+ @prim_attr_register
1178
+ def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
1179
+ self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
1180
+ super(MaxPoolGradGradWithArgmax, self).__init__(kernel_size, strides, pad_mode)
1181
+
1182
+ def infer_shape(self, x_shape, grad_shape, argmax_shape):
1183
+ if not grad_shape:
1184
+ raise TypeError("The dout of MaxPoolGradGradWithArgmax must be a Tensor.")
1185
+ return x_shape
1186
+
1187
+ def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
1188
+ args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
1189
+ validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
1190
+ return grad_dtype
1191
+
1192
+
1193
+ class MinimumGradGrad(Primitive):
1194
+ """Grad for minimum_grad."""
1195
+ @prim_attr_register
1196
+ def __init__(self):
1197
+ """Initialize MinimumGradGrad"""
1198
+ super().__init__("MinimumGradGrad")
1199
+ self.init_prim_io_names(inputs=['x1', 'x2', 'grad_y1', 'grad_y2'],
1200
+ outputs=['sopd_x1', 'sopd_x2', 'sopd_grads'])
1201
+
1202
+
1203
+ class L2NormalizeGrad(Primitive):
1204
+ r"""
1205
+ Gradients of L2 normalize.
1206
+
1207
+ Args:
1208
+ axis (Union[list(int), tuple(int), int]): The begin axis for the input to apply L2 normalize. Default: 0.
1209
+ epsilon (float): A small value added for numerical stability. Default: 1e-4.
1210
+
1211
+ Inputs:
1212
+ - **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize.
1213
+ - **out** (Tensor) - Must be the output of forward operator L2Normalize.
1214
+ - **dout** (Tensor) - The backprop of the next layer.
1215
+
1216
+ Outputs:
1217
+ Tensor, gradients of L2Normalize `input_x`.
1218
+ """
1219
+
1220
+ @prim_attr_register
1221
+ def __init__(self, axis=0, epsilon=1e-4):
1222
+ axis = [axis] if isinstance(axis, int) else axis
1223
+ validator.check_value_type('axis', axis, [list, tuple], self.name)
1224
+ validator.check_value_type('epsilon', epsilon, [int, float], self.name)
1225
+ self.add_prim_attr('axis', axis)
1226
+ self.init_attrs['axis'] = axis
1227
+ if len(axis) != 1:
1228
+ raise TypeError("The length of axis must be 1, later will support multiple axis!")
1229
+
1230
+
1231
+ class LSTMGradData(Primitive):
1232
+ """Computes the data gradients of LSTM."""
1233
+
1234
+ @prim_attr_register
1235
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1236
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1237
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1238
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1239
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1240
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1241
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1242
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1243
+
1244
+ if bidirectional:
1245
+ self.num_directions = 2
1246
+ else:
1247
+ self.num_directions = 1
1248
+
1249
+
1250
+ class LSTMGradWeight(Primitive):
1251
+ """Computes the weight gradients of LSTM."""
1252
+
1253
+ @prim_attr_register
1254
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1255
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1256
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1257
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1258
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1259
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1260
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1261
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1262
+
1263
+ if bidirectional:
1264
+ self.num_directions = 2
1265
+ else:
1266
+ self.num_directions = 1
1267
+
1268
+
1269
+ class LSTMGrad(Primitive):
1270
+ """Computes the data and weight gradients of LSTM."""
1271
+
1272
+ @prim_attr_register
1273
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
1274
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1275
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1276
+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
1277
+ 'proj_size', self.name)
1278
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1279
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1280
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1281
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1282
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1283
+
1284
+ if bidirectional:
1285
+ self.num_directions = 2
1286
+ else:
1287
+ self.num_directions = 1
1288
+
1289
+
1290
+ class DynamicRNNGrad(Primitive):
1291
+ """Computes the input gradients of DynamicRNN."""
1292
+
1293
+ @prim_attr_register
1294
+ def __init__(self,
1295
+ cell_type='LSTM',
1296
+ direction='UNIDIRECTIONAL',
1297
+ cell_depth=1,
1298
+ use_peephole=False,
1299
+ keep_prob=1.0,
1300
+ cell_clip=-1.0,
1301
+ num_proj=0,
1302
+ time_major=True,
1303
+ forget_bias=0.0):
1304
+ self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
1305
+
1306
+
1307
+ class GruGradData(PrimitiveWithInfer):
1308
+ """Computes the data gradients of GRU."""
1309
+
1310
+ @prim_attr_register
1311
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1312
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1313
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1314
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1315
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1316
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1317
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1318
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1319
+
1320
+ if bidirectional:
1321
+ self.num_directions = 2
1322
+ else:
1323
+ self.num_directions = 1
1324
+
1325
+ def infer_shape(self, y_shape, dy_shape, dhy_shape, w_shape,
1326
+ hx_shape, reserve_shape, state_shape):
1327
+ # dhy and dcy should be same shape
1328
+ validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1329
+
1330
+ validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, validator.EQ, "h_shape[0]", self.name)
1331
+ validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1332
+
1333
+ validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1334
+ validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1335
+ validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, validator.EQ, "dy[2]", self.name)
1336
+
1337
+ dx_shape = (y_shape[0], y_shape[1], self.input_size)
1338
+ dhx_shape = dhy_shape
1339
+
1340
+ return (dx_shape, dhx_shape)
1341
+
1342
+ def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, w_dtype,
1343
+ hx_dtype, reserve_dtype, state_dtype):
1344
+ args = {"dy": dy_dtype, "dhy": dhy_dtype}
1345
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
1346
+ return (dy_dtype, dy_dtype)
1347
+
1348
+
1349
+ class GruGradWeight(PrimitiveWithInfer):
1350
+ """Computes the weight gradients of GRU."""
1351
+
1352
+ @prim_attr_register
1353
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1354
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1355
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1356
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1357
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1358
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1359
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1360
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1361
+
1362
+ if bidirectional:
1363
+ self.num_directions = 2
1364
+ else:
1365
+ self.num_directions = 1
1366
+
1367
+ def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
1368
+ weight_size = 0
1369
+ gate_size = 3 * self.hidden_size
1370
+ for layer in range(self.num_layers):
1371
+ for _ in range(self.num_directions):
1372
+ input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1373
+ weight_size += gate_size * input_layer_size
1374
+ weight_size += gate_size * self.hidden_size
1375
+ if self.has_bias:
1376
+ weight_size += 2 * gate_size
1377
+
1378
+ return (weight_size, 1, 1)
1379
+
1380
+ def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
1381
+ return hx_dtype
1382
+
1383
+
1384
+ class GRUV2Grad(Primitive):
1385
+ """Computes the grad gradients of GRU."""
1386
+
1387
+ @prim_attr_register
1388
+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1389
+ self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1390
+ self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1391
+ self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1392
+ self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1393
+ self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1394
+ self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1395
+ self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1396
+
1397
+ if bidirectional:
1398
+ self.num_directions = 2
1399
+ else:
1400
+ self.num_directions = 1
1401
+
1402
+
1403
+ class DynamicGRUV2Grad(Primitive):
1404
+ r"""
1405
+ Computes the input gradients of DynamicGRUV2.
1406
+
1407
+ Args:
1408
+ direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
1409
+ Only 'UNIDIRECTIONAL' is currently supported.
1410
+ cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
1411
+ keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
1412
+ cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
1413
+ num_proj (int): An integer identifying the num proj in the op. Default: 0.
1414
+ time_major (bool): A bool identifying the time major in the op. Default: ``True``.
1415
+ gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh.
1416
+ 'zrh' is another option.
1417
+ reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication.
1418
+ Default: ``True``.
1419
+
1420
+ Inputs:
1421
+ - **x** (Tensor) - Current words. Tensor of shape :math:`(num\_step, batch\_size, input\_size)`.
1422
+ The data type must be float16 or float32.
1423
+ - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input\_size, 3 x hidden\_size)`.
1424
+ The data type must be float16 or float32.
1425
+ - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden\_size, 3 x hidden\_size)`.
1426
+ The data type must be float16 or float32.
1427
+ - **y** (Tensor) - A Tensor of shape :math:
1428
+ if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
1429
+ if num_proj == 0 `(num_step, batch_size, hidden_size)`.
1430
+ The data type must be float16 or float32.
1431
+ - **init_h** (Tensor) - Hidden state of initial time.
1432
+ Tensor of shape :math:`(batch\_size, hidden\_size)`.
1433
+ The data type must be float16 or float32.
1434
+ - **h** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1435
+ The data type must be float16 or float32.
1436
+ - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
1437
+ - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
1438
+ - **update** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1439
+ The data type must be float16 or float32.
1440
+ - **reset** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1441
+ The data type must be float16 or float32.
1442
+ - **new** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1443
+ The data type must be float16 or float32.
1444
+ - **hidden_new** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1445
+ The data type must be float16 or float32.
1446
+ - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch\_size)`.
1447
+ Only `None` is currently supported.
1448
+ - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32.
1449
+
1450
+ Outputs:
1451
+ - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`.
1452
+ Has the same type with input `x`.
1453
+ - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
1454
+ Has the same type with input `x`.
1455
+ - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
1456
+ Has the same type with input `init\_h`.
1457
+ - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
1458
+ Has the same type with input `init\_h`.
1459
+ - **dx** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1460
+ Has the same type with input `x`.
1461
+ - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch\_size, hidden\_size)`.
1462
+ Has the same type with input `init\_h`.
1463
+ """
1464
+
1465
+ @prim_attr_register
1466
+ def __init__(self,
1467
+ direction='UNIDIRECTIONAL',
1468
+ cell_depth=1,
1469
+ keep_prob=1.0,
1470
+ cell_clip=-1.0,
1471
+ num_proj=0,
1472
+ time_major=True,
1473
+ gate_order="rzh",
1474
+ reset_after=True):
1475
+ self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
1476
+ self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
1477
+ self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
1478
+ self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
1479
+ self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
1480
+ self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
1481
+ self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
1482
+ self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
1483
+ self.init_prim_io_names(inputs=[
1484
+ "x", "weight_input", "weight_hidden", "y", "init_h", "h", "dy",
1485
+ "dh", "update", "reset", "new", "hidden_new", "seq_length", "mask"
1486
+ ],
1487
+ outputs=[
1488
+ "dw_input", "dw_hidden", "db_input",
1489
+ "db_hidden", "dx", "dh_prev"
1490
+ ])
1491
+
1492
+
1493
+ class RandomGammaGrad(Primitive):
1494
+ r"""
1495
+ Computes the derivative of a random sample of Gamma with respect to alpha.:
1496
+
1497
+ Inputs:
1498
+ - **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
1499
+ It must be greater than 0. Must be one of the following types: float32, float64.
1500
+ - **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
1501
+ following types: float32, float64.
1502
+
1503
+ Outputs:
1504
+ The dtype is the same type as alpha.
1505
+ The output shape is derived from the input through broadcasting.
1506
+
1507
+ Raises:
1508
+ TypeError: If data type of `alpha` and `sample` is not float32 or float64.
1509
+ TypeError: If data type of `alpha` and `sample` is not same.
1510
+ ValueError: If the shape last dim of `sample` and `alpha` is not equal.
1511
+
1512
+ Supported Platforms:
1513
+ ``GPU``
1514
+
1515
+ Examples:
1516
+ >>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
1517
+ >>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
1518
+ >>> randomgammagrad = ops.RandomGammaGrad()
1519
+ >>> output = randomgammagrad(alpha, sample)
1520
+ >>> print(output)
1521
+ [2.5142431 3.4334087 1.8847835 0.07780622]
1522
+ """
1523
+
1524
+ @prim_attr_register
1525
+ def __init__(self):
1526
+ """Initialize RandomGammaGrad"""
1527
+ self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
1528
+ self.add_prim_attr("side_effect_hidden", True)
1529
+
1530
+
1531
+ class ROIAlignGrad(Primitive):
1532
+ """
1533
+ ROIAlignGrad operator.
1534
+
1535
+ Args:
1536
+ pooled_height (int): The output feature height.
1537
+ pooled_width (int): The output feature width.
1538
+ spatial_scale (float): The feature stride.
1539
+ sample_num (int): Number of sampling points. Default: 2.
1540
+ """
1541
+
1542
+ @prim_attr_register
1543
+ def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2):
1544
+ """Initialize ROIAlignGrad"""
1545
+ self.init_prim_io_names(inputs=["dy", "rois", "xdiff_shape"], outputs=["dx"])
1546
+ validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1547
+ validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1548
+ validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1549
+ validator.check_value_type("sample_num", sample_num, [int], self.name)
1550
+ self.pooled_height = pooled_height
1551
+ self.pooled_width = pooled_width
1552
+ self.spatial_scale = spatial_scale
1553
+ self.sample_num = sample_num
1554
+
1555
+
1556
+ class PsROIPoolingGrad(PrimitiveWithInfer):
1557
+ """
1558
+ PsROIPoolingGrad operator.
1559
+ """
1560
+
1561
+ @prim_attr_register
1562
+ def __init__(self, batch_size, channels, height, width, num_rois,
1563
+ pooled_height, pooled_width, spatial_scale, out_dim):
1564
+ """Initialize PsROIPoolingGrad"""
1565
+ validator.check_value_type("batch_size", batch_size, [int], self.name)
1566
+ validator.check_value_type("channels", channels, [int], self.name)
1567
+ validator.check_value_type("height", height, [int], self.name)
1568
+ validator.check_value_type("width", width, [int], self.name)
1569
+ validator.check_value_type("num_rois", num_rois, [int], self.name)
1570
+ validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1571
+ validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1572
+ validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1573
+ validator.check_value_type("out_dim", out_dim, [int], self.name)
1574
+ self.batch_size = batch_size
1575
+ self.channels = channels
1576
+ self.height = height
1577
+ self.width = width
1578
+ self.num_rois = num_rois
1579
+ self.pooled_height = pooled_height
1580
+ self.pooled_width = pooled_width
1581
+ self.spatial_scale = spatial_scale
1582
+ self.out_dim = out_dim
1583
+
1584
+ def infer_shape(self, ydiff_shape, rois_shape, mapping_channel_shape):
1585
+ return [self.batch_size, self.channels, self.height, self.width]
1586
+
1587
+ def infer_dtype(self, ydiff_type, rois_type, mapping_channel_type):
1588
+ return ydiff_type
1589
+
1590
+
1591
+ class _ActivationGrad(PrimitiveWithInfer):
1592
+ """_ActivationGrad base class."""
1593
+
1594
+ @prim_attr_register
1595
+ def __init__(self):
1596
+ self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
1597
+
1598
+ def infer_shape(self, y_grad_shape, x_shape):
1599
+ return x_shape
1600
+
1601
+ def infer_dtype(self, y_grad_dtype, x_dtype):
1602
+ valid_dtypes = (mstype.float16, mstype.float32)
1603
+ validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
1604
+ validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
1605
+ return x_dtype
1606
+
1607
+
1608
+ class SigmoidCrossEntropyWithLogitsGrad(Primitive):
1609
+ """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
1610
+
1611
+ @prim_attr_register
1612
+ def __init__(self):
1613
+ """Initialize SigmoidCrossEntropyWithLogitsGrad"""
1614
+ self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
1615
+
1616
+
1617
+ class SliceGrad(PrimitiveWithInfer):
1618
+ """Reverse of slice."""
1619
+
1620
+ @prim_attr_register
1621
+ def __init__(self):
1622
+ """Initialize SliceGrad"""
1623
+ self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
1624
+
1625
+ def __infer__(self, dy, x, begin, size):
1626
+ dy_shape, x_shape, size_value, begin_v = dy['shape'], x['shape'], size['value'], begin['value']
1627
+ dy_shape_len = len(dy_shape)
1628
+ if size_value is not None and not is_shape_unknown(x_shape) and not is_shape_unknown(dy_shape):
1629
+ size_value = list(size_value)
1630
+ for i in range(dy_shape_len):
1631
+ if size_value[i] == -1:
1632
+ size_value[i] = x_shape[i] - begin_v[i]
1633
+ validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], validator.LE, self.name)
1634
+ validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]',
1635
+ size_value[i], validator.EQ, self.name)
1636
+
1637
+ return {'shape': x_shape,
1638
+ 'dtype': x['dtype'],
1639
+ 'value': None}
1640
+
1641
+
1642
+ class SmoothL1LossGrad(Primitive):
1643
+ """Computes gradient for prediction on SmoothL1Loss."""
1644
+
1645
+ @prim_attr_register
1646
+ def __init__(self, beta=1.0, reduction='none'):
1647
+ self.add_prim_attr('sigma', self.beta)
1648
+ self.reduction = validator.check_string(
1649
+ reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1650
+
1651
+
1652
+ class SoftMarginLossGrad(Primitive):
1653
+ """Computes gradient for prediction on SoftMarginLoss."""
1654
+
1655
+ @prim_attr_register
1656
+ def __init__(self, reduction="mean"):
1657
+ self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
1658
+ self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1659
+
1660
+
1661
+ class StridedSliceGrad(Primitive):
1662
+ """
1663
+ Performs grad of StridedSlice operation.
1664
+
1665
+ Args:
1666
+ begin_mask (int): Start indexing the slice. Default: 0.
1667
+ end_mask (int): End indexing the slice. Default: 0.
1668
+ ellipsis_mask (int): An int32 mask. Default: 0.
1669
+ new_axis_mask (int): An int32 mask. Default: 0.
1670
+ shrink_axis_mask (int): An int32 mask. Default: 0.
1671
+
1672
+ Returns:
1673
+ Tensor, has the same shape of input.
1674
+ """
1675
+
1676
+ @prim_attr_register
1677
+ def __init__(self,
1678
+ begin_mask=0,
1679
+ end_mask=0,
1680
+ ellipsis_mask=0,
1681
+ new_axis_mask=0,
1682
+ shrink_axis_mask=0):
1683
+ """Initialize StridedSliceGrad"""
1684
+ validator.check_value_type('begin_mask', begin_mask, [int], self.name)
1685
+ validator.check_value_type('end_mask', end_mask, [int], self.name)
1686
+ validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
1687
+ validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
1688
+ validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
1689
+ self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
1690
+
1691
+
1692
+ class SoftplusGrad(Primitive):
1693
+ """Computes gradient for the Softplus activation."""
1694
+
1695
+ @prim_attr_register
1696
+ def __init__(self):
1697
+ self.init_prim_io_names(inputs=['gradients', 'features'], outputs=['backprops'])
1698
+
1699
+
1700
+ class TanhGrad(Primitive):
1701
+ """Computes gradient of hyperbolic tangent of input element-wise."""
1702
+
1703
+ @prim_attr_register
1704
+ def __init__(self):
1705
+ """Initialize TanhGrad"""
1706
+ self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
1707
+
1708
+
1709
+ class MirrorPadGrad(Primitive):
1710
+ """Gradients of MirrorPad operation."""
1711
+
1712
+ @prim_attr_register
1713
+ def __init__(self, mode="REFLECT"):
1714
+ """Initialize MirrorPad"""
1715
+ self.init_prim_io_names(inputs=['dy', 'paddings'], outputs=['output'])
1716
+ validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
1717
+ self.mode = mode
1718
+
1719
+
1720
+ class PadV3Grad(Primitive):
1721
+ """Gradients of PadV3 operation."""
1722
+
1723
+ @prim_attr_register
1724
+ def __init__(self, mode='reflect', paddings_contiguous=True):
1725
+ """Initialize Padv3Grad"""
1726
+ self.add_prim_attr("cust_aicpu", self.name)
1727
+ self.init_prim_io_names(inputs=['x', 'paddings'], outputs=['y'])
1728
+ validator.check_string(mode, ['reflect', 'edge', 'circular'], 'mode', self.name)
1729
+ validator.check_bool(paddings_contiguous, "paddings_contiguous", self.name)
1730
+ self.mode = mode
1731
+ self.paddings_contiguous = paddings_contiguous
1732
+
1733
+
1734
+ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
1735
+ """
1736
+ Performs the gradient for the communication part of EmbeddingLookup operator.
1737
+
1738
+ This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
1739
+ this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
1740
+ """
1741
+
1742
+ @prim_attr_register
1743
+ def __init__(self):
1744
+ self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
1745
+ self.set_device('CPU')
1746
+ self.tuple_setitem = Primitive('tuple_setitem')
1747
+
1748
+ def __infer__(self, dy, split_num):
1749
+ """
1750
+ This primitive is implemented by three steps:
1751
+ 1) Splits the 'dy' along dimension 0 into 'split_num' parts.
1752
+ 2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
1753
+ 3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
1754
+ along dimension 0.
1755
+
1756
+ The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
1757
+ """
1758
+ dy_shape = tuple(dy['shape'])
1759
+ split_num_value = split_num['value']
1760
+ validator.check_value_type("split_num_value", split_num_value, [int], self.name)
1761
+ dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
1762
+ return {'shape': dy_shape_all,
1763
+ 'dtype': dy['dtype'],
1764
+ 'value': None}
1765
+
1766
+
1767
+ class RefToEmbed(Primitive):
1768
+ r"""
1769
+ Make a key from Ref.
1770
+
1771
+ The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
1772
+ and get items by operation `EnvironGet` with the symbolic_key instance. The `Parameter` is a ref.
1773
+
1774
+ Inputs:
1775
+ - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
1776
+
1777
+ Outputs:
1778
+ symbolic_key, made from the Ref.
1779
+
1780
+ Examples:
1781
+ >>> class Net(nn.Cell):
1782
+ >>> def __init__(self):
1783
+ >>> super(Net, self).__init__()
1784
+ >>> self.weight = mindspore.Parameter(1.0, name='weight')
1785
+ >>>
1786
+ >>> def construct(self):
1787
+ >>> key = RefToEmbed()(self.weight)
1788
+ >>> return key, self.weight
1789
+ """
1790
+ __mindspore_signature__ = (
1791
+ sig.make_sig('variable', sig.sig_rw.RW_REF),
1792
+ )
1793
+
1794
+ @prim_attr_register
1795
+ def __init__(self):
1796
+ pass
1797
+
1798
+
1799
+ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
1800
+ """Computes the state gradients of BasicLSTMCell."""
1801
+
1802
+ @prim_attr_register
1803
+ def __init__(self, forget_bias, activation):
1804
+ self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
1805
+ self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
1806
+
1807
+ def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
1808
+ # dhy and dcy should be same shape
1809
+ validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
1810
+ validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), validator.EQ, self.name)
1811
+ validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), validator.EQ, self.name)
1812
+ validator.check("it rank", len(it_shape), "c rank", len(c_shape), validator.EQ, self.name)
1813
+ validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), validator.EQ, self.name)
1814
+ validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), validator.EQ, self.name)
1815
+ validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), validator.EQ, self.name)
1816
+ validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), validator.EQ, self.name)
1817
+ validator.check("dht shape", dht_shape, "c shape", c_shape, validator.EQ, self.name)
1818
+ validator.check("dct shape", dct_shape, "c shape", c_shape, validator.EQ, self.name)
1819
+ validator.check("it shape", it_shape, "c shape", c_shape, validator.EQ, self.name)
1820
+ validator.check("jt shape", jt_shape, "c shape", c_shape, validator.EQ, self.name)
1821
+ validator.check("ft shape", ft_shape, "c shape", c_shape, validator.EQ, self.name)
1822
+ validator.check("ot shape", ot_shape, "c shape", c_shape, validator.EQ, self.name)
1823
+ validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, validator.EQ, self.name)
1824
+
1825
+ dgate_shape = (c_shape[0], 4 * c_shape[1])
1826
+ dct_1_shape = c_shape
1827
+
1828
+ return (dgate_shape, dct_1_shape)
1829
+
1830
+ def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
1831
+ validator.check_subclass("c", c_dtype, [mstype.tensor_type], self.name)
1832
+ validator.check_subclass("dht", dht_dtype, [mstype.tensor_type], self.name)
1833
+ validator.check_subclass("dct", dct_dtype, [mstype.tensor_type], self.name)
1834
+ validator.check_subclass("it", it_dtype, [mstype.tensor_type], self.name)
1835
+ validator.check_subclass("jt", jt_dtype, [mstype.tensor_type], self.name)
1836
+ validator.check_subclass("ft", ft_dtype, [mstype.tensor_type], self.name)
1837
+ validator.check_subclass("ot", ot_dtype, [mstype.tensor_type], self.name)
1838
+ validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor_type], self.name)
1839
+ validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
1840
+ validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
1841
+ validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
1842
+ validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
1843
+ validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
1844
+ validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
1845
+ validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
1846
+ validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
1847
+ return (c_dtype, c_dtype)
1848
+
1849
+
1850
+ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
1851
+ """Computes the weight gradients of BasicLSTM."""
1852
+
1853
+ @prim_attr_register
1854
+ def __init__(self):
1855
+ pass
1856
+
1857
+ def infer_shape(self, x_shape, h_shape, dgate_shape):
1858
+ validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
1859
+ validator.check("h rank", len(h_shape), " x rank", len(x_shape), validator.EQ, self.name)
1860
+ validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), validator.EQ, self.name)
1861
+ validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], validator.EQ, self.name)
1862
+ validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], validator.EQ, self.name)
1863
+ validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], validator.EQ, self.name)
1864
+ input_size = x_shape[1]
1865
+ hidden_size = h_shape[1]
1866
+ dw_shape = (input_size + hidden_size, 4 * hidden_size)
1867
+ db_shape = (4 * hidden_size,)
1868
+ return (dw_shape, db_shape)
1869
+
1870
+ def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
1871
+ validator.check_subclass("x", x_dtype, mstype.tensor_type, self.name)
1872
+ validator.check_subclass("h", h_dtype, mstype.tensor_type, self.name)
1873
+ validator.check_subclass("dgate", dgate_dtype, mstype.tensor_type, self.name)
1874
+ validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
1875
+ validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
1876
+ validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
1877
+ return (x_dtype, x_dtype)
1878
+
1879
+
1880
+ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
1881
+ """Computes the input gradients of BasicLSTM."""
1882
+
1883
+ @prim_attr_register
1884
+ def __init__(self, keep_prob):
1885
+ self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
1886
+ self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, validator.INC_BOTH, "keep_prob", self.name)
1887
+
1888
+ def infer_shape(self, dgate_shape, w_shape):
1889
+ validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
1890
+ validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
1891
+ validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], validator.EQ, self.name)
1892
+ batch_size = dgate_shape[0]
1893
+ hidden_size = dgate_shape[1] // 4
1894
+ input_size = w_shape[0] - hidden_size
1895
+ dxt_shape = (batch_size, input_size)
1896
+ dht_shape = (batch_size, hidden_size)
1897
+ return (dxt_shape, dht_shape)
1898
+
1899
+ def infer_dtype(self, dgate_dtype, w_dtype):
1900
+ validator.check_subclass("dgate", dgate_dtype, mstype.tensor_type, self.name)
1901
+ validator.check_subclass("w", w_dtype, mstype.tensor_type, self.name)
1902
+ validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
1903
+ validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
1904
+ return (dgate_dtype, dgate_dtype)
1905
+
1906
+
1907
+ class InvGrad(Primitive):
1908
+ """Computes gradients for inv operation."""
1909
+
1910
+ @prim_attr_register
1911
+ def __init__(self):
1912
+ self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
1913
+
1914
+
1915
+ class LRNGrad(Primitive):
1916
+ """Computes gradients for LRN operation."""
1917
+
1918
+ @prim_attr_register
1919
+ def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
1920
+ self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
1921
+ validator.check_value_type("depth_radius", depth_radius, [int], self.name)
1922
+ validator.check_value_type("bias", bias, [float], self.name)
1923
+ validator.check_value_type("alpha", alpha, [float], self.name)
1924
+ validator.check_value_type("beta", beta, [float], self.name)
1925
+
1926
+
1927
+ class MvlgammaGrad(Primitive):
1928
+ r"""
1929
+ Computes gradients for Mvlgamma.
1930
+
1931
+ The following tex shows the mathematical calculation process of Mvlgamma:
1932
+
1933
+ .. math::
1934
+
1935
+ \log (\Gamma_{p}(a))=C+\sum_{i=1}^{p} \log (\Gamma(a-\frac{i-1}{2}))
1936
+
1937
+ where :math:`C = \log(\pi) \times \frac{p(p-1)}{4}` and :math:`\Gamma(\cdot)` is the Gamma function.
1938
+
1939
+ Args:
1940
+ p(int): The number of dimensions. And the value of `p` must be greater than or equal to 1.
1941
+
1942
+ Inputs:
1943
+ - **y_grad** (Tensor) - The input gradient.
1944
+ - **x** (Tensor) - The input of Mvlgamma with data type of float32 or float64.
1945
+
1946
+ Outputs:
1947
+ Tensor, has the same shape and type as `x`.
1948
+
1949
+ Raises:
1950
+ TypeError: If dtype of `y_grad or `x` is neither float32 nor float64.
1951
+ TypeError: If `p` is not an int.
1952
+ ValueError: If p is not greater than or equal to 1.
1953
+ ValueError: If all elements of `x` are not greater than (p-1)/2.
1954
+
1955
+ Supported Platforms:
1956
+ ``Ascend`` ``CPU``
1957
+ """
1958
+
1959
+ @prim_attr_register
1960
+ def __init__(self, p):
1961
+ self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['x_grad'])
1962
+ self.p = validator.check_value_type('p', p, [int], self.name)
1963
+
1964
+
1965
+ class CdistGrad(Primitive):
1966
+ """Computes gradient for Cdist."""
1967
+
1968
+ @prim_attr_register
1969
+ def __init__(self, p=2.0):
1970
+ validator.check_value_type("p", p, [float], self.name)
1971
+ self.init_prim_io_names(inputs=['grad', 'input_x', 'input_y', 'cdist'], outputs=['output'])
1972
+
1973
+
1974
+ class PdistGrad(Primitive):
1975
+ """Computes gradient for Pdist operation.
1976
+
1977
+ Args:
1978
+ p (float): the p value for the Pdist formulation. Default: 2.0.
1979
+
1980
+ Inputs:
1981
+ - **y_grad** (Tensor) - The gradients of loss to output of Pdist function.
1982
+ - **x** (Tensor) - Input tensor of shape :math:`(N, M)`.
1983
+ Must be the input `x` of the forward operator Pdist.
1984
+ - **y** (Tensor) - Input tensor of shape :math:`(N*(N-1)/2)`.
1985
+ Must be the output `y` of the forward operator Pdist.
1986
+
1987
+ Outputs:
1988
+ Tensor, with the same shape and dtype as `x`.
1989
+
1990
+ Raises:
1991
+ TypeError: If one of `y_grad`, `x` and `y` is not a Tensor.
1992
+ TypeError: If dtype of `y_grad`, `x` and `y` are not all float16, float32 or float64.
1993
+ TypeError: If `p` is not a float.
1994
+ ValueError: If `p` is a negative float.
1995
+ ValueError: If shape of `y_grad` is not same as `y`.
1996
+ ValueError: If dimension of `x` is not 2.
1997
+
1998
+ Supported Platforms:
1999
+ ``Ascend`` ``GPU`` ``CPU``
2000
+ """
2001
+
2002
+ @prim_attr_register
2003
+ def __init__(self, p=2.0):
2004
+ validator.check_value_type("p", p, [float], self.name)
2005
+ if p < 0:
2006
+ raise ValueError('Pdist p must be a non-negative value, but got `{p}`.')
2007
+ self.init_prim_io_names(inputs=['y_grad', 'x', 'y'], outputs=['x_grad'])
2008
+
2009
+
2010
+ class MultilabelMarginLossGrad(Primitive):
2011
+ """
2012
+ Compute the gradients of MultilabelMarginLoss operation.
2013
+
2014
+ Args:
2015
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2016
+ ``'sum'`` . Default: ``'mean'`` .
2017
+
2018
+ - ``'none'``: no reduction will be applied.
2019
+ - ``'mean'``: compute and return the mean of elements in the output.
2020
+ - ``'sum'``: the output elements will be summed.
2021
+
2022
+ Inputs:
2023
+ - **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
2024
+ the same shape and data type as forward output `y`.
2025
+ - **x** (Tensor) - Predict data. Tensor of shape :math:`(C)` or :math:`(N, C)`, where :math:`N`
2026
+ is the batch size and :math:`C` is the number of classes. Data type must be float16 or float32.
2027
+ - **target** (Tensor) - Ground truth data, with the same shape as `x`, data type must be int32 and
2028
+ label targets padded by -1.
2029
+ - **is_target** (Tensor) - Forward output tensor for backward input, with the same shape and
2030
+ data type as `target`.
2031
+
2032
+ Outputs:
2033
+ The shape of output :math:`(C)` or :math:`(N, C)`, with the same shape and data type as `x`.
2034
+
2035
+ Raises:
2036
+ TypeError: If `x` or `target` or `y_grad` is not a Tensor.
2037
+ TypeError: If dtype of `x` is neither float16 nor float32.
2038
+ TypeError: If dtype of `target` is not int32.
2039
+ TypeError: If dtype of `y_grad` is not the same as `x`.
2040
+ ValueError: If length of shape of `x` is neither 1 nor 2.
2041
+ ValueError: If shape of `x` is not the same as `target`.
2042
+ ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
2043
+ ValueError: If shape of `y_grad` is not the same as forward output `y`.
2044
+
2045
+ Supported Platforms:
2046
+ ``Ascend``
2047
+ """
2048
+
2049
+ @prim_attr_register
2050
+ def __init__(self, reduction="mean"):
2051
+ """Initialize MultilabelMarginLossGrad"""
2052
+ self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2053
+ self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'is_target'], outputs=['x_grad'])
2054
+
2055
+
2056
+ class Dilation2DBackpropInput(Primitive):
2057
+ """
2058
+ Computes the gradient of morphological 2-D dilation with respect to the input.
2059
+
2060
+ .. warning::
2061
+ This operator is an experimental operator, which has some accuracy problems for some inputs.
2062
+
2063
+ Args:
2064
+ stride (Union[int, tuple[int]]): The distance of filter moving, an int number that represents
2065
+ the height and width of movement are both strides, a tuple of two int numbers that
2066
+ represent height and width of movement respectively, or a tuple of four int numbers which
2067
+ should be :math:`(1, 1, H_{stride}, W_{stride})`.
2068
+ dilation (Union[int, tuple[int]]): The input stride for atrous morphological dilation.The data
2069
+ type is int or a tuple of 2 or 4 integers. Its value must be greater or equal to 1 and bounded
2070
+ by the height and width of the input `x`.
2071
+ pad_mode (str): Specifies padding mode. The optional values are "same", "valid".
2072
+ Default: "same". Both upper and lower case are supported.
2073
+ data_format (str): The format of input and output data. Only NCHW format is supported at present.
2074
+ Default:'NCHW'
2075
+
2076
+ Inputs:
2077
+ - **x** (Tensor) - Input data. A four dimension tensor with float16 or float32 data type. The shape must be
2078
+ :math:`(N, C_{in}, H_{in}, W_{in})`.
2079
+ - **filter** (Tensor) - A three dimension tensor with the same type as input. The shape must be
2080
+ :math:`(C_{in}, H_{filter}, W_{filter})`.
2081
+ - **out_backprop** (Tensor) - The gradients with respect to the output of the convolution.
2082
+ A four dimension tensor with float16 or float32 data type. The shape must be
2083
+ :math:`(N, C_{in}, H_{out}, W_{out})`.
2084
+
2085
+ outputs:
2086
+ Tensor, the gradients with respect to the input of convolution. It has the same shape and type as the input `x`.
2087
+
2088
+ Raises:
2089
+ TypeError: If type of `x` or `filter` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2090
+ int32, int64, float16, float32, float64].
2091
+ TypeError: If type of `out_backprop` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2092
+ int32, int64, float16, float32, float64].
2093
+ TypeError: If `stride` or `dilation` is not an int number or a tuple of two or four int numbers.
2094
+ ValueError: If the length of `stride` or `dilation` is neither two nor four when they are tuples.
2095
+ ValueError: If `stride` or `dilation` is not (1, 1, height, width) when it is a tuple of four int numbers.
2096
+ ValueError: If `stride` is not in the range of [1, 255].
2097
+ ValueError: If `dilation` is less than 1.
2098
+ ValueError: If `pad_mode` is not a str of 'same', 'valid', 'SAME' or 'VALID'.
2099
+ ValueError: If `data_format` is not the str of 'NCHW'.
2100
+
2101
+ Supported Platforms:
2102
+ ``Ascend`` ``GPU`` ``CPU``
2103
+
2104
+ Examples:
2105
+ (pad_mode="SAME", data_format="NCHW")
2106
+ >>> out_backprop = Tensor(np.ones([1, 3, 4, 4]), mstype.float32)
2107
+ >>> filter = Tensor(np.ones([3 , 2 , 2]), mstype.float32)
2108
+ >>> x = Tensor(np.ones([1, 3, 4, 4]), mstype.float32)
2109
+ >>> dilation_backprop_input = G.Dilation2DBackpropInput(stride=1, dilation=1)
2110
+ >>> output = dilation_backprop_input(x, filter, out_backprop)
2111
+ >>> print(output)
2112
+ [[[[1. 1. 1. 1.]
2113
+ [1. 1. 1. 1.]
2114
+ [1. 1. 1. 1.]
2115
+ [1. 1. 1. 1.]]
2116
+ [[1. 1. 1. 1.]
2117
+ [1. 1. 1. 1.]
2118
+ [1. 1. 1. 1.]
2119
+ [1. 1. 1. 1.]]
2120
+ [[1. 1. 1. 1.]
2121
+ [1. 1. 1. 1.]
2122
+ [1. 1. 1. 1.]
2123
+ [1. 1. 1. 1.]]]]
2124
+ """
2125
+
2126
+ @prim_attr_register
2127
+ def __init__(self, stride, dilation, pad_mode="SAME", data_format="NCHW"):
2128
+ """Initialize Dilation2DBackpropInput"""
2129
+
2130
+ def _check_format_stride_or_dilation(arg_name, arg_value, prim_name, data_format):
2131
+ validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
2132
+ if isinstance(arg_value, int):
2133
+ ret_value = (1, arg_value, arg_value, 1) if data_format == "NHWC" else (1, 1, arg_value, arg_value)
2134
+ elif len(arg_value) == 2:
2135
+ ret_value = (1, arg_value[0], arg_value[1], 1) if data_format == "NHWC" else \
2136
+ (1, 1, arg_value[0], arg_value[1])
2137
+ elif len(arg_value) == 4:
2138
+ if data_format == "NHWC" and (arg_value[0] != 1 or arg_value[3] != 1):
2139
+ raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be "
2140
+ f"[1, {arg_name}_height, {arg_name}_weigth, 1] when data_format is 'NHWC', "
2141
+ f"but got {arg_value}")
2142
+ if data_format == "NCHW" and (arg_value[0] != 1 or arg_value[1] != 1):
2143
+ raise ValueError(
2144
+ f"For '{prim_name}' attr '{arg_name}' should be [1, 1, {arg_name}_height, {arg_name}_weigth]"
2145
+ f"when data_format is 'NCHW', but got {arg_value}")
2146
+ ret_value = arg_value
2147
+ else:
2148
+ raise ValueError(
2149
+ f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2150
+ f"or four positive int numbers, but got {arg_value}")
2151
+ for item in ret_value:
2152
+ if isinstance(item, int) and not isinstance(item, bool) and item > 0:
2153
+ continue
2154
+ raise ValueError(
2155
+ f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2156
+ f"or four positive int numbers, but got {arg_value}")
2157
+ return ret_value
2158
+
2159
+ if data_format == 'NHWC':
2160
+ raise ValueError(f"For '{self.name}', NHWC format is not supported at present.")
2161
+ self.data_format = validator.check_string(self.data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
2162
+ self.add_prim_attr("data_format", self.data_format)
2163
+ self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
2164
+ self.add_prim_attr("pad_mode", self.pad_mode.upper())
2165
+ self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
2166
+ self.add_prim_attr("stride", self.stride)
2167
+ self.dilation = _check_format_stride_or_dilation("dilation", dilation, self.name, self.data_format)
2168
+ self.add_prim_attr("dilation", self.dilation)
2169
+
2170
+
2171
+ class Dilation2DBackpropFilter(Primitive):
2172
+ """
2173
+ Computes the gradient of morphological 2-D dilation with respect to the filter.
2174
+
2175
+ .. warning::
2176
+ This operator is an experimental operator, which has some accuracy problems for some inputs.
2177
+
2178
+ Args:
2179
+ stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
2180
+ the height and width of movement are both strides, a tuple of two int numbers that
2181
+ represent height and width of movement respectively, or a tuple of four int numbers which
2182
+ should be :math:`(1, 1, H_{stride}, W_{stride})`.
2183
+ dilation (Union(int, tuple[int])): The data type is int or a tuple of 2 integers or a tuple of 4 integers.
2184
+ Specifies the dilation rate to use for dilated convolution.
2185
+ If set to be :math:`k > 1`, there will be :math:`k - 1` pixels skipped for each sampling location.
2186
+ Its value must be greater or equal to 1 and bounded by the height and width of the input `x`.
2187
+ pad_mode (str): Specifies padding mode. The optional values are "same", "valid".
2188
+ Default: "same". Both upper and lower case are supported.
2189
+ data_format (str): The format of input and output data. Only NCHW format is supported at present.
2190
+ Default:'NCHW'
2191
+
2192
+ Inputs:
2193
+ - **x** (Tensor) - Input data. A four dimension tensor with float16 or float32 data type. The shape must be
2194
+ :math:`(N, C_{in}, H_{in}, W_{in})`.
2195
+ - **filter** (Tensor) - A three dimension tensor with the same type as input. The shape must be
2196
+ :math:`(C_{in}, H_{filter}, W_{filter})`.
2197
+ - **out_backprop** (Tensor) - The gradients with respect to the output of the convolution.
2198
+ A four dimension tensor with float16 or float32 data type. The shape must be
2199
+ :math:`(N, C_{in}, H_{out}, W_{out})`.
2200
+
2201
+ outputs:
2202
+ Tensor, the gradients with respect to the input of convolution. It has the same shape and type as the input `x`.
2203
+
2204
+ Raises:
2205
+ TypeError: If type of `x` or `filter` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2206
+ int32, int64, float16, float32, float64].
2207
+ TypeError: If type of `out_backprop` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2208
+ int32, int64, float16, float32, float64].
2209
+ TypeError: If `stride` or `dilation` is not an int number or a tuple of two or four int numbers.
2210
+ ValueError: If the length of `stride` or `dilation` is neither two nor four when they are tuples.
2211
+ ValueError: If `stride` or `dilation` is not (1, 1, height, width) when it is a tuple of four int numbers.
2212
+ ValueError: If `stride` is not in the range of [1, 255].
2213
+ ValueError: If `dilation` is less than 1.
2214
+ ValueError: If `pad_mode` is not a str of 'same', 'valid', 'SAME' or 'VALID'.
2215
+ ValueError: If `data_format` is not the str of 'NCHW'.
2216
+
2217
+
2218
+ Supported Platforms:
2219
+ ``Ascend`` ``GPU`` ``CPU``
2220
+
2221
+ Examples:
2222
+ (pad_mode="SAME", data_format="NCHW")
2223
+ >>> x = Tensor(np.ones([2, 3, 4, 4]), mstype.float32)
2224
+ >>> filter = Tensor(np.ones([3,2,2]), mstype.float32)
2225
+ >>> out_backprop = Tensor(np.ones([2,3,2,2]), mstype.float32)
2226
+ >>> dilation_backprop_filter = G.Dilation2DBackpropFilter(stride=2, dilation=1)
2227
+ >>> output = dilation_backprop_filter(x, filter, out_backprop)
2228
+ >>> print(output)
2229
+ [[[8. 8. 8.]
2230
+ [0. 0. 0.]]
2231
+ [[0. 0. 0.]
2232
+ [0. 0. 0.]]]
2233
+ """
2234
+
2235
+ @prim_attr_register
2236
+ def __init__(self, stride, dilation, pad_mode="SAME", data_format="NCHW"):
2237
+ """Initialize Dilation2DBackpropFilter"""
2238
+
2239
+ def _check_format_stride_or_dilation(arg_name, arg_value, prim_name, data_format):
2240
+ validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
2241
+ if isinstance(arg_value, int):
2242
+ ret_value = (1, arg_value, arg_value, 1) if data_format == "NHWC" else (1, 1, arg_value, arg_value)
2243
+ elif len(arg_value) == 2:
2244
+ ret_value = (1, arg_value[0], arg_value[1], 1) if data_format == "NHWC" else \
2245
+ (1, 1, arg_value[0], arg_value[1])
2246
+ elif len(arg_value) == 4:
2247
+ if data_format == "NHWC" and (arg_value[0] != 1 or arg_value[3] != 1):
2248
+ raise ValueError(
2249
+ f"For '{prim_name}' attr '{arg_name}' should be [1, {arg_name}_height, {arg_name}_weigth, 1]"
2250
+ f"when data_format is 'NHWC', but got {arg_value}")
2251
+ if data_format == "NCHW" and (arg_value[0] != 1 or arg_value[1] != 1):
2252
+ raise ValueError(
2253
+ f"For '{prim_name}' attr '{arg_name}' should be [1, 1, {arg_name}_height, {arg_name}_weigth]"
2254
+ f"when data_format is 'NCHW', but got {arg_value}")
2255
+ ret_value = arg_value
2256
+ else:
2257
+ raise ValueError(
2258
+ f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2259
+ f"or four positive int numbers, but got {arg_value}")
2260
+ for item in ret_value:
2261
+ if isinstance(item, int) and not isinstance(item, bool) and item > 0:
2262
+ continue
2263
+ raise ValueError(
2264
+ f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2265
+ f"or four positive int numbers, but got {arg_value}")
2266
+ return ret_value
2267
+
2268
+ if data_format == 'NHWC':
2269
+ raise ValueError(f"For '{self.name}', NHWC format is not supported at present.")
2270
+ self.data_format = validator.check_string(self.data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
2271
+ self.add_prim_attr("data_format", self.data_format)
2272
+ self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
2273
+ self.add_prim_attr("pad_mode", self.pad_mode.upper())
2274
+ self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
2275
+ def is_in_range(x):
2276
+ return 1 <= x <= 255
2277
+ if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
2278
+ raise ValueError(f"For '{self.name}', size of stride is not supported, "
2279
+ f'stride should be in the range of [1, 255], '
2280
+ f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
2281
+ self.add_prim_attr("stride", self.stride)
2282
+ self.dilation = _check_format_stride_or_dilation("dilation", dilation, self.name, self.data_format)
2283
+ self.add_prim_attr("dilation", self.dilation)
2284
+
2285
+
2286
+ class ParallelResizeBilinearGrad(PrimitiveWithInfer):
2287
+ """ParallelResizeBilinearGrad ops"""
2288
+
2289
+ @prim_attr_register
2290
+ def __init__(self, ori_image_size, src_start_w, dst_start_w, align_corners):
2291
+ """Initialize ParallelResizeBilinearGrad."""
2292
+ self.init_prim_io_names(inputs=["grad", "x", "size"], outputs=['y'])
2293
+ validator.check_value_type("ori_image_size", ori_image_size, [tuple, list], self.name)
2294
+ validator.check_value_type("src_start_w", src_start_w, [int], self.name)
2295
+ validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
2296
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
2297
+ self.ori_image_size = list(ori_image_size)
2298
+ self.src_start_w = src_start_w
2299
+ self.dst_start_w = dst_start_w
2300
+ self.align_corners = align_corners
2301
+ self.half_pixel_centers = False
2302
+ self.add_prim_attr('ori_image_size', self.ori_image_size)
2303
+ self.add_prim_attr('src_start_w', self.src_start_w)
2304
+ self.add_prim_attr('dst_start_w', self.dst_start_w)
2305
+ self.add_prim_attr('align_corners', self.align_corners)
2306
+ self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
2307
+
2308
+ def __infer__(self, grad, x, size):
2309
+ size_val = size['value']
2310
+ grad_shape = grad['shape']
2311
+ grad_dtype = grad['dtype']
2312
+ x_shape = x['shape']
2313
+ x_dtype = x['dtype']
2314
+ validator.check_tensor_dtype_valid("grad_dtype", grad_dtype, [mstype.float16, mstype.float32], self.name)
2315
+ validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
2316
+ if size_val is None:
2317
+ raise ValueError("size must be const input")
2318
+ output_shape = [grad_shape[0], grad_shape[1], x_shape[2], x_shape[3]]
2319
+
2320
+ return {'shape': output_shape,
2321
+ 'dtype': x_dtype,
2322
+ 'value': None}
2323
+
2324
+
2325
+ class MultiMarginLossGrad(Primitive):
2326
+ """
2327
+ Compute the gradients of MultiMarginLoss operation
2328
+
2329
+ Args:
2330
+ p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
2331
+ margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
2332
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2333
+ ``'sum'`` . Default: ``'mean'`` .
2334
+
2335
+ - ``'none'``: no reduction will be applied.
2336
+ - ``'mean'``: compute and return the weighted mean of elements in the output.
2337
+ - ``'sum'``: the output elements will be summed.
2338
+
2339
+ Inputs:
2340
+ - **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
2341
+ Data type only support float32 or float16,float64.
2342
+ - **x** (Tensor) - Input x, with shape :math:`(N, C)`. Data type only support float32, float16 or float64.
2343
+ - **target** (Tensor) - Ground truth labels, with shape :math:`(N,)`. Data type only support int64. The
2344
+ value of target should be non-negative, less than C.
2345
+ - **weight** (Tensor, optional) - The rescaling weight to each class with shape :math:`(C,)`. Data type only
2346
+ support float32, float16 or float64. Default: ``None``.
2347
+
2348
+ Outputs:
2349
+ The shape of output :math:`(N, C)`. Data type only support float32 or float16, float64.
2350
+ Has the same data type with 'x'.
2351
+
2352
+ Raises:
2353
+ TypeError: If dtype of `p` and `target` is not int.
2354
+ TypeError: If dtype of `margin` is not float.
2355
+ TypeError: If dtype of `reduction` is not str.
2356
+ TypeError: If dtype of `x` is not float16, float or float64.
2357
+ TypeError: If dtype of `weight` and `x` is not the same.
2358
+ ValueError: If 'p' is not 1 or 2.
2359
+ ValueError: If 'reduction' is not one of {'none','sum','mean'}.
2360
+ ValueError: If shape[0] of `x` is not equal to shape[0] of `target`.
2361
+ ValueError: If shape[1] of `x` is not equal to shape[0] of `weight`.
2362
+ ValueError: IF rank of `weight` is not 1.
2363
+ ValueError: If rank of `x` is not 2 or rank of 'target' is not 1.
2364
+
2365
+ Supported Platforms:
2366
+ ``Ascend`` ``CPU``
2367
+ """
2368
+ __mindspore_signature__ = (
2369
+ sig.make_sig('y_grad'),
2370
+ sig.make_sig('x'),
2371
+ sig.make_sig('target'),
2372
+ sig.make_sig('weight', default=None)
2373
+ )
2374
+
2375
+ @prim_attr_register
2376
+ def __init__(self, p=1, margin=1.0, reduction="mean"):
2377
+ """Initialize MultiMarginLossGrad"""
2378
+ self.p = validator.check_value_type('p', p, [int], self.name)
2379
+ validator.check_int(p, {1, 2}, validator.IN, 'p', self.name)
2380
+ self.margin = validator.check_value_type('margin', margin, [float], self.name)
2381
+ self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2382
+ self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'weight'], outputs=['x_grad'])
2383
+
2384
+ def __call__(self, y_grad, x, target, weight=None):
2385
+ return super().__call__(y_grad, x, target, weight)
2386
+
2387
+
2388
+ class SparseSegmentMeanGrad(Primitive):
2389
+ """
2390
+ Compute gradients for SparseSegmentMeanGrad operation.
2391
+
2392
+ Inputs:
2393
+ - **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanGrad.
2394
+ - **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
2395
+ types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`.
2396
+ - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one of the
2397
+ following types: int32, int64. Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
2398
+ - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentMean op.
2399
+
2400
+ Outputs:
2401
+ A Tensor. Has the same type as `x` .
2402
+ Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
2403
+
2404
+ Raises:
2405
+ TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
2406
+ TypeError: If the dtype of `x` is not any of the following data types: {float32, float64}.
2407
+ TypeError: If the dtype of `indices` is not int32.
2408
+ TypeError: If the dtype of `segment_ids` is not int32.
2409
+ TypeError: If the dtype of `output_dim0` is not int32.
2410
+ ValueError: If dimension size of `x` is less than 1.
2411
+ ValueError: If rank of `indices` or `segment_ids` is not 1.
2412
+ ValueError: If dimension size of `output_dim0` is not 0.
2413
+ ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`.
2414
+ ValueError: If `segment_ids` is not sorted.
2415
+ ValueError: If `indices` is out of range of `output_dim0`.
2416
+
2417
+ Supported Platforms:
2418
+ ``Ascend`` ``GPU`` ``CPU``
2419
+ """
2420
+
2421
+ @prim_attr_register
2422
+ def __init__(self):
2423
+ """Initialize SparseSegmentMeanGrad"""
2424
+ self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
2425
+
2426
+
2427
+ class FractionalMaxPoolGrad(Primitive):
2428
+ """Computes gradients for FractionalMaxPool operation."""
2429
+
2430
+ @prim_attr_register
2431
+ def __init__(self, overlapping=False):
2432
+ self.init_prim_io_names(inputs=["orig_input", "orig_output", "out_backprop",
2433
+ "row_pooling_sequence", "col_pooling_sequence"],
2434
+ outputs=["y"])
2435
+ validator.check_value_type("overlapping", overlapping, [bool], self.name)
2436
+
2437
+
2438
+ class FractionalMaxPool3DGradWithFixedKsize(Primitive):
2439
+ """Computes gradients for FractionalMaxPool3DWithFixedKsize operation."""
2440
+
2441
+ @prim_attr_register
2442
+ def __init__(self, data_format="NCDHW"):
2443
+ self.init_prim_io_names(inputs=["origin_input", "out_backprop", "argmax"], outputs=["y"])
2444
+ self.data_format = validator.check_string(data_format, ['NCDHW', "NDHWC"], 'data_format', self.name)
2445
+
2446
+
2447
+ class MaxUnpool2DGrad(Primitive):
2448
+ r"""
2449
+ Gradients for MaxUnpool2D operation.
2450
+ """
2451
+
2452
+ @prim_attr_register
2453
+ def __init__(self, ksize, strides=0, pads=0, output_shape=(), data_format="NCHW"):
2454
+ """Initialize MaxUnpool2DGrad."""
2455
+ self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
2456
+ validator.check_value_type("ksize", ksize, [int, tuple], self.name)
2457
+ validator.check_value_type("strides", strides, [int, tuple], self.name)
2458
+ validator.check_value_type("pads", pads, [int, tuple], self.name)
2459
+ validator.check_value_type("output_shape", output_shape, [tuple], self.name)
2460
+ validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
2461
+ validator.check_int(len(ksize), 4, validator.EQ, "ksize rank", self.name)
2462
+ validator.check_int(len(strides), 4, validator.EQ, "strides rank", self.name)
2463
+ validator.check_int(len(pads), 4, validator.EQ, "pads rank", self.name)
2464
+
2465
+
2466
+ class MaxUnpool3DGrad(Primitive):
2467
+ r"""
2468
+ Gradients for MaxUnpool3D operation.
2469
+ """
2470
+
2471
+ @prim_attr_register
2472
+ def __init__(self, ksize, strides=0, pads=0, output_shape=(), data_format="NCDHW"):
2473
+ """Initialize MaxUnpool3DGrad."""
2474
+ self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
2475
+ validator.check_value_type("ksize", ksize, [int, tuple], self.name)
2476
+ validator.check_value_type("strides", strides, [int, tuple], self.name)
2477
+ validator.check_value_type("pads", pads, [int, tuple], self.name)
2478
+ validator.check_value_type("output_shape", output_shape, [tuple], self.name)
2479
+ validator.check_string(data_format, ['NCDHW', 'NDHWC'], 'data_format', self.name)
2480
+ validator.check_int(len(ksize), 5, validator.EQ, "ksize rank", self.name)
2481
+ validator.check_int(len(strides), 5, validator.EQ, "strides rank", self.name)
2482
+ validator.check_int(len(pads), 5, validator.EQ, "pads rank", self.name)
2483
+
2484
+
2485
+ class FractionalAvgPoolGrad(Primitive):
2486
+ """Computes gradients for FractionalAvgPool operation."""
2487
+
2488
+ @prim_attr_register
2489
+ def __init__(self, overlapping=False):
2490
+ self.add_prim_attr("max_length", 1000000)
2491
+ self.init_prim_io_names(inputs=["orig_input_tensor_shape", "out_backprop", "row_pooling_sequence",
2492
+ "col_pooling_sequence"],
2493
+ outputs=["y"])
2494
+ validator.check_value_type("overlapping", overlapping, [bool], self.name)
2495
+
2496
+
2497
+ class PSROIPoolingGrad(Primitive):
2498
+ """Computes gradients for PSROIPooling operation."""
2499
+
2500
+ @prim_attr_register
2501
+ def __init__(self, input_size, spatial_scale, group_size, output_dim):
2502
+ """Initialize PSROIPoolingGrad."""
2503
+ self.init_prim_io_names(inputs=["x", "rois"], outputs=['y'])
2504
+ validator.check_value_type("input_size", input_size, [int, tuple], self.name)
2505
+ validator.check_positive_float(spatial_scale, "spatial_scale", self.name)
2506
+ validator.check_positive_int(group_size, "group_size", self.name)
2507
+ validator.check_positive_int(output_dim, "output_dim", self.name)
2508
+
2509
+ if isinstance(input_size, int):
2510
+ self.input_size = [input_size, input_size]
2511
+ else:
2512
+ self.input_size = list(input_size)
2513
+
2514
+ validator.check_positive_int_sequence(self.input_size, "input_size", self.name)
2515
+ self.spatial_scale = spatial_scale
2516
+ self.group_size = group_size
2517
+ self.output_dim = output_dim
2518
+
2519
+ self.add_prim_attr('input_size', self.input_size)
2520
+ self.add_prim_attr('spatial_scale', self.spatial_scale)
2521
+ self.add_prim_attr('group_size', self.group_size)
2522
+ self.add_prim_attr('output_dim', self.output_dim)
2523
+
2524
+
2525
+ class AdaptiveMaxPool3DGrad(Primitive):
2526
+ """Computes gradients for AdaptiveMaxPool3D operation."""
2527
+
2528
+ @prim_attr_register
2529
+ def __init__(self):
2530
+ """Initialize AdaptiveMaxPool3DGrad"""
2531
+ self.init_prim_io_names(inputs=['input_grad', 'x', 'argmax'], outputs=['output_grad'])
2532
+
2533
+
2534
+ class TraceGrad(Primitive):
2535
+ """
2536
+ Computes grad for Trace operation.
2537
+
2538
+ Inputs:
2539
+ - **y_grad** (Tensor) - the grad of trace to output of Trace function.
2540
+ Currently grad data type support float16, float32, int8, int16, int32, int64,
2541
+ uint8, uint16, uint32, uint64, float64.
2542
+ - **x_shape** (Tensor) - the shape of trace to output of Trace function.
2543
+ Currently shape data type support int32, int64.
2544
+
2545
+ Outputs:
2546
+ x_grad - Tensor, with the same data type as 'y_grad' and shape is x_shape.
2547
+
2548
+ Raises:
2549
+ TypeError: If `x_shape` is not a Tensor.
2550
+ TypeError: If the dtype of `x_shape` is neither int32 nor int64.
2551
+ ValueError: If `x_shape` is not a 1D Tensor.
2552
+ ValueError: If length of shape of `x_shape` is not equal to 2.
2553
+
2554
+ Support Platforms:
2555
+ ``Ascend`` ``GPU`` ``CPU``
2556
+ """
2557
+
2558
+ @prim_attr_register
2559
+ def __init__(self):
2560
+ self.init_prim_io_names(inputs=['y_grad', 'x_shape'], outputs=['x_grad'])
2561
+
2562
+
2563
+ class IgammaGradA(Primitive):
2564
+ r"""
2565
+ Computes the gradient of igamma(a, x) wrt a.
2566
+
2567
+ Inputs:
2568
+ - **a** (Tensor) - The input tensor. With float32 or float64 data type.
2569
+ - **x** (Tensor) - The input tensor. With float32 data or float64 type. `x` should have
2570
+ the same dtype with `a`.
2571
+
2572
+ Outputs:
2573
+ Tensor, has the same dtype as `a` and `x`.
2574
+
2575
+ Raises:
2576
+ TypeError: If a or grad is not a Tensor.
2577
+ TypeError: If dtype of input x and a is not float32 nor float64.
2578
+ TypeError: If x has different dtype with a.
2579
+ ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
2580
+
2581
+ Supported Platforms:
2582
+ ``Ascend`` ``GPU`` ``CPU``
2583
+
2584
+ Examples:
2585
+ >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
2586
+ >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
2587
+ >>> igammagrada = G.IgammaGradA()
2588
+ >>> output = igammagrada(a, x)
2589
+ >>> print (output)
2590
+ [-0.2940046 -0.20153049 -0.13028376 -0.08352186]
2591
+ """
2592
+
2593
+ @prim_attr_register
2594
+ def __init__(self):
2595
+ """Initialize IgammaGradA"""
2596
+ self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
2597
+
2598
+
2599
+ class DeformableOffsetsGrad(Primitive):
2600
+ r"""
2601
+ Computes gradients of DeformableOffsets operation.
2602
+ Args:
2603
+ strides (tuple[int, int ,int ,int]): A tuple of 4 integers. The stride of sliding windows for height
2604
+ and width for H/W dimension.
2605
+ pads (tuple[int, int ,int ,int]): A tuple of 4 integers.Padding added to H/W dimension of the input.The number
2606
+ of pixels to add to each (top, bottom, left,right) side of the input
2607
+ kernel_size (tuple[int, int]): Kernel size, a tuple of 2 integers.
2608
+ dilations (tuple[int, int, int, int]): A tuple of 4 integers. The dilation factor for each dimension of
2609
+ input. Default:(1, 1, 1, 1)
2610
+ data_format (str): An optional string from:"NCHW", "NHWC".Specify the data format of the input x. Default:
2611
+ "NCHW".
2612
+ deformable_groups (int): Specify the C-axis grouping number of input x. Default: 1.
2613
+ modulated (bool): Specify version of DeformableOffsetsGrad, true means v2, false means v1. Default: ``True``.
2614
+
2615
+ Inputs:
2616
+ - **grad** (Tensor) - The input grad tensor. With float16 or float32 data type.
2617
+ - **x** (Tensor) - The input `x` of DeformableOffsets with data type of float16 or float32.
2618
+ - **offsets** (Tensor) - The input 'offsets' of DeformableOffsets with data type of float16 or float32.
2619
+
2620
+ Outputs:
2621
+ - **grad_x** (Tensor) - The output grad of input `x`. With same dtype and shape of input `x`.
2622
+ - ""grad_offsets** (Tensor) - The output grad of input `offsets`. With same dtype and shape of input `offsets`.
2623
+
2624
+ Supported Platforms:
2625
+ ``Ascend````GPU````CPU``
2626
+ """
2627
+
2628
+ @prim_attr_register
2629
+ def __init__(self,
2630
+ strides,
2631
+ pads,
2632
+ kernel_size,
2633
+ dilations=(1, 1, 1, 1),
2634
+ data_format="NCHW",
2635
+ deformable_groups=1,
2636
+ modulated=True):
2637
+ """Initialize DeformableOffsetsGrad"""
2638
+ self.init_prim_io_names(inputs=['out_backprop', 'input', 'offsets'], outputs=['out_grad'])
2639
+
2640
+ self.strides = _check_positive_int_or_tuple('strides', strides, self.name, allow_four=True, ret_four=True)
2641
+ self.add_prim_attr('strides', self.strides)
2642
+
2643
+ self.pads = pads
2644
+ self.add_prim_attr('pads', self.pads)
2645
+
2646
+ self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name, allow_four=True,
2647
+ ret_four=False)
2648
+ self.add_prim_attr('ksize', self.kernel_size)
2649
+
2650
+ self.dilations = _check_positive_int_or_tuple('dilations', dilations, self.name, allow_four=True, ret_four=True)
2651
+ self.add_prim_attr('dilations', dilations)
2652
+
2653
+ self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
2654
+ self.add_prim_attr('data_format', self.data_format)
2655
+
2656
+ self.deformable_groups = validator.check_positive_int(deformable_groups, 'deformable_groups', self.name)
2657
+ self.add_prim_attr('deformable_groups', self.deformable_groups)
2658
+
2659
+ self.modulated = validator.check_bool(modulated, 'modulated', self.name)
2660
+ self.add_prim_attr('modulated', self.modulated)
2661
+
2662
+
2663
+ class MedianGrad(Primitive):
2664
+ """
2665
+ Computes gradient for Median operation.
2666
+
2667
+ .. warning::
2668
+ When attr `global_median` is True, the value of Median's second output Tensor `indices` value is meaningless.
2669
+
2670
+ Args:
2671
+ global_median (bool): Whether the output tensor is the global median of all input tensor elements
2672
+ or not in Median operation.
2673
+ axis (int): The dimension need to reduce in Median operation.
2674
+ keep_dims (bool): Whether the output tensor need to retain `axis` dimension or not in Median operation.
2675
+
2676
+ Inputs:
2677
+ - **y_grad** (Tensor) - The gradients of loss to output of Median function.
2678
+ - **x** (Tensor) - The first input is a tensor whose data type is number.
2679
+ The dtype is one of the following: int16, int32, int64, float32, double.
2680
+ - **y** (Tensor) - The first output of Median function, which datatype is same as `x`.
2681
+ - **indices** (Tensor) - The second output of Median function, which datatype is int64.
2682
+
2683
+ Outputs:
2684
+ x_grad - Tensor, has the same shape as the `x`, dtype is double only when dtype of `x` is double.
2685
+ Otherwise, dtype of `x_grad` is float32.
2686
+
2687
+ Raises:
2688
+ TypeError: If dtype of `y_grad` is not the same as `x`.
2689
+ ValueError: If shape of `y_grad` is not the same as `y`.
2690
+
2691
+ Supported Platforms:
2692
+ ``Ascend`` ``CPU``
2693
+ """
2694
+
2695
+ @prim_attr_register
2696
+ def __init__(self, global_median=False, axis=0, keep_dims=False):
2697
+ validator.check_value_type("global_median", global_median, [bool], self.name)
2698
+ self.global_median = global_median
2699
+ if global_median is False:
2700
+ validator.check_value_type("axis", axis, [int], self.name)
2701
+ validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
2702
+ self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad'])
2703
+
2704
+
2705
+ class SparseSegmentSumGrad(Primitive):
2706
+ """
2707
+ Computes gradients for SparseSegmentSumGrad operation.
2708
+
2709
+ Inputs:
2710
+ - **grad** (Tensor) - A tensor.
2711
+ - **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
2712
+ Has same rank as segment_ids. The shape should be :math:`(N,)`.
2713
+ - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
2714
+ Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
2715
+ - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSum op.
2716
+
2717
+ Outputs:
2718
+ A Tensor. Has the same type as `grad` .
2719
+ Has same shape as `grad`, except for dimension 0 which is the value of `output_dim0`.
2720
+
2721
+ Raises:
2722
+ TypeError: If `grad` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
2723
+ TypeError: If the dtype of `grad` is not any of the following data types: {float16, float32, float64}.
2724
+ TypeError: If the dtype of `indices` and `segment_ids` and `output_dim0` is not int32 or int64.
2725
+ ValueError: If dimension size of `grad` less than 1.
2726
+ ValueError: If rank of `indices` or `segment_ids` is not 1.
2727
+ ValueError: If dimension size of `output_dim0` is not 0.
2728
+ ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
2729
+ ValueError: If `segment_ids` is not sorted.
2730
+ ValueError: If the last number of `segment_ids` is out of range of grad's first shape.
2731
+ ValueError: If `indices` is bigger than or equal to `output_dim0`.
2732
+
2733
+ Supported Platforms:
2734
+ ``GPU``
2735
+ """
2736
+ __mindspore_signature__ = (
2737
+ sig.make_sig('grad', dtype=sig.sig_dtype.T1),
2738
+ sig.make_sig('indices', dtype=sig.sig_dtype.T),
2739
+ sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
2740
+ sig.make_sig('output_dim0', dtype=sig.sig_dtype.T)
2741
+ )
2742
+
2743
+ @prim_attr_register
2744
+ def __init__(self):
2745
+ """Initialize SparseSegmentSumGrad"""
2746
+ self.init_prim_io_names(inputs=['grad', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
2747
+
2748
+
2749
+ class SparseSegmentSqrtNGrad(Primitive):
2750
+ """
2751
+ Computes gradients for SparseSegmentSqrtNGrad operation.
2752
+
2753
+ Inputs:
2754
+ - **x** (Tensor) - A tensor. It's rank must be more than or equal to one.
2755
+ - **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
2756
+ types: int32, int64. Has same rank as segment_ids. The shape should be :math:`(N,)`.
2757
+ - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one
2758
+ of the following types: int32, int64. Values should be sorted and can be repeated. The shape should
2759
+ be :math:`(N,)`.
2760
+ - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSqrtN op.
2761
+
2762
+ Outputs:
2763
+ A Tensor. Has the same type as `x` .
2764
+ Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
2765
+
2766
+ Raises:
2767
+ TypeError: If `x` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
2768
+ TypeError: If the dtype of `x` is not any of the following data types: {float16, float32, float64}.
2769
+ TypeError: If the dtype of `indices` is not int32.
2770
+ TypeError: If the dtype of `segment_ids` is not int32.
2771
+ TypeError: If the dtype of `output_dim0` is not int32.
2772
+ ValueError: If dimension size of `x` is less than 1.
2773
+ ValueError: If rank of `indices` or `segment_ids` is not 1.
2774
+ ValueError: If dimension size of `output_dim0` is not 0.
2775
+ ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
2776
+ ValueError: If `segment_ids` is not sorted.
2777
+ ValueError: If the last number of `segment_ids` is out of range of x's first shape.
2778
+ ValueError: If `indices` is bigger than or equal to `output_dim0`.
2779
+
2780
+ Supported Platforms:
2781
+ ``Ascend`` ``GPU`` ``CPU``
2782
+ """
2783
+
2784
+ @prim_attr_register
2785
+ def __init__(self):
2786
+ """Initialize SparseSegmentSqrtNGrad"""
2787
+ self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
2788
+
2789
+
2790
+ class SparseSliceGrad(Primitive):
2791
+ r"""
2792
+ Computes gradients for SparseSlice operation.
2793
+
2794
+ Inputs:
2795
+ - **backprop_val_grad** (Tensor) - A 1D Tensor.
2796
+ The shape should be :math:`(N,)`.
2797
+ - **indices** (Tensor) - A 2D Tensor (N x R matrix) of type int64. The indices of the SparseTensor.
2798
+ Support int64, each element value should be a non-negative int number. This tensor should be sorted.
2799
+ The shape is :math:`(N, R)`.
2800
+ - **start** (Tensor) - A 1D Tensor of type int64, represents the start of the indices.
2801
+ The shape should be :math:`(R,)`.
2802
+ - **new_indices** (Tensor) - A 2D Tensor (N x C matrix) of type int64. The indices of the SparseTensor.
2803
+ Support int64, each element value should be a non-negative int number. This tensor should be sorted.
2804
+ The shape is :math:`(N, C)`.
2805
+
2806
+ Outputs:
2807
+ - *y_grad_val: A Tensor. Has the same type as `backprop_val_grad`.
2808
+ Has the same number as `indices`.
2809
+
2810
+ Raises:
2811
+ TypeError: If the dtype of `indices`, `start`, `new_indices` are not int64.
2812
+ ValueError: If `indices`, `new_indices` are not 2-D tensor.
2813
+ ValueError: If `backprop_val_grad`, `start` is not a 1-D tensor.
2814
+ ValueError: If the number of `backprop_val_grad` is not corresponding to the number of `new_indices`.
2815
+ ValueError: If the shape of `indices[1]` is not corresponding to `start[1]`.
2816
+ ValueError: If the shape of `indices[1]` is not corresponding to `new_indices[1]`.
2817
+ RuntimeError: If the `backprop_val_grad` is not all backpropagated, because `indices` or `new_indices`
2818
+ is not sorted.
2819
+
2820
+ Supported Platforms:
2821
+ ``Ascend`` ``GPU`` ``CPU``
2822
+ Examples:
2823
+ >>> backprop_val_grad = Tensor(np.array([1, 2, 3, 4]).astype(np.int64))
2824
+ >>> indices = Tensor(np.array([[0, 0], [0, 2], [1, 2], [1, 3], [2, 3], [2, 4]]).astype(np.int64))
2825
+ >>> start = Tensor(np.array([0, 0]).astype(np.int64))
2826
+ >>> new_indices = Tensor(np.array([[0, 2], [1, 2], [1, 3], [2, 4]]).astype(np.int64))
2827
+ >>> grad = SparseSliceGrad()
2828
+ >>> output = grad(backprop_val_grad, indices, start, new_indices)
2829
+ >>> print(output)
2830
+ [0 1 2 3 0 4]
2831
+ """
2832
+
2833
+ @prim_attr_register
2834
+ def __init__(self):
2835
+ """Initialize SparseSliceGrad."""
2836
+ self.init_prim_io_names(inputs=['backprop_val_grad', 'indices', 'start', 'new_indices'], outputs=['y_grad'])
2837
+
2838
+
2839
+ class FractionalMaxPoolGradWithFixedKsize(Primitive):
2840
+ """
2841
+ Computes the gradients of FractionalMaxPoolWithFixedKsize.
2842
+
2843
+ Args:
2844
+ data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW".
2845
+
2846
+ Inputs:
2847
+ - **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be int32 or int64.
2848
+ - **out_backprop** (Tensor) - The gradients with respect to the output of FractionalMaxPoolWithFixedKsize
2849
+ function. Tensor with data format "NCHW", whose data type is float16, float32, float64, int32 or int64.
2850
+ - **argmax** (Tensor) - The second output of FractionalMaxPoolWithFixedKsize function, whose data
2851
+ type is int64.
2852
+
2853
+ Outputs:
2854
+ - **y** (Tensor) - Tensor, with the same shape as `origin_input`, and the same data type as
2855
+ the input `out_backprop`.
2856
+
2857
+ Raises:
2858
+ TypeError: If data type of `out_backprop` is not one of the following: float16, float32, float64, int32, int64.
2859
+ TypeError: If data type of `argmax` is not int64.
2860
+ ValueError: If the shape of `out_backprop` and `argmax` is not equal.
2861
+ ValueError: If the first dimension size of `origin_input` and `out_backprop` is not equal.
2862
+ ValueError: If the second dimension size of `origin_input` and `out_backprop` is not equal.
2863
+
2864
+ Supported Platforms:
2865
+ ``Ascend`` ``GPU`` ``CPU``
2866
+ """
2867
+
2868
+ @prim_attr_register
2869
+ def __init__(self, data_format="NCHW"):
2870
+ self.data_format = validator.check_string(data_format, ['NCHW'], 'data_format', self.name)
2871
+ self.add_prim_attr("data_format", self.data_format)
2872
+ self.init_prim_io_names(inputs=['origin_input', 'out_backprop', 'argmax'], outputs=['y'])
2873
+
2874
+
2875
+ class AffineGridGrad(Primitive):
2876
+ r"""
2877
+ Computes gradients for AffineGrid operation.
2878
+
2879
+ Args:
2880
+ align_corners (bool): if True, consider -1 and 1 to refer to the centers
2881
+ of the corner pixels rather than the image corners. Default: ``False``.
2882
+
2883
+ Inputs:
2884
+ - **y_grad** (Tensor) - Data type must be float16 or float32.
2885
+ - **x_size** (tuple) - Data type must be int32 or int64.
2886
+
2887
+ Outputs:
2888
+ Tensor, with data type same as `y_grad`.
2889
+
2890
+ Supported Platforms:
2891
+ ``CPU``
2892
+
2893
+ Examples:
2894
+ >>> import mindspore.ops.operations._grad_ops as _grad_ops
2895
+ >>> affinegridgrad = _grad_ops.AffineGridGrad()
2896
+ >>> y_grad = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32)
2897
+ >>> x_size = (1, 2, 2, 2)
2898
+ >>> x_grad = affinegridgrad(y_grad, x_size)
2899
+ >>> print(x_grad)
2900
+ [[[0. 0. 4.]
2901
+ [0. 0. 4.]]]
2902
+ """
2903
+
2904
+ @prim_attr_register
2905
+ def __init__(self, align_corners=False):
2906
+ """Initialize AffineGridGrad."""
2907
+ validator.check_value_type("align_corners", align_corners, [bool], self.name)
2908
+ self.init_prim_io_names(inputs=['y_grad', 'x_size'], outputs=['x_grad'])
2909
+
2910
+
2911
+
2912
+ class GluGrad(Primitive):
2913
+ """
2914
+ Computes grad for Glu operation.
2915
+ """
2916
+
2917
+ @prim_attr_register
2918
+ def __init__(self, axis):
2919
+ self.add_prim_attr("cust_aicpu", self.name)
2920
+ self.init_prim_io_names(inputs=["grads", "x"], outputs=["y"])
2921
+ validator.check_value_type("axis", axis, [int], self.name)
2922
+
2923
+
2924
+ class MapTensorGetGrad(Primitive):
2925
+ """
2926
+ Computes gradients for MapTensorGet operation.
2927
+
2928
+ Inputs:
2929
+ - **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
2930
+ - **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
2931
+ - **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
2932
+ - **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
2933
+
2934
+ Outputs:
2935
+ - **output** (MapTensor) - MapTensor with grad values.
2936
+ """
2937
+ @prim_attr_register
2938
+ def __init__(self):
2939
+ """Initialize MapTensorGetGrad"""
2940
+ self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
2941
+ self.add_prim_attr('side_effect_mem', True)
2942
+
2943
+
2944
+ class ResizeV2Grad(Primitive):
2945
+ r"""
2946
+ Calculates the gradient of ResizeV2 operation.
2947
+
2948
+ Supported Platforms:
2949
+ ``CPU``
2950
+ """
2951
+
2952
+ @prim_attr_register
2953
+ def __init__(self, coordinate_transformation_mode="half_pixel", mode="nearest"):
2954
+ """Initialize ResizeV2Grad."""
2955
+ self.init_prim_io_names(inputs=["grads", "roi", "scales", "original_size"], outputs=["y"])
2956
+ self.add_prim_attr("nearest_mode", "floor")
2957
+ self.add_prim_attr("cubic_coeff_a", -0.75)
2958
+ validator.check_value_type(
2959
+ "coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
2960
+ validator.check_string(coordinate_transformation_mode,
2961
+ ["align_corners", "half_pixel"], "coordinate_transformation_mode", self.name)
2962
+ validator.check_value_type("mode", mode, [str], self.name)
2963
+ validator.check_string(mode, ["nearest", "linear", "cubic"], "mode", self.name)
2964
+
2965
+
2966
+ class WKVGrad(Primitive):
2967
+ r"""
2968
+ Calculates the gradient of WKV operation.
2969
+
2970
+ Supported Platforms:
2971
+ ``Ascend``
2972
+ """
2973
+
2974
+ @prim_attr_register
2975
+ def __init__(self):
2976
+ """Initialize WKVGrad."""
2977
+ self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
2978
+ outputs=["gw", "gu", "gk", "gv"])