mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  164. mindspore/lib/libmindspore_shared_lib.so +0 -0
  165. mindspore/lib/libnnacl.so +0 -0
  166. mindspore/lib/libopencv_core.so.4.5 +0 -0
  167. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  169. mindspore/lib/libps_cache.so +0 -0
  170. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  183. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  187. mindspore/lib/plugin/ascend/libakg.so +0 -0
  188. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  189. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  190. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  191. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  193. mindspore/lib/plugin/cpu/libakg.so +0 -0
  194. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  196. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  197. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  198. mindspore/nn/__init__.py +0 -2
  199. mindspore/nn/cell.py +313 -74
  200. mindspore/nn/dynamic_lr.py +21 -21
  201. mindspore/nn/layer/activation.py +22 -30
  202. mindspore/nn/layer/basic.py +15 -13
  203. mindspore/nn/layer/channel_shuffle.py +1 -1
  204. mindspore/nn/layer/container.py +271 -9
  205. mindspore/nn/layer/conv.py +323 -204
  206. mindspore/nn/layer/dense.py +8 -5
  207. mindspore/nn/layer/embedding.py +33 -27
  208. mindspore/nn/layer/flash_attention.py +61 -95
  209. mindspore/nn/layer/image.py +8 -6
  210. mindspore/nn/layer/math.py +16 -25
  211. mindspore/nn/layer/normalization.py +107 -66
  212. mindspore/nn/layer/padding.py +1 -1
  213. mindspore/nn/layer/pooling.py +131 -109
  214. mindspore/nn/layer/rnn_cells.py +27 -22
  215. mindspore/nn/layer/rnns.py +13 -16
  216. mindspore/nn/layer/thor_layer.py +1 -1
  217. mindspore/nn/layer/transformer.py +221 -154
  218. mindspore/nn/learning_rate_schedule.py +9 -1
  219. mindspore/nn/loss/loss.py +235 -174
  220. mindspore/nn/optim/ada_grad.py +2 -1
  221. mindspore/nn/optim/adadelta.py +1 -0
  222. mindspore/nn/optim/adafactor.py +2 -1
  223. mindspore/nn/optim/adam.py +7 -4
  224. mindspore/nn/optim/adamax.py +3 -2
  225. mindspore/nn/optim/adasum.py +2 -2
  226. mindspore/nn/optim/asgd.py +2 -3
  227. mindspore/nn/optim/ftrl.py +6 -5
  228. mindspore/nn/optim/lamb.py +7 -4
  229. mindspore/nn/optim/lars.py +1 -1
  230. mindspore/nn/optim/lazyadam.py +5 -3
  231. mindspore/nn/optim/momentum.py +2 -1
  232. mindspore/nn/optim/optimizer.py +53 -4
  233. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  234. mindspore/nn/optim/rmsprop.py +4 -3
  235. mindspore/nn/optim/rprop.py +23 -12
  236. mindspore/nn/optim/sgd.py +26 -11
  237. mindspore/nn/optim/thor.py +9 -7
  238. mindspore/nn/probability/bijector/bijector.py +5 -5
  239. mindspore/nn/probability/bijector/power_transform.py +27 -27
  240. mindspore/nn/probability/bijector/softplus.py +3 -3
  241. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  242. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  243. mindspore/nn/probability/distribution/beta.py +3 -3
  244. mindspore/nn/probability/distribution/categorical.py +7 -7
  245. mindspore/nn/probability/distribution/cauchy.py +0 -1
  246. mindspore/nn/probability/distribution/distribution.py +3 -3
  247. mindspore/nn/probability/distribution/gamma.py +3 -3
  248. mindspore/nn/probability/distribution/geometric.py +4 -4
  249. mindspore/nn/probability/distribution/gumbel.py +4 -4
  250. mindspore/nn/probability/distribution/log_normal.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +2 -2
  252. mindspore/nn/probability/distribution/poisson.py +4 -4
  253. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  254. mindspore/nn/probability/distribution/uniform.py +6 -6
  255. mindspore/nn/wrap/__init__.py +4 -2
  256. mindspore/nn/wrap/cell_wrapper.py +87 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  274. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  275. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  276. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  277. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  278. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  279. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  280. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  281. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  283. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  284. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  285. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  286. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  287. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  288. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  289. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  290. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  291. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  293. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  294. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  295. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  297. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  298. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  299. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  300. mindspore/ops/_primitive_cache.py +1 -1
  301. mindspore/ops/_tracefunc.py +45 -13
  302. mindspore/ops/_utils/utils.py +6 -1
  303. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  304. mindspore/ops/_vmap/vmap_base.py +3 -3
  305. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  306. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  307. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  308. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  309. mindspore/ops/arg_dtype_cast.py +54 -0
  310. mindspore/ops/composite/base.py +37 -10
  311. mindspore/ops/composite/math_ops.py +5 -4
  312. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  313. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  314. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  315. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  316. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  319. mindspore/ops/deprecated.py +304 -0
  320. mindspore/ops/function/__init__.py +4 -1
  321. mindspore/ops/function/array_func.py +174 -193
  322. mindspore/ops/function/clip_func.py +81 -13
  323. mindspore/ops/function/debug_func.py +1 -1
  324. mindspore/ops/function/grad/grad_func.py +18 -9
  325. mindspore/ops/function/image_func.py +10 -4
  326. mindspore/ops/function/linalg_func.py +5 -5
  327. mindspore/ops/function/math_func.py +575 -386
  328. mindspore/ops/function/nn_func.py +568 -260
  329. mindspore/ops/function/random_func.py +88 -57
  330. mindspore/ops/function/sparse_func.py +1 -1
  331. mindspore/ops/function/sparse_unary_func.py +14 -12
  332. mindspore/ops/function/vmap_func.py +6 -5
  333. mindspore/ops/functional.py +15 -10
  334. mindspore/ops/op_info_register.py +244 -25
  335. mindspore/ops/operations/__init__.py +31 -19
  336. mindspore/ops/operations/_grad_ops.py +71 -7
  337. mindspore/ops/operations/_inner_ops.py +350 -17
  338. mindspore/ops/operations/_quant_ops.py +4 -8
  339. mindspore/ops/operations/_sequence_ops.py +42 -0
  340. mindspore/ops/operations/array_ops.py +68 -282
  341. mindspore/ops/operations/comm_ops.py +107 -59
  342. mindspore/ops/operations/custom_ops.py +94 -70
  343. mindspore/ops/operations/debug_ops.py +8 -4
  344. mindspore/ops/operations/image_ops.py +18 -12
  345. mindspore/ops/operations/inner_ops.py +26 -3
  346. mindspore/ops/operations/math_ops.py +192 -144
  347. mindspore/ops/operations/nn_ops.py +857 -489
  348. mindspore/ops/operations/other_ops.py +0 -22
  349. mindspore/ops/operations/random_ops.py +53 -111
  350. mindspore/ops/operations/sparse_ops.py +3 -1
  351. mindspore/ops/primitive.py +24 -18
  352. mindspore/parallel/_auto_parallel_context.py +68 -8
  353. mindspore/parallel/_cost_model_context.py +2 -2
  354. mindspore/parallel/_offload_context.py +17 -3
  355. mindspore/parallel/_parallel_serialization.py +12 -5
  356. mindspore/parallel/_ps_context.py +12 -0
  357. mindspore/parallel/_tensor.py +18 -13
  358. mindspore/parallel/_transformer/layers.py +5 -3
  359. mindspore/parallel/_transformer/loss.py +1 -0
  360. mindspore/parallel/_transformer/moe.py +2 -2
  361. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  362. mindspore/parallel/_transformer/transformer.py +23 -3
  363. mindspore/parallel/_utils.py +11 -7
  364. mindspore/parallel/algo_parameter_config.py +85 -5
  365. mindspore/parallel/checkpoint_transform.py +19 -12
  366. mindspore/parallel/shard.py +21 -14
  367. mindspore/profiler/common/struct_type.py +3 -3
  368. mindspore/profiler/common/util.py +4 -2
  369. mindspore/profiler/envprofiling.py +1 -1
  370. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  371. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  372. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  373. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  374. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  375. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  376. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  377. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  378. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  379. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  380. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  381. mindspore/profiler/parser/flops_parser.py +15 -11
  382. mindspore/profiler/parser/framework_parser.py +38 -22
  383. mindspore/profiler/parser/hccl_parser.py +16 -12
  384. mindspore/profiler/parser/integrator.py +22 -11
  385. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  386. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  387. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  388. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  389. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  390. mindspore/profiler/parser/optime_parser.py +1 -1
  391. mindspore/profiler/parser/profiler_info.py +21 -2
  392. mindspore/profiler/parser/step_trace_parser.py +11 -14
  393. mindspore/profiler/profiling.py +179 -89
  394. mindspore/rewrite/api/node.py +102 -19
  395. mindspore/rewrite/api/node_type.py +5 -1
  396. mindspore/rewrite/api/pattern_engine.py +1 -1
  397. mindspore/rewrite/api/scoped_value.py +9 -17
  398. mindspore/rewrite/api/symbol_tree.py +131 -47
  399. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  400. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  401. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  402. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  403. mindspore/rewrite/common/rewrite_elog.py +5 -1
  404. mindspore/rewrite/namer.py +33 -24
  405. mindspore/rewrite/namespace.py +14 -5
  406. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  407. mindspore/rewrite/node/call_function.py +79 -0
  408. mindspore/rewrite/node/cell_container.py +135 -0
  409. mindspore/rewrite/node/control_flow.py +88 -0
  410. mindspore/rewrite/{node.py → node/node.py} +273 -234
  411. mindspore/rewrite/node/node_manager.py +254 -0
  412. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  413. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  414. mindspore/rewrite/parsers/assign_parser.py +216 -221
  415. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  416. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  417. mindspore/rewrite/parsers/constant_parser.py +9 -6
  418. mindspore/rewrite/parsers/container_parser.py +9 -7
  419. mindspore/rewrite/parsers/for_parser.py +42 -21
  420. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  421. mindspore/rewrite/parsers/if_parser.py +28 -24
  422. mindspore/rewrite/parsers/module_parser.py +196 -25
  423. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  424. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  425. mindspore/rewrite/parsers/return_parser.py +6 -6
  426. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  427. mindspore/rewrite/sparsify/utils.py +1 -1
  428. mindspore/rewrite/symbol_tree.py +523 -578
  429. mindspore/rewrite/symbol_tree_builder.py +9 -193
  430. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  431. mindspore/run_check/_check_version.py +6 -4
  432. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  433. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  434. mindspore/scipy/linalg.py +1 -1
  435. mindspore/scipy/ops.py +55 -5
  436. mindspore/scipy/optimize/__init__.py +3 -2
  437. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  438. mindspore/scipy/optimize/minimize.py +7 -3
  439. mindspore/train/_utils.py +7 -3
  440. mindspore/train/amp.py +323 -123
  441. mindspore/train/anf_ir_pb2.py +14 -2
  442. mindspore/train/callback/_backup_and_restore.py +2 -12
  443. mindspore/train/callback/_callback.py +29 -4
  444. mindspore/train/callback/_checkpoint.py +23 -8
  445. mindspore/train/callback/_early_stop.py +2 -2
  446. mindspore/train/callback/_landscape.py +4 -4
  447. mindspore/train/callback/_loss_monitor.py +2 -2
  448. mindspore/train/callback/_on_request_exit.py +2 -2
  449. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  450. mindspore/train/callback/_summary_collector.py +15 -8
  451. mindspore/train/callback/_time_monitor.py +58 -5
  452. mindspore/train/data_sink.py +5 -11
  453. mindspore/train/dataset_helper.py +84 -57
  454. mindspore/train/loss_scale_manager.py +2 -2
  455. mindspore/train/metrics/__init__.py +3 -3
  456. mindspore/train/metrics/cosine_similarity.py +1 -1
  457. mindspore/train/metrics/hausdorff_distance.py +3 -2
  458. mindspore/train/metrics/mean_surface_distance.py +3 -2
  459. mindspore/train/metrics/metric.py +39 -19
  460. mindspore/train/metrics/roc.py +2 -2
  461. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  462. mindspore/train/mind_ir_pb2.py +85 -36
  463. mindspore/train/model.py +187 -47
  464. mindspore/train/serialization.py +487 -161
  465. mindspore/train/summary/_summary_adapter.py +1 -1
  466. mindspore/train/summary/_writer_pool.py +3 -2
  467. mindspore/train/summary/summary_record.py +37 -17
  468. mindspore/train/train_thor/convert_utils.py +3 -3
  469. mindspore/train/train_thor/dataset_helper.py +1 -1
  470. mindspore/version.py +1 -1
  471. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  474. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  475. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  477. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  478. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  479. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  480. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  481. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  482. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  483. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  485. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  486. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  487. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  488. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  489. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  490. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  491. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  492. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  493. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  494. mindspore/_extends/graph_kernel/expander.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  496. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  497. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  498. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  499. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  500. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  501. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  502. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  503. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  504. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  505. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  507. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  508. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  509. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  510. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  511. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  513. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  514. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  515. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  516. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  517. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  518. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  519. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  520. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  522. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  523. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  524. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  526. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  527. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  528. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  531. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  533. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  534. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  535. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  536. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  537. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  538. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  540. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  541. mindspore/dataset/datapreprocess/__init__.py +0 -20
  542. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  543. mindspore/include/api/net.h +0 -142
  544. mindspore/nn/lr_scheduler.py +0 -262
  545. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  546. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  547. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  548. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  549. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  550. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  559. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  563. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  574. mindspore/rewrite/node_visitor.py +0 -44
  575. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  576. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -96,7 +96,7 @@ class LPPool1d(Cell):
