mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # ============================================================================
15
15
  """RNN Cells module, include RNNCell, GRUCell, LSTMCell."""
16
16
  from __future__ import absolute_import
17
+ from functools import wraps
17
18
 
18
19
  import math
19
20
  import numpy as np
@@ -39,8 +40,8 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
39
40
  @constexpr(check=False)
40
41
  def _check_is_tensor(param_name, input_data, cls_name):
41
42
  """Internal function, used to check whether the input data is Tensor."""
42
- if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
43
- raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.tensor_type}', "
43
+ if input_data is not None and not isinstance(P.typeof(input_data), mstype.TensorType):
44
+ raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.TensorType}', "
44
45
  f"but got '{P.typeof(input_data)}'")
45
46
 
46
47
 
@@ -68,6 +69,8 @@ def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
68
69
 
69
70
 
70
71
  def _check_lstmcell_init(func):
72
+ """Internal function, used to check init args."""
73
+ @wraps(func)
71
74
  def wrapper(*args, **kwargs):
72
75
  logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', "
73
76
  f"if you still need use single LSTM layer, please use `nn.LSTM` instead.")
@@ -142,7 +145,8 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
142
145
 
143
146
  class RNNCellBase(Cell):
144
147
  '''Basic class for RNN Cells'''
145
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
148
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
149
+ dtype=mstype.float32):
146
150
  super().__init__()
147
151
  validator.check_value_type("has_bias", has_bias, [bool], self.cls_name)
148
152
  validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
@@ -150,20 +154,20 @@ class RNNCellBase(Cell):
150
154
  self.input_size = input_size
151
155
  self.hidden_size = hidden_size
152
156
  self.has_bias = has_bias
153
- self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size).astype(np.float32)))
154
- self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size).astype(np.float32)))
157
+ self.weight_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, input_size), dtype=dtype))
158
+ self.weight_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size, hidden_size), dtype=dtype))
155
159
  if has_bias:
156
- self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
157
- self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size).astype(np.float32)))
160
+ self.bias_ih = Parameter(Tensor(np.random.randn(num_chunks * hidden_size), dtype=dtype))
161
+ self.bias_hh = Parameter(Tensor(np.random.randn(num_chunks * hidden_size), dtype=dtype))
158
162
  else:
159
163
  self.bias_ih = None
160
164
  self.bias_hh = None
161
- self.reset_parameters()
165
+ self.reset_parameters(dtype=dtype)
162
166
 
163
- def reset_parameters(self):
167
+ def reset_parameters(self, dtype=mstype.float32):
164
168
  stdv = 1 / math.sqrt(self.hidden_size)
165
169
  for weight in self.get_parameters():
166
- weight.set_data(initializer(Uniform(stdv), weight.shape))
170
+ weight.set_data(initializer(Uniform(stdv), weight.shape, dtype))
167
171
 
168
172
 
169
173
  class RNNCell(RNNCellBase):
@@ -181,13 +185,14 @@ class RNNCell(RNNCellBase):
181
185
  Args:
182
186
  input_size (int): Number of features of input.
183
187
  hidden_size (int): Number of features of hidden layer.
184
- has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
185
- nonlinearity (str): The non-linearity to use. Can be either `tanh` or `relu`. Default: `tanh`.
188
+ has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: ``True`` .
189
+ nonlinearity (str): The non-linearity to use. Can be either ``"tanh"`` or ``"relu"`` .
190
+ Default: ``"tanh"`` .
191
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
186
192
 
187
193
  Inputs:
188
194
  - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
189
195
  - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)` .
190
- Data type of `hx` must be the same as `x`.
191
196
 
192
197
  Outputs:
193
198
  - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` .
@@ -201,9 +206,11 @@ class RNNCell(RNNCellBase):
201
206
  ``Ascend`` ``GPU`` ``CPU``
202
207
 
203
208
  Examples:
204
- >>> net = nn.RNNCell(10, 16)
205
- >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
206
- >>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
209
+ >>> import mindspore as ms
210
+ >>> import numpy as np
211
+ >>> net = ms.nn.RNNCell(10, 16)
212
+ >>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
213
+ >>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32))
207
214
  >>> output = []
208
215
  >>> for i in range(5):
209
216
  ... hx = net(x[i], hx)
