mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -22,28 +22,95 @@ from mindspore.ops import functional as F
22
22
  from mindspore.nn.cell import Cell
23
23
  from mindspore.common.tensor import Tensor
24
24
  from mindspore.common import dtype as mstype
25
- from mindspore.ops.primitive import constexpr
25
+ from mindspore.common.api import jit
26
+ from mindspore.ops.primitive import _primexpr
26
27
  from mindspore import _checkparam as Validator
27
- from mindspore.ops._primitive_cache import _get_cache_prim
28
28
 
29
29
  __all__ = [
30
30
  'clip_by_value',
31
+ 'clip_by_norm',
31
32
  'clamp',
32
33
  'clip',
33
34
  'clip_by_global_norm',
34
35
  ]
35
36
 
36
37
  hyper_map = C.HyperMap()
37
- max_op = _get_cache_prim(P.Maximum)()
38
- min_op = _get_cache_prim(P.Minimum)()
39
- cast_op = _get_cache_prim(P.Cast)()
40
- scalar2tensor_op = _get_cache_prim(P.ScalarToTensor)()
41
- partial_op = _get_cache_prim(P.Partial)()
38
+ max_op = P.Maximum()
39
+ min_op = P.Minimum()
40
+ cast_op = P.Cast()
41
+ scalar2tensor_op = P.ScalarToTensor()
42
+ partial_op = P.Partial()
42
43
  expand_dims = P.ExpandDims().add_prim_attr("grad_scale", True)
43
44
  get_square_sum = C.MultitypeFuncGraph("get_square_sum")
44
45
  apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
45
46
 
46
47
 
48
+ def _old_norm(norm_type, x):
49
+ """Add norm function"""
50
+ out = F.pow((F.reduce_sum(F.pow(x, norm_type))), 1. / norm_type).astype(x.dtype)
51
+ return out
52
+
53
+
54
+ @jit
55
+ def _cal_total_norm(x, norm_type):
56
+ if norm_type == float('inf'):
57
+ func = lambda data: data.abs().max()
58
+ total_norm = max(hyper_map(func, x))
59
+ else:
60
+ total_norm = _old_norm(norm_type, F.stack(hyper_map(partial_op(_old_norm, norm_type), x)))
61
+ return total_norm
62
+
63
+
64
+ def clip_by_norm(x, max_norm, norm_type=2.0, error_if_nonfinite=False):
65
+ r"""
66
+ Clip norm of a set of input Tensors. This norm is the result of calculating the norm of all elements in the input
67
+ separately, connecting them into a vector, and then calculating the norm.
68
+
69
+ Note:
70
+ The interface is suitable for gradient clipping scenarios, and only supports input of type float.
71
+
72
+ Args:
73
+ x (Union(Tensor, list[Tensor], tuple[Tensor])): Input that wishes to be clipped.
74
+ max_norm (Union(float, int)): The upper limit of the norm for this group of network parameters.
75
+ norm_type (Union(float, int)): Norm type. Default: ``2.0``.
76
+ error_if_nonfinite (bool): If it is ``True``, an exception is thrown if the total norm from the input
77
+ is nan, inf or -inf. If it is ``False``, no exception will be thrown.Default: ``False`` .
78
+
79
+ Returns:
80
+ Tensors, a list or tuple of Tensors, representing clipped Tensors.
81
+
82
+ Raises:
83
+ RuntimeError: If the total norm from the `x` is nan, inf or -inf.
84
+
85
+ Supported Platforms:
86
+ ``Ascend`` ``GPU`` ``CPU``
87
+
88
+ Examples:
89
+ >>> from mindspore import Tensor, ops
90
+ >>> x = Tensor([[0.8748, 0.1425, 0.0076], [0.7721, 0.4084, 0.0552], [4.6376, 0.2914, 2.1120]])
91
+ >>> out = ops.clip_by_norm(x, max_norm=1)
92
+ >>> print(out)
93
+ [[0.16650201 0.02712224 0.00144652]
94
+ [0.14695495 0.07773139 0.0105063 ]
95
+ [0.8826814 0.0554626 0.40198016]]
96
+ """
97
+ is_tensor = False
98
+ if isinstance(x, Tensor):
99
+ x = [x]
100
+ is_tensor = True
101
+ total_norm = _cal_total_norm(x, norm_type)
102
+ if error_if_nonfinite and F.logical_or(total_norm.isnan(), total_norm.isinf()):
103
+ raise RuntimeError(f"For clip_by_norm, the total norm of order {norm_type} from input is non-finite.")
104
+ clip_coef = max_norm / (total_norm + 1e-6)
105
+ if clip_coef < 1:
106
+ ret = hyper_map(partial_op(F.mul, clip_coef), x)
107
+ else:
108
+ ret = x
109
+ if is_tensor:
110
+ return ret[0]
111
+ return ret
112
+
113
+
47
114
  def clip_by_value(x, clip_value_min=None, clip_value_max=None):
