mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,23 +14,31 @@
14
14
  # ============================================================================
15
15
 
16
16
  """Defines gradient related operators with functional form."""
17
-
18
17
  from __future__ import absolute_import
19
18
  from functools import partial
20
- from mindspore.common import ms_function
19
+ import numpy as np
20
+ from mindspore.common import jit, mutable
21
21
  from mindspore.common import Tensor
22
22
  from mindspore.common import dtype as mstype
23
- from mindspore.nn.grad.cell_grad import _JvpInner
24
- from mindspore.nn.grad.cell_grad import _VjpInner
23
+ from mindspore.nn.cell import Cell
25
24
  from mindspore.nn.grad.cell_grad import _LinearizeInner
26
25
  from mindspore.ops.primitive import constexpr
27
- from mindspore.ops.function import ones, expand_dims
28
- from mindspore.ops.composite import _Grad, _TaylorOperation
26
+ from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose
27
+ from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
29
28
  from mindspore.ops import operations as P
29
+ from mindspore.ops.operations import _inner_ops as inner
30
30
 
31
31
  cast = P.Cast()
32
32
  dtype = P.DType()
33
33
  zeros = P.Zeros()
34
+ oneslike = P.OnesLike()
35
+
36
+
37
+ @constexpr
38
+ def _check_has_aux_type(inputs):
39
+ if not isinstance(inputs, bool):
40
+ raise TypeError("The 'has_aux' must be bool type.")
41
+ return True
34
42
 
35
43
 
36
44
  @constexpr
@@ -38,15 +46,27 @@ def _raise_type_error():
38
46
  raise TypeError("The inputs type must be a Tensor, tuple or list of Tensors.")
39
47
 
40
48
 
49
+ @constexpr
50
+ def _check_duplicate_grad_position(grad_position):
51
+ """Check if `grad_position` has duplicate positions when `grad_position` has more than one numbers."""
52
+ if len(set(grad_position)) != len(grad_position):
53
+ raise ValueError("There are duplicate positions in `grad_position`, please check it")
54
+
55
+
41
56
  @constexpr
42
57
  def _convert_grad_position_type(grad_position):
43
58
  """Check and convert the type and size of grad position index."""
44
59
  if isinstance(grad_position, tuple):
45
- for gp in grad_position:
60
+ _check_duplicate_grad_position(grad_position)
61
+ _grad_position = list(grad_position)
62
+ for i, gp in enumerate(_grad_position):
63
+ if isinstance(gp, bool):
64
+ _grad_position[i] = int(gp)
46
65
  if not isinstance(gp, int):
47
66
  raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
48
67
  if gp < 0:
49
68
  raise ValueError("The element in grad_position must be >= 0.")
69
+ grad_position = tuple(_grad_position)
50
70
  elif isinstance(grad_position, int):
51
71
  if grad_position < 0:
52
72
  raise ValueError("grad_position must be >= 0.")
@@ -57,11 +77,22 @@ def _convert_grad_position_type(grad_position):
57
77
 
58
78
 
59
79
  @constexpr
60
- def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False):
61
- return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value)
80
+ def _check_grad_position(grad_position, args_num):
81
+ """Check and convert grad position index."""
82
+ grad_position = _convert_grad_position_type(grad_position)
83
+ for gp in grad_position:
84
+ if gp < 0 or gp >= args_num:
85
+ raise ValueError("The element in grad_position must belong to [0, args_num).")
86
+ return grad_position
87
+
88
+
89
+ @constexpr
90
+ def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False, return_ids=False):
91
+ return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value,
92
+ return_ids=return_ids)
62
93
 
63
94
 
64
- def grad(fn, grad_position=0, weights=None, has_aux=False):
95
+ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
65
96
  """
66
97
  A wrapper function to generate the gradient function for the input function.
67
98
 
@@ -84,11 +115,19 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
84
115
  has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
85
116
  will be returned straightly. It means the `fn` must return more than one outputs in this case.
86
117
  Default: False.
118
+ return_ids(bool): Whether return the tuple made by gradients and the index to specify which inputs
119
+ to be differentiated or the name of parameters of the training network that need to calculate the gradient.
120
+ If True, the output gradients will be replaced by the tuples made by gradients and the index to specify
121
+ which inputs to be differentiated or the name of parameters of the training network.
122
+ Default: False.
87
123
 
88
124
  Returns:
89
125
  Function, the gradient function to calculate gradient for the input function or cell.
90
126
  For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
91
127
  like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
128
+ When return_ids is set to True, The format of the output will be the same with the output of grad when
129
+ return_ids is set to false, but every gradient in the output will be replaced by a tuple of position id or
130
+ parameter name and its gradient.
92
131
 
93
132
  Raises:
94
133
  ValueError: If both `grad_position` and `weights` are None.
@@ -102,7 +141,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
102
141
  >>> import mindspore
103
142
  >>> import mindspore.nn as nn
104
143
  >>> from mindspore import Tensor, ops
105
- >>> from mindspore.ops import grad
144
+ >>> from mindspore import grad
106
145
  >>>
107
146
  >>> # Cell object to be differentiated
108
147
  >>> class Net(nn.Cell):
@@ -131,7 +170,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
131
170
  >>> print(aux)
132
171
  (Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),)
133
172
  >>>
134
- >>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
173
+ >>> # For given network to be differentiated with both inputs and weights, there are 4 cases.
135
174
  >>> net = nn.Dense(10, 1)
136
175
  >>> loss_fn = nn.MSELoss()
137
176
  >>> def forward(inputs, labels):
@@ -163,17 +202,36 @@ def grad(fn, grad_position=0, weights=None, has_aux=False):
163
202
  >>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
164
203
  >>> print(len(weights), len(params_gradient))
165
204
  2 2
205
+ >>> # Case 4: return the gradient with ids.
206
+ >>> import numpy as np
207
+ >>> import mindspore
208
+ >>> import mindspore.nn as nn
209
+ >>> from mindspore import Tensor, ops
210
+ >>> from mindspore import grad
211
+ >>>
212
+ >>> # Cell object to be differentiated
213
+ >>> class Net(nn.Cell):
214
+ ... def construct(self, x, y, z):
215
+ ... return x * y * z
216
+ >>> x = Tensor([1, 2], mindspore.float32)
217
+ >>> y = Tensor([-2, 3], mindspore.float32)
218
+ >>> z = Tensor([0, 3], mindspore.float32)
219
+ >>> net = Net()
220
+ >>> output = grad(net, grad_position=(1, 2), return_ids = True)(x, y, z)
221
+ >>> print(output)
222
+ ((1, Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00, 6.00000000e+00])),
223
+ (2, Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00, 6.00000000e+00])))
166
224
  """