96
96
  f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
97
97
 
98
98
  Args:
99
- norm_type (Union[int, float]): Type of normalization, represents p in the formula, can not be 0.
99
+ norm_type (Union[int, float]): Type of normalization, represents :math:`p` in the formula, can not be 0.
100
100
 
101
101
  - if p = 1, the result is the sum of the elements within the pooling kernel(proportional to average
102
102
  pooling).
@@ -168,7 +168,7 @@ class LPPool2d(Cell):
168
168
  f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
169
169
 
170
170
  Args:
171
- norm_type(Union[int, float]) - Type of normalization, represents p in the formula, can not be 0.
171
+ norm_type(Union[int, float]) - Type of normalization, represents :math:`p` in the formula, can not be 0.
172
172
 
173
173
  - if p = 1, the result is the sum of the elements within the pooling kernel(proportional to average
174
174
  pooling).
@@ -297,16 +297,21 @@ class MaxPool3d(_PoolNd):
297
297
  or a tuple of three int numbers that represent depth, height and width of movement respectively.
298
298
  The value must be a positive integer. If the value is None, the default value `kernel_size` is used.
299
299
  Default: ``1`` .
300
- pad_mode (str): The optional value for pad mode, is ``"same"`` , ``"valid"`` or ``"pad"`` , not case sensitive.
301
- Default: ``"valid"`` .
302
-
303
- - ``"same"``: The output shape is the same as the input shape evenly divided by `stride`.
304
-
305
- - ``"valid"``: The possible largest height and width of output
306
- will be returned without padding. Extra pixels will be discarded.
307
-
308
- - ``"pad"``: pads the input. Pads the top, bottom, left, and right sides of the input with `padding` number
309
- of zeros. If this mode is set, `padding` must be greater than or equal to 0.
300
+ pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
301
+ ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
302
+
303
+ - ``"same"``: Pad the input around its depth/height/width dimension so that the shape of input and output
304
+ are the same when `stride` is set to ``1``.
305
+ The amount of padding to is calculated by the operator internally. If the amount is even,
306
+ it isuniformly distributed around the input, if it is odd, the excess amount goes
307
+ to the front/right/bottom side.
308
+ If this mode is set, `padding` must be 0.
309
+ - ``"valid"``: No padding is applied to the input, and the output returns the maximum
310
+ possible depth, height and width. Extra pixels that could not complete a full stride will
311
+ be discarded. If this mode is set, `padding` must be 0.
312
+ - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
313
+ in the depth, height and width dimension is determined by the `padding` parameter.
314
+ If this mode is set, `padding` must be greater than or equal to 0.
310
315
 