48
115
  r"""
49
116
  Clips tensor values to a specified min and max.
@@ -74,8 +141,8 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
74
141
  Args:
75
142
  x (Union(Tensor, list[Tensor], tuple[Tensor])): Input data, which type is Tensor or a list or tuple of Tensor.
76
143
  Tensors of arbitrary dimensions are supported.
77
- clip_value_min (Union(Tensor, float, int)): The minimum value. Default: None.
78
- clip_value_max (Union(Tensor, float, int)): The maximum value. Default: None.
144
+ clip_value_min (Union(Tensor, float, int)): The minimum value. Default: ``None`` .
145
+ clip_value_max (Union(Tensor, float, int)): The maximum value. Default: ``None`` .
79
146
 
80
147
  Returns:
81
148
  (Union(Tensor, tuple[Tensor], list[Tensor])), a clipped Tensor or a tuple or a list of clipped Tensor.
@@ -115,6 +182,7 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
115
182
  [[ 5. 20. 5. 7.]
116
183
  [ 5. 11. 6. 20.]]
117
184
  """
185
+
118
186
  def _clip_by_value(clip_min, clip_max, x):
119
187
  if not isinstance(x, Tensor):
120
188
  raise TypeError("For 'clip_by_value', the type of argument 'x' must be "
@@ -161,7 +229,7 @@ def clamp(input, min=None, max=None):
161
229
 
162
230
  out_i= \left\{
163
231
  \begin{array}{align}
164
- max & \text{ if } x_i\ge max \\
232
+ max & \text{ if } x_i\ge max \\
165
233
  x_i & \text{ if } min \lt x_i \lt max \\
166
234
  min & \text{ if } x_i \le min \\
167
235
  \end{array}\right.
@@ -176,8 +244,8 @@ def clamp(input, min=None, max=None):
176
244
  Args:
177
245
  input (Union(Tensor, list[Tensor], tuple[Tensor])): Input data, which type is Tensor or a list or tuple of
178
246
  Tensor. Tensors of arbitrary dimensions are supported.
179
- min (Union(Tensor, float, int), optional): The minimum value. Default: None.
180
- max (Union(Tensor, float, int), optional): The maximum value. Default: None.
247
+ min (Union(Tensor, float, int), optional): The minimum value. Default: ``None`` .
248
+ max (Union(Tensor, float, int), optional): The maximum value. Default: ``None`` .
181
249
 
182
250
  Returns:
183
251
  Union(Tensor, tuple[Tensor], list[Tensor]), a clipped Tensor or a tuple or a list of clipped Tensor.
@@ -193,23 +261,23 @@ def clamp(input, min=None, max=None):
193
261
  ``Ascend`` ``GPU`` ``CPU``
194
262
 
195
263
  Examples:
196
- >>> # case 1: the data type of x is Tensor
264
+ >>> # case 1: the data type of input is Tensor
197
265
  >>> import mindspore
198
266
  >>> from mindspore import Tensor, ops
199
267
  >>> import numpy as np
200
268
  >>> min_value = Tensor(5, mindspore.float32)
201
269
  >>> max_value = Tensor(20, mindspore.float32)
202
- >>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
203
- >>> output = ops.clamp(x, min_value, max_value)
270
+ >>> input = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
271
+ >>> output = ops.clamp(input, min_value, max_value)
204
272
  >>> print(output)
205
273
  [[ 5. 20. 5. 7.]
206
274
  [ 5. 11. 6. 20.]]
207
- >>> # case 2: the data type of x is list[Tensor]
275
+ >>> # case 2: the data type of input is list[Tensor]
208
276
  >>> min_value = 5
209
277
  >>> max_value = 20
210
- >>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
211
- >>> y = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
212
- >>> output = ops.clamp([x,y], min_value, max_value)
278
+ >>> input_x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
279
+ >>> input_y = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
280
+ >>> output = ops.clamp([input_x,input_y], min_value, max_value)
213
281
  >>> for out in output:
214
282
  ... print(out)
215
283
  [[ 5. 20. 5. 7.]
@@ -220,14 +288,14 @@ def clamp(input, min=None, max=None):
220
288
  return clip_by_value(input, min, max)
221
289
 
222
290
 
223
- def clip(x, min=None, max=None):
291
+ def clip(input, min=None, max=None):
224
292
  r"""
225
293
  Alias for :func:`mindspore.ops.clamp` .
226
294
 
227
295
  Supported Platforms:
228
296
  ``Ascend`` ``GPU`` ``CPU``
229
297
  """
230
- return clamp(x, min, max)
298
+ return clamp(input, min, max)
231
299
 
232
300
 
233
301
  @get_square_sum.register("Tensor")
@@ -251,7 +319,7 @@ class _ClipByGlobalNorm(Cell):
251
319
 
252
320
  Args:
253
321
  clip_norm (Union(float, int)): The clipping ratio. Default: 1.0
254
- use_norm (Union(float, None)): The global norm. Default: None
322
+ use_norm (Union(float, None)): The global norm. Default: ``None``
255
323
 
256
324
  Inputs:
257
325
  - **x** (Union(tuple[Tensor], list[Tensor])) - Input data to clip.
@@ -281,10 +349,9 @@ class _ClipByGlobalNorm(Cell):
281
349
  return clip_x
282
350
 
283
351
 
284
- @constexpr
352
+ @_primexpr
285
353
  def _check_value(clip_norm):
286
354
  Validator.check_number("clip_norm", clip_norm, 0.0, Validator.GT, "clip_by_global_norm")
287
- return clip_norm
288
355
 
289
356
 
290
357
  def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
@@ -299,8 +366,8 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
299
366
 
300
367
  Args:
301
368
  x (Union(tuple[Tensor], list[Tensor])): Input data to clip.
302
- clip_norm (Union(float, int)): The clipping ratio, it should be greater than 0. Default: 1.0
303
- use_norm (None): The global norm. Default: None. Currently only none is supported.
369
+ clip_norm (Union(float, int)): The clipping ratio, it should be greater than 0. Default: ``1.0`` .
370
+ use_norm (None): The global norm. Default: ``None`` . Currently only none is supported.
304
371
 
305
372
  Returns:
306
373
  tuple[Tensor], a clipped Tensor. It has the same data type as `x` and each Tensor in the output tuple is the
@@ -324,6 +391,6 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
324
391
  [ 4.47213590e-01, 1.49071202e-01]]))