167
225
  if grad_position is None and weights is None:
168
226
  raise ValueError("`grad_position` and `weight` can not be None at the same time.")
169
227
 
170
228
  if grad_position is None:
171
- return _get_grad_op(True, False, has_aux)(fn, weights)
229
+ return _get_grad_op(True, False, has_aux, False, return_ids)(fn, weights)
172
230
 
173
231
  grad_position = _convert_grad_position_type(grad_position)
174
232
  if weights is None:
175
- return _get_grad_op(False, True, has_aux)(fn, None, grad_position)
176
- return _get_grad_op(True, True, has_aux)(fn, weights, grad_position)
233
+ return _get_grad_op(False, True, has_aux, False, return_ids)(fn, None, grad_position)
234
+ return _get_grad_op(True, True, has_aux, False, return_ids)(fn, weights, grad_position)
177
235
 
178
236
 
179
237
  def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
@@ -216,7 +274,7 @@ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
216
274
  >>> import numpy as np
217
275
  >>> import mindspore
218
276
  >>> from mindspore import Tensor, ops, nn
219
- >>> from mindspore.ops import value_and_grad
277
+ >>> from mindspore import value_and_grad
220
278
  >>>
221
279
  >>> # Cell object to be differentiated
222
280
  >>> class Net(nn.Cell):
@@ -300,6 +358,55 @@ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
300
358
  return _get_grad_op(True, True, has_aux, True)(fn, weights, grad_position)
301
359
 
302
360
 
361
+ def get_grad(gradients, identifier):
362
+ """
363
+ When `return_ids` of :func:`mindspore.grad` is set to True, use its return value as gradients. Then find
364
+ the specific gradient from `gradients` according to `identifier` .
365
+
366
+ As for gradient, two typical cases are included:
367
+
368
+ 1. `identifier` is the position of the specific tensor to get gradient.
369
+ 2. `identifier` is a parameter of a network.
370
+
371
+ Args:
372
+ gradients (Union[tuple[int, Tensor], tuple[tuple, tuple]]): The return value of :func:`mindspore.grad`
373
+ when `return_ids` is set to True.
374
+ identifier (Union[int, Parameter]): The position number of a tensor, or a parameter that is used in
375
+ :func:`mindspore.grad`.
376
+
377
+ Returns:
378
+ The gradient of the tensor on the position or in the parameter that specified by the `identifier`.
379
+
380
+ Raises:
381
+ RuntimeError: If gradient is not found.
382
+ TypeError: If type of Args does not belong to required ones.
383
+
384
+ Supported Platforms:
385
+ ``Ascend`` ``GPU`` ``CPU``
386
+
387
+ Examples:
388
+ >>> import numpy as np
389
+ >>> import mindspore
390
+ >>> import mindspore.nn as nn
391
+ >>> from mindspore import Tensor, ops
392
+ >>> from mindspore import grad, get_grad
393
+ >>>
394
+ >>> # Cell object to be differentiated
395
+ >>> class Net(nn.Cell):
396
+ ... def construct(self, x, y, z):
397
+ ... return x * y * z
398
+ >>> x = Tensor([1, 2], mindspore.float32)
399
+ >>> y = Tensor([-2, 3], mindspore.float32)
400
+ >>> z = Tensor([0, 3], mindspore.float32)
401
+ >>> net = Net()
402
+ >>> out_grad = grad(net, grad_position=(1, 2), return_ids=True)(x, y, z)
403
+ >>> output = get_grad(out_grad, 1)
404
+ >>> print(output)
405
+ [0. 6.]
406
+ """
407
+ return inner.GetGrad()(gradients, identifier)
408
+
409
+
303
410
  def _trans_jet_inputs(primals_item, series_item):
304
411
  """Trans inputs of jet"""
305
412
  value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
@@ -376,15 +483,14 @@ def jet(fn, primals, series):
376
483
  >>> import numpy as np
377
484
  >>> import mindspore.nn as nn
378
485
  >>> import mindspore as ms
379
- >>> import mindspore.ops as P
486
+ >>> import mindspore.ops as ops
380
487
  >>> from mindspore import Tensor
381
- >>> from mindspore.ops.functional import jet
382
488
  >>> ms.set_context(mode=ms.GRAPH_MODE)