@@ -213,8 +220,9 @@ class RNNCell(RNNCellBase):
213
220
  """
214
221
  _non_linearity = ['tanh', 'relu']
215
222
 
216
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
217
- super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
223
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh",
224
+ dtype=mstype.float32):
225
+ super().__init__(input_size, hidden_size, has_bias, num_chunks=1, dtype=dtype)
218
226
  validator.check_value_type("nonlinearity", nonlinearity, [str], self.cls_name)
219
227
  validator.check_string(nonlinearity, self._non_linearity, "nonlinearity", self.cls_name)
220
228
  self.nonlinearity = nonlinearity
@@ -263,15 +271,16 @@ class LSTMCell(RNNCellBase):
263
271
  Args:
264
272
  input_size (int): Number of features of input.
265
273
  hidden_size (int): Number of features of hidden layer.
266
- has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
274
+ has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: ``True`` .
275
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
267
276
 
268
277
  Inputs:
269
- - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)`.
278
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
270
279
  - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
271
- and shape :math:`(batch\_size, hidden\_size)`. The data type of `hx` must be the same as `x`.
280
+ and shape :math:`(batch\_size, hidden\_size)` .
272
281
 
273
282
  Outputs:
274
- - **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)`.
283
+ - **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)` .
275
284
 
276
285
  Raises:
277
286
  TypeError: If `input_size`, `hidden_size` is not an int.
@@ -281,10 +290,12 @@ class LSTMCell(RNNCellBase):
281
290
  ``Ascend`` ``GPU`` ``CPU``
282
291
 
283
292
  Examples:
284
- >>> net = nn.LSTMCell(10, 16)
285
- >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
286
- >>> h = Tensor(np.ones([3, 16]).astype(np.float32))
287
- >>> c = Tensor(np.ones([3, 16]).astype(np.float32))
293
+ >>> import mindspore as ms
294
+ >>> import numpy as np
295
+ >>> net = ms.nn.LSTMCell(10, 16)
296
+ >>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
297
+ >>> h = ms.Tensor(np.ones([3, 16]).astype(np.float32))
298
+ >>> c = ms.Tensor(np.ones([3, 16]).astype(np.float32))
288
299
  >>> output = []
289
300
  >>> for i in range(5):
290
301
  ... hx = net(x[i], (h, c))
@@ -293,8 +304,9 @@ class LSTMCell(RNNCellBase):
293
304
  (3, 16)
294
305
  """
295
306
  @_check_lstmcell_init
296
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
297
- super().__init__(input_size, hidden_size, has_bias, num_chunks=4)
307
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True,
308
+ dtype=mstype.float32):
309
+ super().__init__(input_size, hidden_size, has_bias, num_chunks=4, dtype=dtype)
298
310
  self.support_non_tensor_inputs = True
299
311
 
300
312
  def construct(self, x, hx):
@@ -343,15 +355,15 @@ class GRUCell(RNNCellBase):
343
355
  Args:
344
356
  input_size (int): Number of features of input.
345
357
  hidden_size (int): Number of features of hidden layer.
346
- has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: True.
358
+ has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: ``True`` .
359
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
347
360
 
348
361
  Inputs:
349
- - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)`.
350
- - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)`.
351
- Data type of `hx` must be the same as `x`.
362
+ - **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)` .
363
+ - **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)` .
352
364
 
353
365
  Outputs:
354
- - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)`.
366
+ - **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)` .
355
367
 
356
368
  Raises:
357
369
  TypeError: If `input_size`, `hidden_size` is not an int.
@@ -361,9 +373,11 @@ class GRUCell(RNNCellBase):
361
373
  ``Ascend`` ``GPU`` ``CPU``
362
374
 
363
375
  Examples:
364
- >>> net = nn.GRUCell(10, 16)
365
- >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32))
366
- >>> hx = Tensor(np.ones([3, 16]).astype(np.float32))
376
+ >>> import mindspore as ms
377
+ >>> import numpy as np
378
+ >>> net = ms.nn.GRUCell(10, 16)
379
+ >>> x = ms.Tensor(np.ones([5, 3, 10]).astype(np.float32))
380
+ >>> hx = ms.Tensor(np.ones([3, 16]).astype(np.float32))
367
381
  >>> output = []
368
382
  >>> for i in range(5):
369
383
  ... hx = net(x[i], hx)
@@ -371,8 +385,9 @@ class GRUCell(RNNCellBase):
371
385
  >>> print(output[0].shape)
372
386
  (3, 16)