325
392
  """
326
393
 
327
- clip_norm = _check_value(clip_norm)
394
+ _check_value(clip_norm)
328
395
  clip_val = _ClipByGlobalNorm(clip_norm, use_norm)(x)
329
396
  return clip_val
@@ -34,7 +34,7 @@ def print_(*input_x):
34
34
  This function is used for debugging. When too much data is printed at the same time,
35
35
  in order not to affect the main process, the framework may discard some data. If you need to record the
36
36
  data completely, you are recommended to use the `Summary` function, and can check
37
- `Summary <https://www.mindspore.cn/mindinsight/docs/en/r2.0/summary_record.html>`_.
37
+ `Summary <https://www.mindspore.cn/mindinsight/docs/en/r2.2/summary_record.html>`_.
38
38
 
39
39
  Args:
40
40
  input_x (Union[Tensor, bool, int, float, str, tuple, list]): The inputs of print_.
@@ -50,6 +50,8 @@ def print_(*input_x):
50
50
  ``Ascend`` ``GPU`` ``CPU``
51
51
 
52
52
  Examples:
53
+ >>> import numpy as np
54
+ >>> from mindspore import Tensor
53
55
  >>> x = Tensor(np.ones([2, 1]).astype(np.int32))
54
56
  >>> y = Tensor(np.ones([2, 2]).astype(np.int32))
55
57
  >>> result = ops.print_('Print Tensor x and Tensor y:', x, y)
@@ -22,15 +22,14 @@ from mindspore.common import Tensor
22
22
  from mindspore.common import dtype as mstype
23
23
  from mindspore.nn.cell import Cell
24
24
  from mindspore.nn.grad.cell_grad import _LinearizeInner
25
- from mindspore.ops.primitive import constexpr
26
- from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose
25
+ from mindspore.ops.primitive import constexpr, _primexpr
26
+ from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose, zeros
27
27
  from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
28
28
  from mindspore.ops import operations as P
29
29
  from mindspore.ops.operations import _inner_ops as inner
30
30
 
31
31
  cast = P.Cast()
32
32
  dtype = P.DType()
33
- zeros = P.Zeros()
34
33
  oneslike = P.OnesLike()
35
34
 
36
35
 
@@ -108,24 +107,24 @@ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
108
107
  If int, get the gradient with respect to single input.
109
108
  If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
110
109
  If None, none derivative of any input will be figured out, and in this case, `weights` is required.
111
- Default: 0.
110
+ Default: ``0`` .
112
111
  weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
113
112
  calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
114
- Default: None.
115
- has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
116
- will be returned straightly. It means the `fn` must return more than one outputs in this case.
117
- Default: False.
113
+ Default: ``None`` .
114
+ has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
115
+ outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
116
+ Default: ``False`` .
118
117
  return_ids(bool): Whether return the tuple made by gradients and the index to specify which inputs
119
118
  to be differentiated or the name of parameters of the training network that need to calculate the gradient.
120
- If True, the output gradients will be replaced by the tuples made by gradients and the index to specify
119
+ If ``True`` , the output gradients will be replaced by the tuples made by gradients and the index to specify
121
120
  which inputs to be differentiated or the name of parameters of the training network.
122
- Default: False.
121
+ Default: ``False`` .
123
122
 
124
123
  Returns:
125
124
  Function, the gradient function to calculate gradient for the input function or cell.
126
- For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
127
- like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
128
- When return_ids is set to True, The format of the output will be the same with the output of grad when
125
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set ``True`` , gradient function will return
126
+ outputs like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
127
+ When return_ids is set to ``True`` , The format of the output will be the same with the output of grad when
129
128
  return_ids is set to false, but every gradient in the output will be replaced by a tuple of position id or
130
129
  parameter name and its gradient.
131
130
 
@@ -139,9 +138,7 @@ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
139
138
  Examples:
140
139
  >>> import numpy as np
141
140
  >>> import mindspore
142
- >>> import mindspore.nn as nn
143
- >>> from mindspore import Tensor, ops
144
- >>> from mindspore import grad
141
+ >>> from mindspore import Tensor, ops, nn, grad
145
142
  >>>
146
143
  >>> # Cell object to be differentiated
147
144
  >>> class Net(nn.Cell):
@@ -250,13 +247,13 @@ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False):
250
247
  If int, get the gradient with respect to single input.
251
248
  If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
252
249
  If None, none derivative of any input will be solved, and in this case, `weights` is required.
253
- Default: 0.
250
+ Default: ``0`` .
254
251
  weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
255
252
  calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
256
- Default: None.
257
- has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
258
- will be returned straightly. It means the `fn` must return more than one outputs in this case.
259
- Default: False.
253
+ Default: ``None`` .
254
+ has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
255
+ outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
256
+ Default: ``False`` .
260
257
 
261
258
  Returns:
262
259
  Function, returns the gradient function to calculate forward output and gradient for the input function or cell.
@@ -385,10 +382,8 @@ def get_grad(gradients, identifier):
385
382
  ``Ascend`` ``GPU`` ``CPU``
386
383
 
387
384
  Examples:
388
- >>> import numpy as np
389
385
  >>> import mindspore
390
- >>> import mindspore.nn as nn
391
- >>> from mindspore import Tensor, ops
386
+ >>> from mindspore import Tensor, nn
392
387
  >>> from mindspore import grad, get_grad
393
388
  >>>
394
389
  >>> # Cell object to be differentiated
@@ -659,7 +654,8 @@ def _check_jvp_input_v_len(inputs_len, v_len):
659
654
  def jvp(fn, inputs, v, has_aux=False):
660
655
  """