311
316
  padding (Union(int, tuple[int], list[int])): Pooling padding value. Default: ``0`` .
312
317
  `padding` can only be an integer or a tuple/list containing one or three integers.
@@ -337,7 +342,7 @@ class MaxPool3d(_PoolNd):
337
342
  :math:`(C_{out}, D_{out}, H_{out}, W_{out})`. It has the same data type as `x`.
338
343
  - **argmax** (Tensor) - Index corresponding to the maximum value. Data type is int64.
339
344
 
340
- If `pad_mode` is in `pad` mode, the output shape calculation formula is as follows:
345
+ If `pad_mode` is in ``"pad"`` mode, the output shape calculation formula is as follows:
341
346
 
342
347
  .. math::
343
348
  D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
@@ -356,9 +361,9 @@ class MaxPool3d(_PoolNd):
356
361
  TypeError: If `kernel_size` , `stride` , `padding` or `dilation` is neither an int nor a tuple.
357
362
  ValueError: If `kernel_size` or `stride` is less than 1.
358
363
  ValueError: If the `padding` parameter is neither an integer nor a tuple of length 3.
359
- ValueError: If `pad_mode` is not set to 'pad', setting return_indices to True or dilation to a value
364
+ ValueError: If `pad_mode` is not set to ``"pad"``, setting return_indices to True or dilation to a value
360
365
  other than 1.
361
- ValueError: If `padding` is non-zero when `pad_mode` is not 'pad'.
366
+ ValueError: If `padding` is non-zero when `pad_mode` is not ``"pad"``.
362
367
 
363
368
  Supported Platforms:
364
369
  ``Ascend`` ``GPU`` ``CPU``
@@ -369,13 +374,13 @@ class MaxPool3d(_PoolNd):
369
374
  >>> import numpy as np
370
375
  >>> np_x = np.random.randint(0, 10, [5, 3, 4, 6, 7])
371
376
  >>> x = Tensor(np_x, ms.float32)
372
- >>> pool1 = nn.MaxPool3d(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=3, return_indices=True)
377
+ >>> pool1 = nn.MaxPool3d(kernel_size=2, stride=1, pad_mode="pad", padding=1, dilation=3, return_indices=True)
373
378
  >>> output = pool1(x)
374
379
  >>> print(output[0].shape)
375
380
  (5, 3, 3, 5, 6)
376
381
  >>> print(output[1].shape)
377
382
  (5, 3, 3, 5, 6)
378
- >>> pool2 = nn.MaxPool3d(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=3, return_indices=False)
383
+ >>> pool2 = nn.MaxPool3d(kernel_size=2, stride=1, pad_mode="pad", padding=1, dilation=3, return_indices=False)
379
384
  >>> output2 = pool2(x)