373
387
  """
374
- def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
375
- super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
388
+ def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True,
389
+ dtype=mstype.float32):
390
+ super().__init__(input_size, hidden_size, has_bias, num_chunks=3, dtype=dtype)
376
391
 
377
392
  def construct(self, x, hx):
378
393
  _check_is_tensor('x', x, self.cls_name)
@@ -35,10 +35,9 @@ from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_ce
35
35
  __all__ = ['LSTM', 'GRU', 'RNN']
36
36
 
37
37
 
38
- @_primexpr
39
38
  def _init_state(shape, dtype, is_lstm):
40
- hx = P.Zeros()(shape, dtype)
41
- cx = P.Zeros()(shape, dtype)
39
+ hx = P.zeros(shape, dtype)
40
+ cx = P.zeros(shape, dtype)
42
41
  if is_lstm:
43
42
  return (hx, cx)
44
43
  return hx
@@ -58,8 +57,8 @@ def _check_input_dtype_same_and_valid(args_name, args_value, valid_values, cls_n
58
57
  @constexpr(check=False)
59
58
  def _check_is_tensor(param_name, input_data, cls_name):
60
59
  """Internal function, used to check whether the input data is Tensor."""
61
- if input_data is not None and not isinstance(P.typeof(input_data), mstype.tensor_type):
62
- raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.tensor_type}', "
60
+ if input_data is not None and not isinstance(P.typeof(input_data), mstype.TensorType):
61
+ raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.TensorType}', "
63
62
  f"but got '{P.typeof(input_data)}'")
64
63
 
65
64
 
@@ -260,8 +259,8 @@ class _DynamicGRUAscend(Cell):
260
259
  def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
261
260
  '''Dynamic GRU module on Ascend'''
262
261
  if b_ih is None:
263
- b_ih = P.Zeros()(w_ih.shape[0], w_ih.dtype)
264
- b_hh = P.Zeros()(w_ih.shape[0], w_ih.dtype)
262
+ b_ih = P.zeros(w_ih.shape[0], w_ih.dtype)
263
+ b_hh = P.zeros(w_ih.shape[0], w_ih.dtype)
265
264
  outputs, _, _, _, _, _ = self.gru(self.cast(x, self.dtype), \
266
265
  self.cast(self.transpose(w_ih, (1, 0)), self.dtype), \
267
266
  self.cast(self.transpose(w_hh, (1, 0)), self.dtype), \
@@ -344,7 +343,7 @@ class _DynamicLSTMAscend(Cell):
344
343
  w_hh = self.concat_dim0((w_hh_i, w_hh_g, w_hh_f, w_hh_o))
345
344
  weight = self.concat_dim1((w_ih, w_hh))
346
345
  if b_ih is None:
347
- bias = P.Zeros()(w_ih.shape[0], w_ih.dtype)
346
+ bias = P.zeros(w_ih.shape[0], w_ih.dtype)
348
347
  else:
349
348
  b_ih_i, b_ih_f, b_ih_g, b_ih_o = self.split(b_ih)
350
349
  b_hh_i, b_hh_f, b_hh_g, b_hh_o = self.split(b_hh)
@@ -373,7 +372,7 @@ class _RNNBase(Cell):
373
372
  '''Basic class for RNN operators'''
374
373
 
375
374
  def __init__(self, mode, input_size, hidden_size, num_layers=1, has_bias=True,
376
- batch_first=False, dropout=0., bidirectional=False):
375
+ batch_first=False, dropout=0., bidirectional=False, dtype=mstype.float32):
377
376
  super().__init__()
378
377
  validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
379
378
  validator.check_positive_int(input_size, "input_size", self.cls_name)
@@ -436,17 +435,17 @@ class _RNNBase(Cell):
436
435
  suffix = '_reverse' if direction == 1 else ''
437
436
 
438
437
  self.w_ih_list.append(Parameter(
439
- Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)),
440
- name='weight_ih_l{}{}'.format(layer, suffix)))
438
+ Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32),
439
+ dtype=dtype), name='weight_ih_l{}{}'.format(layer, suffix)))
441
440
  self.w_hh_list.append(Parameter(
442
- Tensor(np.random.uniform(-stdv, stdv, (gate_size, hidden_size)).astype(np.float32)),
443
- name='weight_hh_l{}{}'.format(layer, suffix)))
441
+ Tensor(np.random.uniform(-stdv, stdv, (gate_size, hidden_size)).astype(np.float32),
442
+ dtype=dtype), name='weight_hh_l{}{}'.format(layer, suffix)))
444
443
  if has_bias:
445
444
  self.b_ih_list.append(Parameter(
446
- Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
445
+ Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32), dtype=dtype),
447
446
  name='bias_ih_l{}{}'.format(layer, suffix)))
448
447
  self.b_hh_list.append(Parameter(
449
- Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)),
448
+ Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32), dtype=dtype),
450
449
  name='bias_hh_l{}{}'.format(layer, suffix)))
451
450
  self.w_ih_list = ParameterTuple(self.w_ih_list)
452
451
  self.w_hh_list = ParameterTuple(self.w_hh_list)
@@ -579,9 +578,7 @@ class _RNNBase(Cell):
579
578
 
580
579
  class RNN(_RNNBase):
581
580
  r"""