661
656
  Compute the jacobian-vector-product of the given network. `jvp` matches
662
- `forward-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#forward-mode-ad>`_.
657
+ `forward-mode differentiation
658
+ <https://www.mindspore.cn/docs/en/r2.2/design/programming_paradigm.html#forward-mode-ad>`_.
663
659
 
664
660
  Args:
665
661
  fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
@@ -667,16 +663,17 @@ def jvp(fn, inputs, v, has_aux=False):
667
663
  inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
668
664
  v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in jacobian-vector-product. The shape and type of `v`
669
665
  should be the same as `inputs` .
670
- has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
671
- will be returned straightly. It means the `fn` must return more than one outputs in this case.
672
- Default: False.
666
+ has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
667
+ outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
668
+ Default: ``False`` .
673
669
 
674
670
  Returns:
675
671
  - **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)` . Specially, when `has_aux` is set
676
- True, `netout` is the first output of `fn(inputs)` .
672
+ ``True`` , `netout` is the first output of `fn(inputs)` .
677
673
  - **jvp** (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.
678
- - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
679
- It means the second to last outputs of `fn(inputs)` . Specially, `aux_value` does not contribute to gradient.
674
+ - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is ``True`` , `aux_value` will be
675
+ returned. It means the second to last outputs of `fn(inputs)` . Specially, `aux_value` does not contribute to
676
+ gradient.
680
677
 
681
678
  Raises:
682
679
  TypeError: `inputs` or `v` does not belong to required types.
@@ -865,18 +862,26 @@ def _check_tensor(inputs):
865
862
  return True
866
863
 
867
864
 
868
- def vjp(fn, *inputs, has_aux=False):
865
+ _vjp_grad_op = _Grad(get_all=True, sens_param=True, merge_forward=True)
866
+ _vjp_grad_op_with_weight = _Grad(get_all=True, get_by_list=True, sens_param=True, merge_forward=True)
867
+
868
+
869
+ def vjp(fn, *inputs, weights=None, has_aux=False):
869
870
  """
870
871
  Compute the vector-jacobian-product of the given network. `vjp` matches
871
- `reverse-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#reverse-mode-ad>`_.
872
+ `reverse-mode differentiation
873
+ <https://www.mindspore.cn/docs/en/r2.2/design/programming_paradigm.html#reverse-mode-ad>`_.
872
874
 