383
489
  >>> class Net(nn.Cell):
384
490
  ... def __init__(self):
385
491
  ... super().__init__()
386
- ... self.sin = P.Sin()
387
- ... self.exp = P.Exp()
492
+ ... self.sin = ops.Sin()
493
+ ... self.exp = ops.Exp()
388
494
  ... def construct(self, x):
389
495
  ... out1 = self.sin(x)
390
496
  ... out2 = self.exp(out1)
@@ -392,7 +498,7 @@ def jet(fn, primals, series):
392
498
  >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
393
499
  >>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
394
500
  >>> net = Net()
395
- >>> out_primals, out_series = jet(net, primals, series)
501
+ >>> out_primals, out_series = ops.jet(net, primals, series)
396
502
  >>> print(out_primals, out_series)
397
503
  [[2.319777 2.4825778]
398
504
  [1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
@@ -487,15 +593,14 @@ def derivative(fn, primals, order):
487
593
  >>> import numpy as np
488
594
  >>> import mindspore as ms
489
595
  >>> import mindspore.nn as nn
490
- >>> import mindspore.ops as P
596
+ >>> import mindspore.ops as ops
491
597
  >>> from mindspore import Tensor
492
- >>> from mindspore.ops.functional import derivative
493
598
  >>> ms.set_context(mode=ms.GRAPH_MODE)
494
599
  >>> class Net(nn.Cell):
495
600
  ... def __init__(self):
496
601
  ... super().__init__()
497
- ... self.sin = P.Sin()
498
- ... self.exp = P.Exp()
602
+ ... self.sin = ops.Sin()
603
+ ... self.exp = ops.Exp()
499
604
  ... def construct(self, x):
500
605
  ... out1 = self.sin(x)
501
606
  ... out2 = self.exp(out1)
@@ -503,7 +608,7 @@ def derivative(fn, primals, order):
503
608
  >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
504
609
  >>> order = 3
505
610
  >>> net = Net()
506
- >>> out_primals, out_series = derivative(net, primals, order)
611
+ >>> out_primals, out_series = ops.derivative(net, primals, order)
507
612
  >>> print(out_primals, out_series)
508
613
  [[2.319777 2.4825778]
509
614
  [1.1515628 0.4691642]] [[-4.0515366 3.6724353 ]
@@ -541,10 +646,20 @@ def derivative(fn, primals, order):
541
646
  return out_primals, out_series
542
647
 
543
648
 
544
- def jvp(fn, inputs, v):
649
+ _grad_single = GradOperation(sens_param=True)
650
+ _grad_all = GradOperation(sens_param=True, get_all=True)
651
+
652
+
653
+ @constexpr
654
+ def _check_jvp_input_v_len(inputs_len, v_len):
655
+ if inputs_len != v_len:
656
+ raise ValueError(f'v has invalid length: should be {inputs_len}, but got {v_len}')
657
+
658
+
659
+ def jvp(fn, inputs, v, has_aux=False):
545
660
  """
546
661
  Compute the jacobian-vector-product of the given network. `jvp` matches
547
- `forward-mode differentiation <https://www.mindspore.cn/docs/en/r1.10/design/auto_gradient.html#forward-mode-ad>`_.
662
+ `forward-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#forward-mode-ad>`_.
548
663
 
549
664
  Args:
550
665
  fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
@@ -552,10 +667,16 @@ def jvp(fn, inputs, v):
552
667
  inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
553
668
  v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in jacobian-vector-product. The shape and type of `v`
554
669
  should be the same as `inputs` .
670
+ has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
671
+ will be returned straightly. It means the `fn` must return more than one outputs in this case.
672
+ Default: False.
555
673
 
556
674
  Returns:
557
- - **net_output** (Union[Tensor, tuple[Tensor]]) - The result of `fn(inputs)` .
675
+ - **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)` . Specially, when `has_aux` is set
676
+ True, `netout` is the first output of `fn(inputs)` .
558
677
  - **jvp** (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.
678
+ - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
679
+ It means the second to last outputs of `fn(inputs)` . Specially, `aux_value` does not contribute to gradient.
559
680
 
560
681
  Raises:
561
682
  TypeError: `inputs` or `v` does not belong to required types.
@@ -564,32 +685,102 @@ def jvp(fn, inputs, v):
564
685
  ``Ascend`` ``GPU`` ``CPU``
565
686
 
566
687
  Examples:
567
- >>> from mindspore import ops
688
+ >>> import numpy as np
689
+ >>> from mindspore import jvp
568
690
  >>> from mindspore import Tensor
691
+ >>> import mindspore.nn as nn
569
692
  >>> class Net(nn.Cell):
570
693
  ... def construct(self, x, y):
571
694
  ... return x**3 + y
572
695
  >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
573
696
  >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
574
697
  >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
575
- >>> output = ops.jvp(Net(), (x, y), (v, v))
698
+ >>> output = jvp(Net(), (x, y), (v, v))
576
699
  >>> print(output[0])
577
700
  [[ 2. 10.]
578
701
  [30. 68.]]
579
702
  >>> print(output[1])
580
703
  [[ 4. 13.]
581
704
  [28. 49.]]
705
+ >>>
706
+ >>> def fn(x, y):
707
+ ... return x ** 3 + y, y
708
+ >>> output, jvp_out, aux = jvp(fn, (x, y), (v, v), has_aux=True)
709
+ >>> print(output)
710
+ [[ 2. 10.]
711
+ [30. 68.]]
712
+ >>> print(jvp_out)
713
+ [[ 4. 13.]
714
+ [28. 49.]]
715
+ >>> print(aux)
716
+ [[ 1. 2.]
717
+ [3. 4.]]
582
718
  """
583
- jvp_inner = _JvpInner()
584
-
585
- @ms_function(hash_args=fn)
586
- def _wrap_container(*arg):
587
- args = arg[1:]
719
+ _check_has_aux_type(has_aux)
720
+
721
+ def aux_fn(*args):
722
+ outputs = fn(*args)
723
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
724
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
725
+ res = outputs[0]
726
+ return res
727
+
728
+ def grad_single(u, first_grad_single_value):
729
+ if has_aux:
730
+ return _grad_single(aux_fn)(*first_grad_single_value, u)
731
+ return _grad_single(fn)(*first_grad_single_value, u)
732
+
733
+ def grad_all(u, first_grad):
734
+ if has_aux:
735
+ return _grad_all(aux_fn)(*first_grad, u)
736
+ return _grad_all(fn)(*first_grad, u)
737
+
738
+ def _wrap_container_inner(*arg):
739
+ jvp_inputs = arg[1:]
588
740
  vectors = arg[0]
589
- return jvp_inner(fn, vectors, *args)
741
+ if has_aux:
742
+ outputs = aux_fn(*jvp_inputs)
743
+ else:
744
+ outputs = fn(*jvp_inputs)
745
+ if isinstance(outputs, tuple):
746
+ u = ()
747
+ for item in outputs:
748
+ u = u + (mutable(oneslike(item)),)
749
+ else:
750
+ u = mutable(oneslike(outputs))
751
+ if len(jvp_inputs) == 1:
752
+ second_grad_net = _grad_single(grad_single)
753
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
754
+ else:
755
+ second_grad_net = _grad_single(grad_all)
756
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
757
+ if has_aux:
758
+ res = fn(*jvp_inputs)
759
+ if len(res) == 2:
760
+ return res[0], gradient_outputs, res[1]
761
+ return res[0], gradient_outputs, res[1:]
762
+ return outputs, gradient_outputs
763
+
764
+ if has_aux:
765
+ @jit(hash_args=aux_fn)
766
+ def _wrap_container(*arg):
767
+ return _wrap_container_inner(*arg)
768
+ else:
769
+ @jit(hash_args=fn)
770
+ def _wrap_container(*arg):
771
+ return _wrap_container_inner(*arg)
590
772
 
591
773
  if not isinstance(inputs, (Tensor, tuple, list)) or not isinstance(v, (Tensor, tuple, list)):
592
774
  _raise_type_error()
775
+
776
+ inputs_len = 1
777
+ v_len = 1
778
+ if isinstance(inputs, (tuple, list)):
779
+ inputs_len = len(inputs)
780
+ if isinstance(v, (tuple, list)):
781
+ v_len = len(v)
782
+ _check_jvp_input_v_len(inputs_len, v_len)
783
+
593
784
  if isinstance(v, list):
594
785
  v = tuple(v)
595
786
  if isinstance(inputs, (tuple, list)):
@@ -647,7 +838,7 @@ def linearize(fn, inputs):
647
838
  """
648
839
  linearize_inner = _LinearizeInner()
649
840
 
650
- @ms_function(hash_args=fn)
841
+ @jit(hash_args=fn)
651
842
  def _wrap_container(*arg):
652
843
  args = arg[1:-1]
653
844
  vectors = arg[-1]
@@ -664,24 +855,38 @@ def linearize(fn, inputs):
664
855
  return output, partial(_wrap_container, output, *inputs)
665
856
 
666
857
 
667
- def vjp(fn, inputs, v):
858
+ def _check_tensor(inputs):
859
+ if not isinstance(inputs, (Tensor, tuple)):
860
+ raise TypeError("The inputs type must be Tensor.")
861
+ if isinstance(inputs, tuple):
862
+ for item in inputs:
863
+ if not isinstance(item, (Tensor, tuple, list)):
864
+ raise TypeError("The inputs type must be Tensor.")
865
+ return True
866
+
867
+
868
+ def vjp(fn, *inputs, has_aux=False):
668
869
  """
669
870
  Compute the vector-jacobian-product of the given network. `vjp` matches
670
- `reverse-mode differentiation <https://www.mindspore.cn/docs/en/r1.10/design/auto_gradient.html#reverse-mode-ad>`_.
671
-
672
- Note:
673
- This function is subjected to change in the future.
871
+ `reverse-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#reverse-mode-ad>`_.
674
872
 
675
873
  Args:
676
874
  fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
677
875
  Tensors.
678
876
  inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
679
- v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in vector-jacobian-product. The shape and type of `v`
680
- should be the same as `fn(inputs)` .
877
+ has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
878
+ will be returned straightly. It means the `fn` must return more than one outputs in this case.
879
+ Default: False.
681
880
 
682
881
  Returns:
683
- - **net_output** (Union[Tensor, tuple[Tensor]]) - The result of `fn(inputs)` .
684
- - **vjp** (Union[Tensor, tuple[Tensor]]) - The result of vector-jacobian-product.
882
+ Forward outputs and function to calculate vjp.
883
+
884
+ - **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)`. Specially, when `has_aux` is set
885
+ True, `netout` is the first output of `fn(inputs)`.
886
+ - **vjp_fn** (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and
887
+ type should be the same as `netout` .
888
+ - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
889
+ It means the second to last outputs of `fn(inputs)`. Specially, `aux_value` does not contribute to gradient.
685
890
 
686
891
  Raises:
687
892
  TypeError: `inputs` or `v` does not belong to required types.
@@ -690,7 +895,9 @@ def vjp(fn, inputs, v):
690
895
  ``Ascend`` ``GPU`` ``CPU``
691
896
 
692
897
  Examples:
693
- >>> from mindspore import ops
898
+ >>> import numpy as np
899
+ >>> import mindspore.nn as nn
900
+ >>> from mindspore import vjp
694
901
  >>> from mindspore import Tensor
695
902
  >>> class Net(nn.Cell):
696
903
  ... def construct(self, x, y):
@@ -698,41 +905,505 @@ def vjp(fn, inputs, v):
698
905
  >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
699
906
  >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
700
907
  >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
701
- >>> output = ops.vjp(Net(), (x, y), v)
702
- >>> print(output[0])
908
+ >>> outputs, vjp_fn = vjp(Net(), x, y)
909
+ >>> print(outputs)
703
910
  [[ 2. 10.]
704
911
  [30. 68.]]
705
- >>> print(output[1])
912
+ >>> gradient = vjp_fn(v)
913
+ >>> print(gradient)
706
914
  (Tensor(shape=[2, 2], dtype=Float32, value=
707
915
  [[ 3.00000000e+00, 1.20000000e+01],
708
916
  [ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
709
917
  [[ 1.00000000e+00, 1.00000000e+00],
710
918
  [ 1.00000000e+00, 1.00000000e+00]]))
919
+ >>> def fn(x, y):
920
+ ... return 2 * x + y, y ** 3
921
+ >>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True)
922
+ >>> gradient = vjp_fn(v)
923
+ >>> print(outputs)
924
+ [[ 3. 6.]
925
+ [ 9. 12.]]
926
+ >>> print(aux)
927
+ [[ 1. 8.]
928
+ [27. 64.]]
929
+ >>> print(gradient)
930
+ (Tensor(shape=[2, 2], dtype=Float32, value=
931
+ [[ 2.00000000e+00, 2.00000000e+00],
932
+ [ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
933
+ [[ 1.00000000e+00, 1.00000000e+00],
934
+ [ 1.00000000e+00, 1.00000000e+00]]))
711
935
  """
712
- vjp_inner = _VjpInner()
936
+ _check_tensor(inputs)
937
+ _check_has_aux_type(has_aux)
938
+
939
+ def aux_fn(*args):
940
+ outputs = fn(*args)
941
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
942
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
943
+ res = outputs[0]
944
+ return res
945
+
946
+ def wrap_container(*v):
947
+ _check_tensor(v)
948
+ if has_aux:
949
+ fn_ = aux_fn
950
+ else:
951
+ fn_ = fn
952
+ if len(v) == 1:
953
+ return _grad_all(fn_)(*inputs, v[0])
954
+ return _grad_all(fn_)(*inputs, v)
713
955
 
714
- @ms_function(hash_args=fn)
715
- def wrap_container(*arg):
716
- args = arg[:-1]
717
- vectors = arg[-1]
718
- return vjp_inner(fn, *args, vectors)
956
+ res = fn(*inputs)
957
+ if has_aux:
958
+ if len(res) == 2:
959
+ return res[0], wrap_container, res[1]
960
+ return res[0], wrap_container, res[1:]
961
+ return res, wrap_container
719
962
 
720
- if not isinstance(inputs, (Tensor, tuple, list)) or not isinstance(v, (Tensor, tuple, list)):
721
- _raise_type_error()
722
- if isinstance(v, list):
723
- v = tuple(v)
724
- if isinstance(inputs, (tuple, list)):
725
- return wrap_container(*inputs, v)
726
- return wrap_container(inputs, v)
963
+
964
+ @constexpr
965
+ def _jac_generate_target_dimension(x):
966
+ """For given length = len(x), this method generates target dimension tuple (1, 2, 3,..., length, 0)."""
967
+ target_dimension = tuple(index + 1 for index, _ in enumerate(x[1:])) + (0,)
968
+ return target_dimension
969
+
970
+
971
+ def _jacfwd_trans_item(item, inputs_shape, grad_position):
972
+ """transfer origin item to derivative of each output with respect to each input."""
973
+ output_wrt_input_all = ()
974
+ for i in grad_position:
975
+ origin_output_wrt_input = item[inputs_shape[i][1]:inputs_shape[i + 1][1]]
976
+ target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
977
+ temp = transpose(origin_output_wrt_input, target_dimension)
978
+ output_wrt_input = reshape(temp, temp.shape[:-1] + inputs_shape[i + 1][0])
979
+ output_wrt_input_all += (output_wrt_input,)
980
+ return output_wrt_input_all
981
+
982
+
983
+ def _jac_postprocess(x, shape, grad_position, mode):
984
+ """reformat jacobian."""
985
+
986
+ if mode == 'forward':
987
+ func = _jacfwd_trans_item
988
+ args = (shape, grad_position)
989
+ else:
990
+ func = _jacrev_trans_item
991
+ args = (shape,)
992
+
993
+ if isinstance(x, tuple):
994
+ jacobian = ()
995
+ for item in x:
996
+ jacobian += func(item, *args)
997
+ res = jacobian
998
+ else:
999
+ res = func(x, *args)
1000
+ if len(res) == 1:
1001
+ return res[0]
1002
+ input_num = len(grad_position)
1003
+ if len(res) % input_num != 0:
1004
+ raise ValueError("The numbers of inputs and outputs do not match.")
1005
+ output_num = len(res) // input_num
1006
+ if input_num == 1 or output_num == 1:
1007
+ return res
1008
+ jac = ()
1009
+ for i in range(output_num):
1010
+ input_grad = ()
1011
+ for j in range(input_num):
1012
+ if mode == 'forward':
1013
+ grad_increment = (res[i * input_num + j],)
1014
+ else:
1015
+ grad_increment = (res[j * output_num + i],)
1016
+ input_grad += grad_increment
1017
+ jac += (input_grad,)
1018
+ return jac
1019
+
1020
+
1021
+ def _jacfwd_postprocess(x, inputs_shape, grad_position):
1022
+ """reformat forward-computed Jacobian."""
1023
+ return _jac_postprocess(x, inputs_shape, grad_position, 'forward')
1024
+
1025
+
1026
+ def _jacfwd_construct_v(inputs, grad_position):
1027
+ """
1028
+ For input (x1, x2), x1.shape = (a, b), x2.shape = (c, d), this method generates corresponding v (v1, v2),
1029
+ v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
1030
+ """
1031
+ v = ()
1032
+ primals = ()
1033
+ inputs_shape = (((), 0),)
1034
+ num = 0
1035
+ items_num = ()
1036
+ cum_num = (0,)
1037
+ for item in inputs:
1038
+ num += size(item)
1039
+ inputs_shape += ((item.shape, num),)
1040
+ items_num += (size(item),)
1041
+ cum_num += (num,)
1042
+ for i, element in enumerate(inputs):
1043
+ item_size = items_num[i]
1044
+ if i in grad_position:
1045
+ temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
1046
+ else:
1047
+ temp2 = zeros((num, item_size), mstype.float32)
1048
+ input_v = reshape(temp2, (num,) + element.shape)
1049
+ primal = broadcast_to(element, (num,) + element.shape)
1050
+ v += (input_v,)
1051
+ primals += (primal,)
1052
+ if len(inputs) == 1:
1053
+ return primals, v[0], inputs_shape
1054
+ return primals, v, inputs_shape
1055
+
1056
+
1057
+ _vmap = _Vmap()
1058
+
1059
+
1060
+ def jacfwd(fn, grad_position=0, has_aux=False):
1061
+ """
1062
+ Compute Jacobian via forward mode, corresponding to
1063
+ `forward-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#forward-mode-ad>`_.
1064
+ When number of outputs is much greater than that of inputs, it's better to calculate Jacobian via forward mode than
1065
+ reverse mode to get better performance.
1066
+
1067
+ Args:
1068
+ fn (Union[Cell, Function]): Function to do GradOperation.
1069
+ grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
1070
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1071
+ has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
1072
+ while the other outputs will be returned straightly. It means the `fn` must return more than one
1073
+ outputs in this case. Default: False.
1074
+
1075
+ Returns:
1076
+ Function, returns the Jacobian function for the input function or cell.
1077
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
1078
+ like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1079
+
1080
+ Raises:
1081
+ TypeError: `grad_position` or `has_aux` does not belong to required types.
1082
+
1083
+ Supported Platforms:
1084
+ ``Ascend`` ``GPU`` ``CPU``
1085
+
1086
+ Examples:
1087
+ >>> import numpy as np
1088
+ >>> import mindspore.nn as nn
1089
+ >>> from mindspore import jacfwd
1090
+ >>> from mindspore import Tensor
1091
+ >>> class MultipleInputsMultipleOutputsNet(nn.Cell):
1092
+ ... def construct(self, x, y, z):
1093
+ ... return x ** 2 + y ** 2 + z ** 2, x * y * z
1094
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1095
+ >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1096
+ >>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
1097
+ >>> net = MultipleInputsMultipleOutputsNet()
1098
+ >>> jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
1099
+ >>> print(jac)
1100
+ [[[[ 2., 0.]
1101
+ [ 0., 0.]]
1102
+ [[ 0., 4.]
1103
+ [ 0., 0.]]]
1104
+ [[[ 0., 0.]
1105
+ [ 6., 0.]]
1106
+ [[ 0., 0.]
1107
+ [ 0., 8.]]]]
1108
+ >>> print(aux)
1109
+ [[ 1. 4.]
1110
+ [ 9. 16.]]
1111
+ """
1112
+ _check_has_aux_type(has_aux)
1113
+
1114
+ def aux_fn(*args):
1115
+ outputs = fn(*args)
1116
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
1117
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
1118
+ res = outputs[0]
1119
+ return res
1120
+
1121
+ def grad_single(u, first_grad_single_value):
1122
+ if has_aux:
1123
+ return _grad_single(aux_fn)(*first_grad_single_value, u)
1124
+ return _grad_single(fn)(*first_grad_single_value, u)
1125
+
1126
+ def grad_all(u, first_grad):
1127
+ if has_aux:
1128
+ return _grad_all(aux_fn)(*first_grad, u)
1129
+ return _grad_all(fn)(*first_grad, u)
1130
+
1131
+ @jit
1132
+ def wrapped(*args):
1133
+ checked_grad_position = _check_grad_position(grad_position, len(args))
1134
+ primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
1135
+
1136
+ def inner_fn(jvp_inputs, vectors):
1137
+ outputs = fn(*jvp_inputs)
1138
+ if isinstance(outputs, tuple):
1139
+ u = ()
1140
+ for item in outputs:
1141
+ u = u + (mutable(oneslike(item)),)
1142
+ else:
1143
+ u = mutable(oneslike(outputs))
1144
+ if len(jvp_inputs) == 1:
1145
+ second_grad_net = _grad_single(grad_single)
1146
+ else:
1147
+ second_grad_net = _grad_single(grad_all)
1148
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
1149
+ return gradient_outputs
1150
+
1151
+ def inner_aux_fn(jvp_inputs, vectors):
1152
+ outputs = aux_fn(*jvp_inputs)
1153
+ u = mutable(oneslike(outputs))
1154
+ if len(jvp_inputs) == 1:
1155
+ second_grad_net = _grad_single(grad_single)
1156
+ else:
1157
+ second_grad_net = _grad_single(grad_all)
1158
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
1159
+ return gradient_outputs
1160
+
1161
+ if has_aux:
1162
+ res = _vmap(inner_aux_fn)(primals, v)
1163
+ jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
1164
+ forward_outputs = fn(*args)
1165
+ if len(forward_outputs) == 2:
1166
+ return jac_res, forward_outputs[1]
1167
+ return jac_res, forward_outputs[1:]
1168
+ res = _vmap(inner_fn)(primals, v)
1169
+ jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
1170
+ return jac_res
1171
+
1172
+ return wrapped
1173
+
1174
+
1175
+ def _jacrev_trans_item(item, outputs_shape):
1176
+ """transfer origin item to derivative of each output with respect to each input."""
1177
+ output_wrt_input_all = ()
1178
+ length = len(outputs_shape) - 1
1179
+ for i in range(length):
1180
+ origin_output_wrt_input = item[outputs_shape[i][1]:outputs_shape[i + 1][1]]
1181
+ target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
1182
+ temp = transpose(origin_output_wrt_input, target_dimension)
1183
+ output_wrt_input = reshape(origin_output_wrt_input, outputs_shape[i + 1][0] + temp.shape[:-1])
1184
+ output_wrt_input_all += (output_wrt_input,)
1185
+ return output_wrt_input_all
1186
+
1187
+
1188
+ def _jacrev_postprocess(x, outputs_shape, grad_position):
1189
+ """reformat reverse-computed jacobian."""
1190
+ return _jac_postprocess(x, outputs_shape, grad_position, 'reverse')
1191
+
1192
+
1193
+ def _jacrev_construct_v(inputs, outputs, has_aux=False):
1194
+ """
1195
+ For outputs (y1, y2), y1.shape = (a, b), y2.shape = (c, d), this method generates corresponding v (v1, v2),
1196
+ v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
1197
+ """
1198
+ if isinstance(outputs, Tensor):
1199
+ outputs = (outputs,)
1200
+ if has_aux:
1201
+ outputs = (outputs[0],)
1202
+ v = ()
1203
+ primals = ()
1204
+ outputs_shape = (((), 0),)
1205
+ num = 0
1206
+ items_num = ()
1207
+ cum_num = (0,)
1208
+ for item in outputs:
1209
+ item_num = size(item)
1210
+ num += item_num
1211
+ outputs_shape += ((item.shape, num),)
1212
+ items_num += (item_num,)
1213
+ cum_num += (num,)
1214
+ for element in inputs:
1215
+ primal = broadcast_to(element, (num,) + element.shape)
1216
+ primals += (primal,)
1217
+ for i, element in enumerate(outputs):
1218
+ item_size = items_num[i]
1219
+ temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
1220
+ output_v = reshape(temp2, (num,) + element.shape)
1221
+ v += (output_v,)
1222
+ if len(outputs) == 1 or has_aux:
1223
+ return primals, v[0], outputs_shape
1224
+ return primals, v, outputs_shape
1225
+
1226
+
1227
+ _grad = _Grad(get_by_position=True, has_aux=False, sens_param=True)
1228
+
1229
+
1230
+ def jacrev(fn, grad_position=0, has_aux=False):
1231
+ """
1232
+ Compute Jacobian via reverse mode, corresponding to
1233
+ `reverse-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#reverse-mode-ad>`_.
1234
+ When number of inputs is much greater than that of outputs, it's better to calculate Jacobian via reverse mode than
1235
+ forward mode to get better performance.
1236
+
1237
+ Args:
1238
+ fn (Union[Cell, Function]): Function to do GradOperation.
1239
+ grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
1240
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1241
+ has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
1242
+ while the other outputs will be returned straightly. It means the `fn` must return more than
1243
+ one outputs in this case. Default: False.
1244
+
1245
+ Returns:
1246
+ Function, returns the Jacobian function for the input function or cell.
1247
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
1248
+ like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1249
+
1250
+ Raises:
1251
+ TypeError: `grad_position` or `has_aux` does not belong to required types.
1252
+
1253
+ Supported Platforms:
1254
+ ``Ascend`` ``GPU`` ``CPU``
1255
+
1256
+ Examples:
1257
+ >>> import numpy as np
1258
+ >>> import mindspore.nn as nn
1259
+ >>> from mindspore import jacrev
1260
+ >>> from mindspore import Tensor
1261
+ >>> class MultipleInputsMultipleOutputsNet(nn.Cell):
1262
+ ... def construct(self, x, y, z):
1263
+ ... return x ** 2 + y ** 2 + z ** 2, x * y * z
1264
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1265
+ >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1266
+ >>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
1267
+ >>> net = MultipleInputsMultipleOutputsNet()
1268
+ >>> jac, aux = jacrev(net, grad_position=0, has_aux=True)(x, y, z)
1269
+ >>> print(jac)
1270
+ [[[[ 2., 0.]
1271
+ [ 0., 0.]]
1272
+ [[ 0., 4.]
1273
+ [ 0., 0.]]]
1274
+ [[[ 0., 0.]
1275
+ [ 6., 0.]]
1276
+ [[ 0., 0.]
1277
+ [ 0., 8.]]]]
1278
+ >>> print(aux)
1279
+ [[ 1. 4.]
1280
+ [ 9. 16.]]
1281
+ """
1282
+ _check_has_aux_type(has_aux)
1283
+
1284
+ def aux_fn(*args):
1285
+ outputs = fn(*args)
1286
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
1287
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
1288
+ res = outputs[0]
1289
+ return res
1290
+
1291
+ @jit
1292
+ def wrapped(*args):
1293
+ checked_grad_position = _check_grad_position(grad_position, len(args))
1294
+ outputs = fn(*args)
1295
+ primals, v, outputs_shape = _jacrev_construct_v(args, outputs, has_aux)
1296
+
1297
+ def inner_fn(vjp_inputs, vectors):
1298
+ gradient_outputs = _grad(fn, None, checked_grad_position)(*vjp_inputs, vectors)
1299
+ return gradient_outputs
1300
+
1301
+ def inner_aux_fn(vjp_inputs, vectors):
1302
+ gradient_outputs = _grad(aux_fn, None, checked_grad_position)(*vjp_inputs, vectors)
1303
+ return gradient_outputs
1304
+
1305
+ if has_aux:
1306
+ res = _vmap(inner_aux_fn)(primals, v)
1307
+ jac_res = _jacrev_postprocess(res, outputs_shape, checked_grad_position)
1308
+ forward_outputs = fn(*args)
1309
+ if len(forward_outputs) == 2:
1310
+ return jac_res, forward_outputs[1]
1311
+ return jac_res, forward_outputs[1:]
1312
+
1313
+ res = _vmap(inner_fn)(primals, v)
1314
+ jac_res = _jacrev_postprocess(res, outputs_shape, checked_grad_position)
1315
+ return jac_res
1316
+
1317
+ return wrapped
1318
+
1319
+
1320
+ def custom_vjp(fn=None):
1321
+ """
1322
+ Support vjp to custom bprop for function.
1323
+
1324
+ Args:
1325
+ fn (function): The `fn` that need to define custom bprop. Default: None.
1326
+
1327
+ Supported Platforms:
1328
+ ``Ascend`` ``GPU`` ``CPU``
1329
+ """
1330
+
1331
+ def deco(fn):
1332
+ class CustomVjp(Cell):
1333
+ """
1334
+ The CustomVjp decorates function into cell to support custom bprop.
1335
+ """
1336
+
1337
+ def __init__(self, fwd):
1338
+ super(CustomVjp, self).__init__()
1339
+ self.fwd = fwd
1340
+ self.bwd = None
1341
+ self.add_flags(custom_vjp=True)
1342
+
1343
+ def construct(self, *args):
1344
+ return self.fwd(*args)
1345
+
1346
+ def defbwd(self, bwd):
1347
+ self.bwd = bwd
1348
+
1349
+ def bprop(self, *args):
1350
+ return self.bwd(*args)
1351
+
1352
+ return CustomVjp(fn)
1353
+
1354
+ if fn is not None:
1355
+ return deco(fn)
1356
+ return deco
1357
+
1358
+
1359
+ def stop_gradient(value):
1360
+ """
1361
+ StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
1362
+ the gradient propagation from an output of a function.
1363
+ For more details, please refer to `Stop Gradient
1364
+ <https://www.mindspore.cn/tutorials/en/r2.0/beginner/autograd.html#stop-gradient>`_.
1365
+
1366
+ Args:
1367
+ value (Any): The value whose effect on the gradient to be eliminated.
1368
+
1369
+ Returns:
1370
+ The same as `value`.
1371
+
1372
+ Supported Platforms:
1373
+ ``Ascend`` ``GPU`` ``CPU``
1374
+
1375
+ Examples:
1376
+ >>> import mindspore.ops as ops
1377
+ >>> from mindspore import Tensor
1378
+ >>> from mindspore import dtype as mstype
1379
+ >>> def net(x, y):
1380
+ ... out1 = ops.MatMul()(x, y)
1381
+ ... out2 = ops.MatMul()(x, y)
1382
+ ... out2 = ops.stop_gradient(out2)
1383
+ ... return out1, out2
1384
+ ...
1385
+ >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
1386
+ >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
1387
+ >>> grad_fn = ops.grad(net)
1388
+ >>> output = grad_fn(x, y)
1389
+ >>> print(output)
1390
+ [[1.4100001 1.6 6.5999994]
1391
+ [1.4100001 1.6 6.5999994]]
1392
+ """
1393
+ return P.StopGradient()(value)
727
1394
 
728
1395
 
729
1396
  __all__ = [
730
1397
  'grad',
731
1398
  'value_and_grad',
1399
+ 'jacfwd',
1400
+ 'jacrev',
732
1401
  'jet',
733
1402
  'derivative',
734
1403
  'jvp',
735
1404
  'vjp',
736
- 'linearize'
1405
+ 'linearize',
1406
+ 'stop_gradient',
1407
+ 'get_grad'
737
1408
  ]
738
1409
  __all__.sort()