582
- Stacked Elman RNN layers.
583
-
584
- Apply RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input.
581
+ Stacked Elman RNN layers, applying RNN layer with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to the input.
585
582
 
586
583
  For each element in the input sequence, each layer computes the following function:
587
584
 
@@ -596,23 +593,23 @@ class RNN(_RNNBase):
596
593
  Args:
597
594
  input_size (int): Number of features of input.
598
595
  hidden_size (int): Number of features of hidden layer.
599
- num_layers (int): Number of layers of stacked RNN. Default: 1.
596
+ num_layers (int): Number of layers of stacked RNN. Default: ``1`` .
600
597
  nonlinearity (str): The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
601
- has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
602
- batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
598
+ has_bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{hh}`. Default: ``True`` .
599
+ batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` .
603
600
  dropout (float): If not 0.0, append `Dropout` layer on the outputs of each
604
- RNN layer except the last layer. Default 0.0. The range of dropout is [0.0, 1.0).
601
+ RNN layer except the last layer. Default ``0.0`` . The range of dropout is [0.0, 1.0).
605
602
  bidirectional (bool): Specifies whether it is a bidirectional RNN,
606
- num_directions=2 if bidirectional=True otherwise 1. Default: False.
603
+ num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` .
604
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
607
605
 
608
606
  Inputs:
609
607
  - **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
610
608
  shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)` .
611
609
  - **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
612
610
  shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` .
613
- The data type of `hx` must be the same as `x`.
614
611
  - **seq_length** (Tensor) - The length of each sequence in an input batch.
615
- Tensor of shape :math:`(batch\_size)` . Default: None.
612
+ Tensor of shape :math:`(batch\_size)` . Default: ``None`` .
616
613
  This input indicates the real sequence length before padding to avoid padded elements
617
614
  have been used to compute hidden state and affect the final output. It is recommended to
618
615
  use this input when `x` has padding elements.
@@ -635,9 +632,11 @@ class RNN(_RNNBase):
635
632
  ``Ascend`` ``GPU`` ``CPU``
636
633
 
637
634
  Examples:
638
- >>> net = nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
639
- >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
640
- >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
635
+ >>> import mindspore as ms
636
+ >>> import numpy as np
637
+ >>> net = ms.nn.RNN(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
638
+ >>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32))
639
+ >>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
641
640
  >>> output, hn = net(x, h0)
642
641
  >>> print(output.shape)
643
642
  (3, 5, 16)
@@ -696,22 +695,22 @@ class GRU(_RNNBase):
696
695
  Args:
697
696
  input_size (int): Number of features of input.
698
697
  hidden_size (int): Number of features of hidden layer.
699
- num_layers (int): Number of layers of stacked GRU. Default: 1.
700
- has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: True.
701
- batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
698
+ num_layers (int): Number of layers of stacked GRU. Default: ``1`` .
699
+ has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: ``True`` .
700
+ batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` .
702
701
  dropout (float): If not 0.0, append `Dropout` layer on the outputs of each
703
- GRU layer except the last layer. Default 0.0. The range of dropout is [0.0, 1.0).
702
+ GRU layer except the last layer. Default ``0.0`` . The range of dropout is [0.0, 1.0).
704
703
  bidirectional (bool): Specifies whether it is a bidirectional GRU,