873
875
  Args:
874
876
  fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
875
877
  Tensors.
876
878
  inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
879
+ weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
880
+ calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
881
+ Default: ``None`` .
877
882
  has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
878
883
  will be returned straightly. It means the `fn` must return more than one outputs in this case.
879
- Default: False.
884
+ Default: ``False``.
880
885
 
881
886
  Returns:
882
887
  Forward outputs and function to calculate vjp.
@@ -949,9 +954,12 @@ def vjp(fn, *inputs, has_aux=False):
949
954
  fn_ = aux_fn
950
955
  else:
951
956
  fn_ = fn
957
+ sens = v
952
958
  if len(v) == 1:
953
- return _grad_all(fn_)(*inputs, v[0])
954
- return _grad_all(fn_)(*inputs, v)
959
+ sens = v[0]
960
+ if weights is None:
961
+ return _vjp_grad_op(fn_)(*inputs, sens)
962
+ return _vjp_grad_op_with_weight(fn_, weights)(*inputs, sens)
955
963
 
956
964
  res = fn(*inputs)
957
965
  if has_aux:
@@ -961,10 +969,13 @@ def vjp(fn, *inputs, has_aux=False):
961
969
  return res, wrap_container
962
970
 
963
971
 
964
- @constexpr
972
+ @_primexpr
965
973
  def _jac_generate_target_dimension(x):