380
385
  >>> print(output2.shape)
381
386
  (5, 3, 3, 5, 6)
@@ -437,16 +442,20 @@ class MaxPool2d(_PoolNd):
437
442
  stride (Union[int, tuple[int]]): The distance of kernel moving, an int number or a single element tuple that
438
443
  represents the height and width of movement are both stride, or a tuple of two int numbers that
439
444
  represent height and width of movement respectively. Default: ``1`` .
440
- pad_mode (str): The optional value for pad mode, is ``"same"`` , ``"valid"`` or ``"pad"`` , not case sensitive.
441
- Default: ``"valid"`` .
442
-
443
- - ``"same"``: The output shape is the same as the input shape evenly divided by `stride`.
444
-
445
- - ``"valid"``: The possible largest height and width of output
446
- will be returned without padding. Extra pixels will be discarded.
447
-
448
- - ``"pad"``: pads the input. Pads the top, bottom, left, and right sides of the input with `padding` number
449
- of zeros. If this mode is set, `padding` must be greater than or equal to 0.
445
+ pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
446
+ ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
447
+
448
+ - ``"same"``: Pad the input around its edges so that the shape of input and output
449
+ are the same when `stride` is set to ``1``.
450
+ The amount of padding to is calculated by the operator internally, If the amount is even, it is
451
+ uniformly distributed around the input, if it is odd, the excess amount goes to the right/bottom side.
452
+ If this mode is set, `padding` must be 0.
453
+ - ``"valid"``: No padding is applied to the input, and the output returns the maximum
454
+ possible height and width. Extra pixels that could not complete a full stride will
455
+ be discarded. If this mode is set, `padding` must be 0.
456
+ - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
457
+ in the height and width directions is determined by the `padding` parameter.
458
+ If this mode is set, `padding` must be greater than or equal to 0.
450
459
 
451
460
  padding (Union(int, tuple[int], list[int])): Specifies the padding value of the pooling operation.
452
461
  Default: ``0`` . `padding` can only be an integer or a tuple/list containing one or two integers. If
@@ -489,17 +498,17 @@ class MaxPool2d(_PoolNd):
489
498
 
490
499
  Raises:
491
500
  TypeError: If `kernel_size` or `stride` is neither int nor tuple.
492
- ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
493
- ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
501
+ ValueError: If `pad_mode` is neither ``"valid"`` nor ``"same"`` with not case sensitive.
502
+ ValueError: If `data_format` is neither ``'NCHW'`` nor ``'NHWC'`` .
494
503
  ValueError: If `kernel_size` or `stride` is less than 1.
495
504
  ValueError: If length of shape of `x` is not equal to 3 or 4.
496
- ValueError: If `pad_mode` is not 'pad', `padding`, `dilation`, `return_indices`, `ceil_mode` parameters are not
497
- set to their default values.
505
+ ValueError: If `pad_mode` is not ``"pad"``, `padding`, `dilation`, `return_indices`, `ceil_mode` parameters
506
+ are not set to their default values.
498
507
  ValueError: If the length of the tuple/list `padding` parameter is not 2.
499
508
  ValueError: If The length of the tuple dilation parameter is not 2.
500
509
  ValueError: If dilation parameter is neither an integer nor a tuple.
501
- ValueError: If `pad_mode` is 'pad' and `data_format` is 'NHWC'.
502
- ValueError: If `padding` is non-zero when `pad_mode` is not 'pad'.
510
+ ValueError: If `pad_mode` is ``"pad"`` and `data_format` is ``'NHWC'``.
511
+ ValueError: If `padding` is non-zero when `pad_mode` is not ``"pad"``.
503
512
 
504
513
  Supported Platforms:
505
514
  ``Ascend`` ``GPU`` ``CPU``
@@ -514,7 +523,7 @@ class MaxPool2d(_PoolNd):
514
523
  (1, 2, 2, 2)
515
524
  >>> np_x = np.random.randint(0, 10, [5, 3, 4, 5])
516
525
  >>> x = ms.Tensor(np_x, ms.float32)
517
- >>> pool2 = ms.nn.MaxPool2d(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=1, return_indices=True)
526
+ >>> pool2 = ms.nn.MaxPool2d(kernel_size=2, stride=1, pad_mode="pad", padding=1, dilation=1, return_indices=True)
518
527
  >>> output = pool2(x)
519
528
  >>> print(output[0].shape)
520
529
  (5, 3, 5, 6)
@@ -596,18 +605,20 @@ class MaxPool1d(_PoolNd):
596
605
  kernel_size (int): The size of kernel used to take the max value, Default: ``1`` .
597
606
  stride (int): The distance of kernel moving, an int number that represents
598
607
  the width of movement is stride, Default: ``1`` .
599
- pad_mode (str): The optional value for pad mode, is ``"same"`` , ``"valid"`` or ``"pad"`` , not case sensitive.
600
- Default: ``"valid"`` .
601
-
602
- - ``"same"``: Adopts the way of completion. The total number of padding will be calculated in horizontal
603
- and vertical directions and evenly distributed to top and bottom, left and right if possible.
604
- Otherwise, the last extra padding will be done from the bottom and the right side.
605
-
606
- - ``"valid"``: Adopts the way of discarding. The possible largest height and width of output
607
- will be returned without padding. Extra pixels will be discarded.
608
-
609
- - ``"pad"``: Performs padding on the input. Adds padding size of zeros to both ends of the input.
610
- If this mode is set, padding must be greater than or equal to 0.
608
+ pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
609
+ ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
610
+
611
+ - ``"same"``: Pad the input at the begin and end so that the shape of input and output
612
+ are the same when `stride` is set to ``1``.
613
+ The amount of padding to is calculated by the operator internally. If the amount is even, it is
614
+ uniformly distributed around the input, if it is odd, the excess padding is goes to the right side.
615
+ If this mode is set, `padding` must be 0.
616
+ - ``"valid"``: No padding is applied to the input, and the output returns the maximum
617
+ possible length. Extra pixels that could not complete a full stride will
618
+ be discarded. If this mode is set, `padding` must be 0.
619
+ - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
620
+ at the begin and end is determined by the `padding` parameter.
621
+ If this mode is set, `padding` must be greater than or equal to 0.
611
622
 
612
623
  padding (Union(int, tuple[int], list[int])): Padding value for the pooling. Default value is 0.
613
624
  padding can only be an integer or a tuple/list containing a single integer, in which case padding times or