705
- num_directions=2 if bidirectional=True otherwise 1. Default: False.
704
+ num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` .
705
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
706
706
 
707
707
  Inputs:
708
708
  - **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
709
- shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
709
+ shape (seq\_len, batch\_size, `input\_size`) or :math:`(batch\_size, seq\_len, input\_size)`.
710
710
  - **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
711
- shape (num_directions * `num_layers`, batch_size, `hidden_size`). The data type of `hx` must be the same as
712
- `x`.
711
+ shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`.
713
712
  - **seq_length** (Tensor) - The length of each sequence in an input batch.
714
- Tensor of shape :math:`(\text{batch_size})`. Default: None.
713
+ Tensor of shape :math:`(\text{batch_size})`. Default: ``None`` .
715
714
  This input indicates the real sequence length before padding to avoid padded elements
716
715
  have been used to compute hidden state and affect the final output. It is recommended to
717
716
  use this input when **x** has padding elements.
@@ -719,9 +718,9 @@ class GRU(_RNNBase):
719
718
  Outputs:
720
719
  Tuple, a tuple contains (`output`, `h_n`).
721
720
 
722
- - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`) or
723
- (batch_size, seq_len, num_directions * `hidden_size`).
724
- - **hx_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
721
+ - **output** (Tensor) - Tensor of shape :math:`(seq\_len, batch\_size, num\_directions * hidden\_size)` or
722
+ :math:`(batch\_size, seq\_len, num\_directions * hidden\_size)`.
723
+ - **hx_n** (Tensor) - Tensor of shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`.
725
724
 
726
725
  Raises:
727
726
  TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
@@ -733,9 +732,11 @@ class GRU(_RNNBase):
733
732
  ``Ascend`` ``GPU`` ``CPU``
734
733
 
735
734
  Examples:
736
- >>> net = nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
737
- >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
738
- >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
735
+ >>> import mindspore as ms
736
+ >>> import numpy as np
737
+ >>> net = ms.nn.GRU(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
738
+ >>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32))
739
+ >>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
739
740
  >>> output, hn = net(x, h0)
740
741
  >>> print(output.shape)
741
742
  (3, 5, 16)
@@ -794,22 +795,22 @@ class LSTM(_RNNBase):
794
795
  Args:
795
796
  input_size (int): Number of features of input.
796
797
  hidden_size (int): Number of features of hidden layer.
797
- num_layers (int): Number of layers of stacked LSTM . Default: 1.
798
- has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
799
- batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: False.
798
+ num_layers (int): Number of layers of stacked LSTM . Default: ``1`` .
799
+ has_bias (bool): Whether the cell has bias :math:`b_{ih}` and :math:`b_{hh}`. Default: ``True`` .
800
+ batch_first (bool): Specifies whether the first dimension of input `x` is batch_size. Default: ``False`` .
800
801
  dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
801
- LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0).
802
+ LSTM layer except the last layer. Default ``0`` . The range of dropout is [0.0, 1.0).
802
803
  bidirectional (bool): Specifies whether it is a bidirectional LSTM,
803
- num_directions=2 if bidirectional=True otherwise 1. Default: False.
804
+ num_directions=2 if bidirectional=True otherwise 1. Default: ``False`` .
805
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
804
806
 
805
807
  Inputs:
806
808
  - **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
807
- shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)`.
809
+ shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)` .
808
810
  - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
809
- or mindspore.float16 and shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`.
810
- The data type of `hx` must be the same as `x`.
811
+ or mindspore.float16 and shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` .
811
812
  - **seq_length** (Tensor) - The length of each sequence in an input batch.
812
- Tensor of shape :math:`(batch\_size)`. Default: None.
813
+ Tensor of shape :math:`(batch\_size)`. Default: ``None`` .
813
814
  This input indicates the real sequence length before padding to avoid padded elements
814
815
  have been used to compute hidden state and affect the final output. It is recommended to
815
816
  use this input when **x** has padding elements.
@@ -831,10 +832,12 @@ class LSTM(_RNNBase):
831
832
  ``Ascend`` ``GPU`` ``CPU``
832
833
 
833
834
  Examples:
834
- >>> net = nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
835
- >>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
836
- >>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
837
- >>> c0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
835
+ >>> import mindspore as ms
836
+ >>> import numpy as np
837
+ >>> net = ms.nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
838
+ >>> x = ms.Tensor(np.ones([3, 5, 10]).astype(np.float32))
839
+ >>> h0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
840
+ >>> c0 = ms.Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
838
841
  >>> output, (hn, cn) = net(x, (h0, c0))
839
842
  >>> print(output.shape)
840
843
  (3, 5, 16)