966
974
  """For given length = len(x), this method generates target dimension tuple (1, 2, 3,..., length, 0)."""
967
- target_dimension = tuple(index + 1 for index, _ in enumerate(x[1:])) + (0,)
975
+ dim = ()
976
+ for index in range(len(x[1:])):
977
+ dim += (index + 1,)
978
+ target_dimension = dim + (0,)
968
979
  return target_dimension
969
980
 
970
981
 
@@ -1009,11 +1020,7 @@ def _jac_postprocess(x, shape, grad_position, mode):
1009
1020
  for i in range(output_num):
1010
1021
  input_grad = ()
1011
1022
  for j in range(input_num):
1012
- if mode == 'forward':
1013
- grad_increment = (res[i * input_num + j],)
1014
- else:
1015
- grad_increment = (res[j * output_num + i],)
1016
- input_grad += grad_increment
1023
+ input_grad += (res[i * input_num + j],) if mode == 'forward' else (res[j * output_num + i],)
1017
1024
  jac += (input_grad,)
1018
1025
  return jac
1019
1026
 
@@ -1060,22 +1067,23 @@ _vmap = _Vmap()
1060
1067
  def jacfwd(fn, grad_position=0, has_aux=False):
1061
1068
  """
1062
1069
  Compute Jacobian via forward mode, corresponding to
1063
- `forward-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#forward-mode-ad>`_.
1070
+ `forward-mode differentiation
1071
+ <https://www.mindspore.cn/docs/en/r2.2/design/programming_paradigm.html#forward-mode-ad>`_.
1064
1072
  When number of outputs is much greater than that of inputs, it's better to calculate Jacobian via forward mode than
1065
1073
  reverse mode to get better performance.
1066
1074
 
1067
1075
  Args:
1068
1076
  fn (Union[Cell, Function]): Function to do GradOperation.
1069
1077
  grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
1070
- If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1071
- has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
1078
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: ``0`` .
1079
+ has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`,
1072
1080
  while the other outputs will be returned straightly. It means the `fn` must return more than one
1073
- outputs in this case. Default: False.
1081
+ outputs in this case. Default: ``False`` .
1074
1082
 
1075
1083
  Returns:
1076
1084
  Function, returns the Jacobian function for the input function or cell.
1077
- For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
1078
- like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1085
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set ``True`` , gradient function will return
1086
+ outputs like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1079
1087
 
1080
1088
  Raises:
1081
1089
  TypeError: `grad_position` or `has_aux` does not belong to required types.
@@ -1097,14 +1105,14 @@ def jacfwd(fn, grad_position=0, has_aux=False):
1097
1105
  >>> net = MultipleInputsMultipleOutputsNet()
1098
1106
  >>> jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
1099
1107
  >>> print(jac)
1100
- [[[[ 2., 0.]
1101
- [ 0., 0.]]
1102
- [[ 0., 4.]
1103
- [ 0., 0.]]]
1104
- [[[ 0., 0.]
1105
- [ 6., 0.]]
1106
- [[ 0., 0.]
1107
- [ 0., 8.]]]]
1108
+ [[[[ 2. 0.]
1109
+ [ 0. 0.]]
1110
+ [[ 0. 4.]
1111
+ [ 0. 0.]]]
1112
+ [[[ 0. 0.]
1113
+ [ 6. 0.]]
1114
+ [[ 0. 0.]
1115
+ [ 0. 8.]]]]
1108
1116
  >>> print(aux)
1109
1117
  [[ 1. 4.]
1110
1118
  [ 9. 16.]]
@@ -1230,22 +1238,23 @@ _grad = _Grad(get_by_position=True, has_aux=False, sens_param=True)
1230
1238
  def jacrev(fn, grad_position=0, has_aux=False):
1231
1239
  """