@@ -641,16 +652,16 @@ class MaxPool1d(_PoolNd):
641
652
 
642
653
  Raises:
643
654
  TypeError: If `kernel_size` or `strides` is not an int.
644
- ValueError: If `pad_mode` is not 'valid', 'same' or 'pad', case-insensitive.
645
- ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
655
+ ValueError: If `pad_mode` is not ``"valid"``, ``"same"`` or ``"pad"``, case-insensitive.
656
+ ValueError: If `data_format` is neither ``'NCHW'`` nor ``'NHWC'``.
646
657
  ValueError: If `kernel_size` or `strides` is less than 1.
647
658
  ValueError: If length of shape of `x` is not equal to 2 or 3.
648
- ValueError: If `pad_mode` is not 'pad', `padding`, `dilation`, `return_indices`, `ceil_mode` parameters are not
649
- set to their default values.
659
+ ValueError: If `pad_mode` is not ``"pad"``, `padding`, `dilation`, `return_indices`, `ceil_mode` parameters
660
+ are not set to their default values.
650
661
  ValueError: If the length of the tuple/list `padding` parameter is not 1.
651
662
  ValueError: If The length of the tuple dilation parameter is not 1.
652
663
  ValueError: If dilation parameter is neither an integer nor a tuple.
653
- ValueError: If `padding` is non-zero when `pad_mode` is not 'pad'.
664
+ ValueError: If `padding` is non-zero when `pad_mode` is not ``"pad"``.
654
665
 
655
666
  Supported Platforms:
656
667
  ``Ascend`` ``GPU`` ``CPU``
@@ -667,7 +678,7 @@ class MaxPool1d(_PoolNd):
667
678
  (1, 2, 2)
668
679
  >>> np_x = np.random.randint(0, 10, [5, 3, 4])
669
680
  >>> x = ms.Tensor(np_x, ms.float32)
670
- >>> mpool2 = nn.MaxPool1d(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=1, return_indices=True)
681
+ >>> mpool2 = nn.MaxPool1d(kernel_size=2, stride=1, pad_mode="pad", padding=1, dilation=1, return_indices=True)
671
682
  >>> output = mpool2(x)
672
683
  >>> print(output[0].shape)
673
684
  (5, 3, 5)
@@ -787,19 +798,23 @@ class AvgPool3d(_PoolNd):
787
798
  element tuple that represents the depth, height and width of movement, or a tuple of three positive integers
788
799
  that represents depth, height and width of movement respectively. If the value is None, the default value
789
800
  `kernel_size` is used. Default: ``1`` .
790
- pad_mode (str, optional): Specifies the padding method of pooling, optional values are ``"same"``, ``"valid"``
791
- or ``"pad"`` , case insensitive. Default: ``"valid"`` .
792
-
793
- - same: The depth, height and width of the output is the same as the value after the input is divided
794
- by stride.
795
-
796
- - valid: Returns the output obtained by effective calculation without padding.
797
- The excess pixels that do not meet the calculation will be discarded.
798
-
799
- - pad: Pads the input. Fill the front, back, top, and bottom of the input with 0s of size `padding`.
801
+ pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
802
+ ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
803
+
804
+ - ``"same"``: Pad the input around its depth/height/width dimension so that the shape of input and output
805
+ are the same when `stride` is set to ``1``.
806
+ The amount of padding to is calculated by the operator internally. If the amount is even,
807
+ it isuniformly distributed around the input, if it is odd, the excess amount goes
808
+ to the front/right/bottom side.
809
+ If this mode is set, `padding` must be 0.
810
+ - ``"valid"``: No padding is applied to the input, and the output returns the maximum
811
+ possible depth, height and width. Extra pixels that could not complete a full stride will
812
+ be discarded. If this mode is set, `padding` must be 0.
813
+ - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
814
+ in the depth, height and width dimension is determined by the `padding` parameter.
800
815
  If this mode is set, `padding` must be greater than or equal to 0.
801
816
 
802
- padding (Union(int, tuple[int], list[int]), optional): Pooling padding value, only 'pad' mode can be set to
817
+ padding (Union(int, tuple[int], list[int]), optional): Pooling padding value, only ``"pad"`` mode can be set to
803
818
  non-zero. Default: ``0`` . Only the following paddings are supported:
804
819
 
805
820
  - `padding` is an integer or a tuple/list containing one integer, it will be padded in six directions of
@@ -851,7 +866,7 @@ class AvgPool3d(_PoolNd):
851
866
  ValueError: If element of `padding` is less than 0.
852
867
  ValueError: If length of shape of `x` is neither 4 nor 5.
853
868
  ValueError: If `divisor_override` is less than or equal to 0.
854
- ValueError: If `padding` is non-zero when `pad_mode` is not 'pad'.
869
+ ValueError: If `padding` is non-zero when `pad_mode` is not ``"pad"``.
855
870
 
856
871
  Supported Platforms:
857
872
  ``Ascend`` ``GPU`` ``CPU``
@@ -864,7 +879,7 @@ class AvgPool3d(_PoolNd):
864
879
  >>> print(output.shape)
865
880
  (1, 2, 2, 2, 3)
866
881
  >>> x1 = ms.ops.randn(6, 5, 7, 7, 5).astype(ms.float32)
867
- >>> pool2 = ms.nn.AvgPool3d(4, stride=2, pad_mode='pad', padding=(2, 2, 1), divisor_override=10)
882
+ >>> pool2 = ms.nn.AvgPool3d(4, stride=2, pad_mode="pad", padding=(2, 2, 1), divisor_override=10)
868
883
  >>> output2 = pool2(x1)
869
884
  >>> print(output2.shape)
870
885
  (6, 5, 4, 4, 2)
@@ -912,19 +927,22 @@ class AvgPool2d(_PoolNd):
912
927
  stride (Union[int, tuple[int]]): The distance of kernel moving, an int number or a single element tuple that
913
928
  represents the height and width of movement are both strides, or a tuple of two int numbers that
914
929
  represent height and width of movement respectively. Default: ``1`` .
915
- pad_mode (str) - Specifies the padding method of pooling, optional values are ``"same"``, ``"valid"`` or
916
- ``"pad"`` , case insensitive. Default: ``"valid"`` .
917
-
918
- - ``"same"``: The height and width of the output is the same as the value after the input is divided by
919
- stride.
920
-
921
- - ``"valid"``: Returns the output obtained by effective calculation without padding.
922
- The excess pixels that do not meet the calculation will be discarded.
923
-
924
- - ``"pad"``: pads the input. Pads the top, bottom, left, and right sides of the input with `padding` number
925
- of zeros. If this mode is set, `padding` must be greater than or equal to 0.
930
+ pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
931
+ ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
932
+
933
+ - ``"same"``: Pad the input around its edges so that the shape of input and output
934
+ are the same when `stride` is set to ``1``.
935
+ The amount of padding to is calculated by the operator internally, If the amount is even, it is
936
+ uniformly distributed around the input, if it is odd, the excess amount goes to the right/bottom side.
937
+ If this mode is set, `padding` must be 0.
938
+ - ``"valid"``: No padding is applied to the input, and the output returns the maximum
939
+ possible height and width. Extra pixels that could not complete a full stride will
940
+ be discarded. If this mode is set, `padding` must be 0.
941
+ - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
942
+ in the height and width directions is determined by the `padding` parameter.
943
+ If this mode is set, `padding` must be greater than or equal to 0.
926
944
 
927
- padding (Union(int, tuple[int], list[int])): Pooling padding value, only 'pad' mode can be set to non-zero.
945
+ padding (Union(int, tuple[int], list[int])): Pooling padding value, only ``"pad"`` mode can be set to non-zero.
928
946
  Default: ``0`` . `padding` can only be an integer or a tuple/list containing one or two integers.
929
947
  If `padding` is an integer or a tuple/list containing one integer, it will be padded `padding` times in the
930
948
  four directions of the input. If `padding` is a tuple/list containing two integers, it will be padded
@@ -955,15 +973,15 @@ class AvgPool2d(_PoolNd):
955
973
 
956
974
  Raises:
957
975
  TypeError: If `kernel_size` or `strides` is neither int nor tuple.
958
- ValueError: If `pad_mode` is not 'valid' ,'same' or 'pad' with not case sensitive.
959
- ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
976
+ ValueError: If `pad_mode` is not ``"valid"`` , ``"same"`` or ``"pad"`` with not case sensitive.
977
+ ValueError: If `data_format` is neither ``'NCHW'`` nor ``'NHWC'``.
960
978
  ValueError: If `padding`, `ceil_mode`, `count_include_pad`, or `divisor_override` is used
961
- or `pad_mode` is `pad` when `data_format` is 'NHWC'.
979
+ or `pad_mode` is ``"pad"`` when `data_format` is 'NHWC'.
962
980
  ValueError: If `kernel_size` or `strides` is less than 1.
963
981
  ValueError: If length of `padding` tuple/list is not 1 or 2.
964
982
  ValueError: If length of shape of `x` is not equal to 3 or 4.
965
983
  ValueError: If `divisor_override` is less than or equal to 0.
966
- ValueError: If `padding` is non-zero when `pad_mode` is not 'pad'.
984
+ ValueError: If `padding` is non-zero when `pad_mode` is not ``"pad"``.
967
985
 
968
986
  Supported Platforms:
969
987
  ``Ascend`` ``GPU`` ``CPU``
@@ -977,7 +995,7 @@ class AvgPool2d(_PoolNd):
977
995
  >>> print(output.shape)
978
996
  (1, 2, 2, 2)
979
997
  >>> x = ms.ops.randn(6, 6, 8, 8)
980
- >>> pool2 = ms.nn.AvgPool2d(4, stride=1, pad_mode='pad', padding=2, divisor_override=5)
998
+ >>> pool2 = ms.nn.AvgPool2d(4, stride=1, pad_mode="pad", padding=2, divisor_override=5)
981
999
  >>> output2 = pool2(x)
982
1000
  >>> print(output2.shape)
983
1001
  (6, 6, 9, 9)
@@ -1062,18 +1080,22 @@ class AvgPool1d(_PoolNd):
1062
1080
  kernel_size (int): The size of kernel window used to take the average value, Default: ``1`` .
1063
1081
  stride (int): The distance of kernel moving, an int number that represents
1064
1082
  the width of movement is strides, Default: ``1`` .
1065
- pad_mode (str) - Specifies the padding method of pooling, optional values are ``"same"``, ``"valid"`` or
1066
- ``"pad"`` , case insensitive. Default: ``"valid"`` .
1067
-
1068
- - same: The width of the output is the same as the value after the input is divided by stride.
1069
-
1070
- - valid: Returns the output obtained by effective calculation without padding.
1071
- The excess pixels that do not meet the calculation will be discarded.
1072
-
1073
- - pad: Performs padding on the input. Adds padding size of zeros to both ends of the input.
1074
- If this mode is set, padding must be greater than or equal to ``0`` .
1083
+ pad_mode (str, optional): Specifies the padding mode with a padding value of 0. It can be set to:
1084
+ ``"same"`` , ``"valid"`` or ``"pad"`` . Default: ``"valid"`` .
1085
+
1086
+ - ``"same"``: Pad the input at the begin and end so that the shape of input and output
1087
+ are the same when `stride` is set to ``1``.
1088
+ The amount of padding to is calculated by the operator internally. If the amount is even, it is
1089
+ uniformly distributed around the input, if it is odd, the excess padding is goes to the right side.
1090
+ If this mode is set, `padding` must be 0.
1091
+ - ``"valid"``: No padding is applied to the input, and the output returns the maximum
1092
+ possible length. Extra pixels that could not complete a full stride will
1093
+ be discarded. If this mode is set, `padding` must be 0.
1094
+ - ``"pad"``: Pad the input with a specified amount. In this mode, the amount of padding
1095
+ at the begin and end is determined by the `padding` parameter.
1096
+ If this mode is set, `padding` must be greater than or equal to 0.
1075
1097
 
1076
- padding (Union(int, tuple[int], list[int])): Pooling padding value, only 'pad' mode can be set to non-zero.
1098
+ padding (Union(int, tuple[int], list[int])): Pooling padding value, only ``"pad"`` mode can be set to non-zero.
1077
1099
  Default: ``0`` . padding can only be an integer or a tuple/list containing a single integer, in which case
1078
1100
  padding times or padding[0] times are padded on both sides of the input.
1079
1101
  ceil_mode (bool): If ``True`` , use ceil to compute the output shape instead of floor. Default: ``False`` .
@@ -1093,11 +1115,11 @@ class AvgPool1d(_PoolNd):
1093
1115
 
1094
1116
  Raises:
1095
1117
  TypeError: If `kernel_size` or `stride` is not an int.
1096
- ValueError: If `pad_mode` is not 'valid' ,'same' or 'pad' with not case sensitive.
1118
+ ValueError: If `pad_mode` is not ``"valid"`` , ``"same"`` or ``"pad"`` with not case sensitive.
1097
1119
  ValueError: If `kernel_size` or `strides` is less than 1.
1098
1120
  ValueError: If length of `padding` tuple/list is not 1.
1099
1121
  ValueError: If length of shape of `x` is not equal to 2 or 3.
1100
- ValueError: If `padding` is non-zero when `pad_mode` is not 'pad'.
1122
+ ValueError: If `padding` is non-zero when `pad_mode` is not ``"pad"``.
1101
1123
 
1102
1124
  Supported Platforms:
1103
1125
  ``Ascend`` ``GPU`` ``CPU``
@@ -1111,7 +1133,7 @@ class AvgPool1d(_PoolNd):
1111
1133
  >>> result = output.shape
1112
1134
  >>> print(result)
1113
1135
  (1, 3, 1)
1114
- >>> pool2 = ms.nn.AvgPool1d(4, stride=1, ceil_mode=True, pad_mode='pad', padding=2)
1136
+ >>> pool2 = ms.nn.AvgPool1d(4, stride=1, ceil_mode=True, pad_mode="pad", padding=2)
1115
1137
  >>> x1 = ms.ops.randn(6, 6, 8)
1116
1138
  >>> output = pool2(x1)
1117
1139
  >>> print(output.shape)
@@ -1528,7 +1550,7 @@ class AdaptiveMaxPool2d(Cell):
1528
1550
 
1529
1551
  Outputs:
1530
1552
  Tensor, with the same type as the `input`.
1531
- Shape of the output is `input_shape[:len(input_shape) - len(out_shape)] + out_shape`.
1553
+ Shape of the output is :math:`input\_shape[:len(input\_shape) - len(out\_shape)] + out\_shape`.
1532
1554
 
1533
1555
  Raises:
1534
1556
  TypeError: If `output_size` is not int or tuple.
@@ -1860,7 +1882,7 @@ class MaxUnpool1d(Cell):
1860
1882
 
1861
1883
  .. math::
1862
1884
  \begin{array}{ll} \\
1863
- H_{out} = (H{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
1885
+ H_{out} = (H_{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
1864
1886
  \end{array}
1865
1887
 
1866
1888
  Args:
@@ -1877,8 +1899,8 @@ class MaxUnpool1d(Cell):
1877
1899
  Values of indices must belong to :math:`[0, H_{in} - 1]`.
1878
1900
  Data type must be in int32 or int64.
1879
1901
  - **output_size** (tuple[int], optional) - The output size. Default: ``None`` .
1880
- If output_size is None, then the shape of output computed by kernel_size, stride and padding.
1881
- If output_size is not None, then output_size must be :math:`(N, C, H)` , :math:`(C, H)` or
1902
+ If output_size is ``None``, then the shape of output computed by kernel_size, stride and padding.
1903
+ If output_size is not ``None``, then output_size must be :math:`(N, C, H)` , :math:`(C, H)` or
1882
1904
  :math:`(H)` and output_size must belong to
1883
1905
  :math:`[(N, C, H_{out} - stride[0]), (N, C, H_{out} + stride[0])]`.
1884
1906
 
@@ -1942,8 +1964,8 @@ class MaxUnpool2d(Cell):
1942
1964
 
1943
1965
  .. math::
1944
1966
  \begin{array}{ll} \\
1945
- H_{out} = (H{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
1946
- W_{out} = (W{in} - 1) \times stride[1] - 2 \times padding[1] + kernel\_size[1] \\
1967
+ H_{out} = (H_{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
1968
+ W_{out} = (W_{in} - 1) \times stride[1] - 2 \times padding[1] + kernel\_size[1] \\
1947
1969
  \end{array}
1948
1970
 
1949
1971
  Args:
@@ -1966,8 +1988,8 @@ class MaxUnpool2d(Cell):
1966
1988
  Values of indices must belong to :math:`[0, H_{in} \times W_{in} - 1]`.
1967
1989
  Data type must be in int32 or int64.
1968
1990
  - **output_size** (tuple[int], optional) - The output size. Default: ``None`` .
1969
- If output_size is None, then the shape of output computed by kernel_size, stride and padding.
1970
- If output_size is not None, then output_size must be :math:`(N, C, H, W)`, :math:`(C, H, W)` or
1991
+ If output_size is ``None``, then the shape of output computed by kernel_size, stride and padding.
1992
+ If output_size is not ``None``, then output_size must be :math:`(N, C, H, W)`, :math:`(C, H, W)` or
1971
1993
  :math:`(H, W)` and output_size must belong to
1972
1994
  :math:`[(N, C, H_{out} - stride[0], W_{out} - stride[1]), (N, C, H_{out} + stride[0], W_{out} + stride[1])]`.
1973
1995
 
@@ -2034,9 +2056,9 @@ class MaxUnpool3d(Cell):
2034
2056
 
2035
2057
  .. math::
2036
2058
  \begin{array}{ll} \\
2037
- D_{out} = (D{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
2038
- H_{out} = (H{in} - 1) \times stride[1] - 2 \times padding[1] + kernel\_size[1] \\
2039
- W_{out} = (W{in} - 1) \times stride[2] - 2 \times padding[2] + kernel\_size[2] \\
2059
+ D_{out} = (D_{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
2060
+ H_{out} = (H_{in} - 1) \times stride[1] - 2 \times padding[1] + kernel\_size[1] \\
2061
+ W_{out} = (W_{in} - 1) \times stride[2] - 2 \times padding[2] + kernel\_size[2] \\
2040
2062
  \end{array}
2041
2063
 
2042
2064
  Args:
@@ -2060,8 +2082,8 @@ class MaxUnpool3d(Cell):
2060
2082
  Values of indices must belong to :math:`[0, D_{in} \times H_{in} \times W_{in} - 1]`.
2061
2083
  Data type must be in int32 or int64.
2062
2084
  - **output_size** (tuple[int], optional) - The output size. Default: ``None`` .
2063
- If output_size is None, then the shape of output computed by kernel_size, stride and padding.
2064
- If output_size is not None, then output_size must be :math:`(N, C, D, H, W)` , :math:`(C, D, H, W)` or
2085
+ If output_size is ``None``, then the shape of output computed by kernel_size, stride and padding.
2086
+ If output_size is not ``None``, then output_size must be :math:`(N, C, D, H, W)` , :math:`(C, D, H, W)` or
2065
2087
  :math:`(D, H, W)` and output_size must belong to
2066
2088
  :math:`[(N, C, D_{out} - stride[0], H_{out} - stride[1], W_{out} - stride[2]),
2067
2089
  (N, C, D_{out} + stride[0], H_{out} + stride[1], W_{out} + stride[2])]`.
@@ -83,7 +83,7 @@ def _check_lstmcell_init(func):
83
83
 
84
84
 
85
85
  def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
86
- '''RNN cell function with tanh activation'''
86
+ """RNN cell function with tanh activation"""
87
87
  if b_ih is None:
88
88
  igates = P.MatMul(False, True)(inputs, w_ih)
89
89
  hgates = P.MatMul(False, True)(hidden, w_hh)
@@ -94,7 +94,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
94
94
 
95
95
 
96
96
  def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
97
- '''RNN cell function with relu activation'''
97
+ """RNN cell function with relu activation"""
98
98
  if b_ih is None:
99
99
  igates = P.MatMul(False, True)(inputs, w_ih)
100
100
  hgates = P.MatMul(False, True)(hidden, w_hh)
@@ -105,7 +105,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
105
105
 
106
106
 
107
107
  def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
108
- '''LSTM cell function'''
108
+ """LSTM cell function"""
109
109
  hx, cx = hidden
110
110
  if b_ih is None:
111
111
  gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
@@ -125,7 +125,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
125
125
 
126
126
 
127
127
  def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
128
- '''GRU cell function'''
128
+ """GRU cell function"""
129
129
  if b_ih is None:
130
130
  gi = P.MatMul(False, True)(inputs, w_ih)
131
131
  gh = P.MatMul(False, True)(hidden, w_hh)
@@ -144,8 +144,9 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
144
144
 
145
145
 
146
146
  class RNNCellBase(Cell):
147
- '''Basic class for RNN Cells'''
148
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
147
+ """Basic class for RNN Cells"""
148
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
149
+ dtype=mstype.float32):
149
150
  super().__init__()
150
151
  validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
151
152
  validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
@@ -153,20 +154,20 @@ class RNNCellBase(Cell):
153
154
  self.input_size = input_size
154
155
  self.hidden_size = hidden_size
155
156
  self.has_bias = has_bias
156
- self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
157
- self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
157
+ self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size), dtype=dtype))
158
+ self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size), dtype=dtype))
158
159
  if has_bias:
159
- self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
160
- self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
160
+ self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size), dtype=dtype))
161
+ self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size), dtype=dtype))
161
162
  else:
162
163
  self.bias_ih = None
163
164
  self.bias_hh = None
164
- self.reset_parameters()
165
+ self.reset_parameters(dtype=dtype)
165
166
 
166
- def reset_parameters(self):
167
+ def reset_parameters(self, dtype=mstype.float32):
167
168
  stdv = 1 / math.sqrt(self.hidden_size)
168
169
  for weight in self.get_parameters():
169
- weight.set_data(initializer(Uniform(stdv), weight.shape))
170
+ weight.set_data(initializer(Uniform(stdv), weight.shape, dtype))
170
171
 
171
172
 
172
173
  class RNNCell(RNNCellBase):
@@ -187,11 +188,11 @@ class RNNCell(RNNCellBase):
187
188
  has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: ``True`` .
188
189
  nonlinearity (str): The non-linearity to use. Can be either ``"tanh"`` or ``"relu"`` .
189
190
  Default: ``"tanh"`` .
191
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
190
192
 
191
193
  Inputs:
192
194
  - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
193
195
  - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)` .
194
- Data type of `hx` must be the same as `x`.
195
196
 
196
197
  Outputs:
197
198
  - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` .
@@ -219,8 +220,9 @@ class RNNCell(RNNCellBase):
219
220
  """
220
221
  _non_linearity = ['tanh', 'relu']
221
222
 
222
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
223
- super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
223
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh",
224
+ dtype=mstype.float32):
225
+ super().__init__(input_size, hidden_size, has_bias, num_chunks=1, dtype=dtype)
224
226
  validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
225
227
  validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
226
228
  self.nonlinearity = nonlinearity
@@ -270,11 +272,12 @@ class LSTMCell(RNNCellBase):
270
272
  input_size (int): Number of features of input.
271
273
  hidden_size (int): Number of features of hidden layer.
272
274
  has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: ``True`` .
275
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
273
276
 
274
277
  Inputs:
275
278
  - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
276
279
  - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
277
- and shape :math:`(batch\_size, hidden\_size)` . The data type of `hx` must be the same as `x`.
280
+ and shape :math:`(batch\_size, hidden\_size)` .
278
281
 
279
282
  Outputs:
280
283
  - **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)` .
@@ -301,8 +304,9 @@ class LSTMCell(RNNCellBase):
301
304
  (3, 16)
302
305
  """
303
306
  @_check_lstmcell_init
304
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
305
- super().__init__(input_size, hidden_size, has_bias, num_chunks=4)
307
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True,
308
+ dtype=mstype.float32):
309
+ super().__init__(input_size, hidden_size, has_bias, num_chunks=4, dtype=dtype)
306
310
  self.support_non_tensor_inputs = True
307
311
 
308
312
  def construct(self, x, hx):
@@ -352,11 +356,11 @@ class GRUCell(RNNCellBase):
352
356
  input_size (int): Number of features of input.
353
357
  hidden_size (int): Number of features of hidden layer.
354
358
  has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: ``True`` .
359
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
355
360
 
356
361
  Inputs:
357
362
  - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
358
363
  - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)` .
359
- Data type of `hx` must be the same as `x`.
360
364
 
361
365
  Outputs:
362
366
  - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` .
@@ -381,8 +385,9 @@ class GRUCell(RNNCellBase):
381
385
  >>> print(output[0].shape)
382
386
  (3, 16)
383
387
  """
384
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
385
- super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
388
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True,
389
+ dtype=mstype.float32):
390
+ super().__init__(input_size, hidden_size, has_bias, num_chunks=3, dtype=dtype)
386
391
 
387
392
  def construct(self, x, hx):
388
393
  _check_is_tensor('x', x, self.cls_name)