1232
1240
  Compute Jacobian via reverse mode, corresponding to
1233
- `reverse-mode differentiation <https://www.mindspore.cn/docs/en/r2.0/design/auto_gradient.html#reverse-mode-ad>`_.
1241
+ `reverse-mode differentiation
1242
+ <https://www.mindspore.cn/docs/en/r2.2/design/programming_paradigm.html#reverse-mode-ad>`_.
1234
1243
  When number of inputs is much greater than that of outputs, it's better to calculate Jacobian via reverse mode than
1235
1244
  forward mode to get better performance.
1236
1245
 
1237
1246
  Args:
1238
1247
  fn (Union[Cell, Function]): Function to do GradOperation.
1239
1248
  grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
1240
- If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
1241
- has_aux (bool, optional): If True, only the first output of `fn` contributes the gradient of `fn`,
1249
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: ``0`` .
1250
+ has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`,
1242
1251
  while the other outputs will be returned straightly. It means the `fn` must return more than
1243
- one outputs in this case. Default: False.
1252
+ one outputs in this case. Default: ``False`` .
1244
1253
 
1245
1254
  Returns:
1246
1255
  Function, returns the Jacobian function for the input function or cell.
1247
- For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set True, gradient function will return outputs
1248
- like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1256
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set ``True`` , gradient function will return
1257
+ outputs like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1249
1258
 
1250
1259
  Raises:
1251
1260
  TypeError: `grad_position` or `has_aux` does not belong to required types.
@@ -1267,14 +1276,14 @@ def jacrev(fn, grad_position=0, has_aux=False):
1267
1276
  >>> net = MultipleInputsMultipleOutputsNet()
1268
1277
  >>> jac, aux = jacrev(net, grad_position=0, has_aux=True)(x, y, z)
1269
1278
  >>> print(jac)
1270
- [[[[ 2., 0.]
1271
- [ 0., 0.]]
1272
- [[ 0., 4.]
1273
- [ 0., 0.]]]
1274
- [[[ 0., 0.]
1275
- [ 6., 0.]]
1276
- [[ 0., 0.]
1277
- [ 0., 8.]]]]
1279
+ [[[[ 2. 0.]
1280
+ [ 0. 0.]]
1281
+ [[ 0. 4.]
1282
+ [ 0. 0.]]]
1283
+ [[[ 0. 0.]
1284
+ [ 6. 0.]]
1285
+ [[ 0. 0.]
1286
+ [ 0. 8.]]]]
1278
1287
  >>> print(aux)
1279
1288
  [[ 1. 4.]
1280
1289
  [ 9. 16.]]
@@ -1322,7 +1331,7 @@ def custom_vjp(fn=None):
1322
1331
  Support vjp to custom bprop for function.
1323
1332
 
1324
1333
  Args:
1325
- fn (function): The `fn` that need to define custom bprop. Default: None.
1334
+ fn (function): The `fn` that need to define custom bprop. Default: ``None``.
1326
1335
 
1327
1336
  Supported Platforms:
1328
1337
  ``Ascend`` ``GPU`` ``CPU``
@@ -1361,7 +1370,7 @@ def stop_gradient(value):
1361
1370
  StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
1362
1371
  the gradient propagation from an output of a function.
1363
1372
  For more details, please refer to `Stop Gradient
1364
- <https://www.mindspore.cn/tutorials/en/r2.0/beginner/autograd.html#stop-gradient>`_.
1373
+ <https://www.mindspore.cn/tutorials/en/r2.2/beginner/autograd.html#stop-gradient>`_.
1365
1374
 
1366
1375
  Args:
1367
1376
  value (Any): The value whose effect on the gradient to be eliminated.