mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -62,7 +62,8 @@ class _BatchNorm(Cell):
62
62
  moving_mean_init='zeros',
63
63
  moving_var_init='ones',
64
64
  use_batch_statistics=None,
65
- data_format='NCHW'):
65
+ data_format='NCHW',
66
+ dtype=mstype.float32):
66
67
  """Initialize _BatchNorm."""
67
68
  super(_BatchNorm, self).__init__()
68
69
  validator.check_value_type('num_features', num_features, [int], self.cls_name)
@@ -87,13 +88,13 @@ class _BatchNorm(Cell):
87
88
  self.moving_mean_init = moving_mean_init
88
89
  self.moving_var_init = moving_var_init
89
90
  self.moving_mean = Parameter(initializer(
90
- moving_mean_init, num_features), name="mean", requires_grad=False)
91
+ moving_mean_init, num_features, dtype=dtype), name="mean", requires_grad=False)
91
92
  self.moving_variance = Parameter(initializer(
92
- moving_var_init, num_features), name="variance", requires_grad=False)
93
+ moving_var_init, num_features, dtype=dtype), name="variance", requires_grad=False)
93
94
  self.gamma = Parameter(initializer(
94
- gamma_init, num_features), name="gamma", requires_grad=affine)
95
+ gamma_init, num_features, dtype=dtype), name="gamma", requires_grad=affine)
95
96
  self.beta = Parameter(initializer(
96
- beta_init, num_features), name="beta", requires_grad=affine)
97
+ beta_init, num_features, dtype=dtype), name="beta", requires_grad=affine)
97
98
 
98
99
  self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
99
100
 
@@ -188,28 +189,39 @@ class BatchNorm1d(_BatchNorm):
188
189
 
189
190
  Args:
190
191
  num_features (int): number of features or channels `C` of the input `x` .
191
- eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
192
+ eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: ``1e-5`` .
192
193
  momentum (float): A floating hyperparameter of the momentum for the
193
- running_mean and running_var computation. Default: 0.9.
194
- affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
194
+ running_mean and running_var computation. Default: ``0.9`` .
195
+ affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
196
+ Default: ``True`` .
195
197
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
196
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
198
+ The values of str refer to the function `mindspore.common.initializer
199
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
200
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
197
201
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
198
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
202
+ The values of str refer to the function `mindspore.common.initializer
203
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
204
+ including ``'zeros'`` , ``'ones'``, etc. Default: ``'zeros'`` .
199
205
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
200
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
206
+ The values of str refer to the function `mindspore.common.initializer
207
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
208
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
201
209
  moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
202
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
203
- use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
204
- use the mean value and variance value of specified value. If None, the training process will use the mean
205
- and variance of current batch data and track the running mean and variance, the evaluation process will use
206
- the running mean and variance. Default: None.
207
- data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
208
- Default: 'NCHW'.
210
+ The values of str refer to the function `mindspore.common.initializer
211
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
212
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
213
+ use_batch_statistics (bool): If ``true`` , use the mean value and variance value of current batch data. If
214
+ ``false`` , use the mean value and variance value of specified value. If ``None`` , the training process
215
+ will use the mean and variance of current batch data and track the running mean and variance, the
216
+ evaluation process will use the running mean and variance. Default: ``None`` .
217
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'`` .
218
+ Default: ``'NCHW'`` .
219
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
209
220
 
210
221
  Inputs:
211
222
  - **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
212
223
  where `N` is the batch size, `C` is the number of features or channels, and `L` is the sequence length.
224
+ Supported types: float16, float32.
213
225
 
214
226
  Outputs:
215
227
  Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C)` or :math:`(N, C, L)` .
@@ -225,11 +237,10 @@ class BatchNorm1d(_BatchNorm):
225
237
 
226
238
  Examples:
227
239
  >>> import numpy as np
228
- >>> import mindspore.nn as nn
229
- >>> from mindspore import Tensor
230
- >>> net = nn.BatchNorm1d(num_features=4)
231
- >>> x = Tensor(np.array([[0.7, 0.5, 0.5, 0.6],
232
- ... [0.5, 0.4, 0.6, 0.9]]).astype(np.float32))
240
+ >>> import mindspore as ms
241
+ >>> net = ms.nn.BatchNorm1d(num_features=4)
242
+ >>> x = ms.Tensor(np.array([[0.7, 0.5, 0.5, 0.6],
243
+ ... [0.5, 0.4, 0.6, 0.9]]).astype(np.float32))
233
244
  >>> output = net(x)
234
245
  >>> print(output)
235
246
  [[ 0.6999965 0.4999975 0.4999975 0.59999704 ]
@@ -274,32 +285,42 @@ class BatchNorm2d(_BatchNorm):
274
285
  Args:
275
286
  num_features (int): The number of channels of the input tensor. Expected input size is :math:`(N, C, H, W)`,
276
287
  `C` represents the number of channels.
277
- eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
288
+ eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: ``1e-5`` .
278
289
  momentum (float): A floating hyperparameter of the momentum for the
279
- running_mean and running_var computation. Default: 0.9.
280
- affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
290
+ running_mean and running_var computation. Default: ``0.9`` .
291
+ affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
292
+ Default: ``True`` .
281
293
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
282
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
294
+ The values of str refer to the function `mindspore.common.initializer
295
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
296
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
283
297
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
284
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
298
+ The values of str refer to the function `mindspore.common.initializer
299
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
300
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
285
301
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
286
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
302
+ The values of str refer to the function `mindspore.common.initializer
303
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
304
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
287
305
  moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
288
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
289
- use_batch_statistics (bool):
306
+ The values of str refer to the function `mindspore.common.initializer
307
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
308
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
309
+ use_batch_statistics (bool): Default: ``None`` .
290
310
 
291
- - If true, use the mean value and variance value of current batch data and track running mean
311
+ - If ``true`` , use the mean value and variance value of current batch data and track running mean
292
312
  and running variance.
293
- - If false, use the mean value and variance value of specified value, and not track statistical value.
294
- - If None, the use_batch_statistics is automatically set to true or false according to the training
295
- and evaluation mode. During training, the parameter is set to true, and during evaluation, the
296
- parameter is set to false. Default: None.
313
+ - If ``false`` , use the mean value and variance value of specified value, and not track statistical value.
314
+ - If ``None`` , the use_batch_statistics is automatically set to ``true`` or ``false`` according to the
315
+ training and evaluation mode. During training, the parameter is set to true, and during evaluation, the
316
+ parameter is set to false.
297
317
 
298
- data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
299
- Default: 'NCHW'.
318
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'`` .
319
+ Default: ``'NCHW'`` .
320
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
300
321
 
301
322
  Inputs:
302
- - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
323
+ - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Supported types: float16, float32.
303
324
 
304
325
  Outputs:
305
326
  Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`.
@@ -316,10 +337,9 @@ class BatchNorm2d(_BatchNorm):
316
337
 
317
338
  Examples:
318
339
  >>> import numpy as np
319
- >>> import mindspore.nn as nn
320
- >>> from mindspore import Tensor
321
- >>> net = nn.BatchNorm2d(num_features=3)
322
- >>> x = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
340
+ >>> import mindspore as ms
341
+ >>> net = ms.nn.BatchNorm2d(num_features=3)
342
+ >>> x = ms.Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
323
343
  >>> output = net(x)
324
344
  >>> print(output)
325
345
  [[[[ 0.999995 0.999995 ]
@@ -355,25 +375,35 @@ class BatchNorm3d(Cell):
355
375
 
356
376
  Args:
357
377
  num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)` .
358
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
378
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
359
379
  momentum (float): A floating hyperparameter of the momentum for the
360
- running_mean and running_var computation. Default: 0.9.
361
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
380
+ running_mean and running_var computation. Default: ``0.9`` .
381
+ affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
362
382
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
363
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
383
+ The values of str refer to the function `mindspore.common.initializer
384
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
385
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
364
386
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
365
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
387
+ The values of str refer to the function `mindspore.common.initializer
388
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
389
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
366
390
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
367
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
391
+ The values of str refer to the function `mindspore.common.initializer
392
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
393
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
368
394
  moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
369
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
370
- use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
371
- use the mean value and variance value of specified value. If None, the training process will use the mean
372
- and variance of current batch data and track the running mean and variance, the evaluation process will use
373
- the running mean and variance. Default: None.
395
+ The values of str refer to the function `mindspore.common.initializer
396
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
397
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
398
+ use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If
399
+ ``false``, use the mean value and variance value of specified value. If ``None`` , the training process
400
+ will use the mean and variance of current batch data and track the running mean and variance, the
401
+ evaluation process will use the running mean and variance. Default: ``None`` .
402
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
374
403
 
375
404
  Inputs:
376
405
  - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
406
+ Supported types: float16, float32.
377
407
 
378
408
  Outputs:
379
409
  Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`.
@@ -389,10 +419,9 @@ class BatchNorm3d(Cell):
389
419
 
390
420
  Examples:
391
421
  >>> import numpy as np
392
- >>> import mindspore.nn as nn
393
- >>> from mindspore import Tensor
394
- >>> net = nn.BatchNorm3d(num_features=3)
395
- >>> x = Tensor(np.ones([16, 3, 10, 32, 32]).astype(np.float32))
422
+ >>> import mindspore as ms
423
+ >>> net = ms.nn.BatchNorm3d(num_features=3)
424
+ >>> x = ms.Tensor(np.ones([16, 3, 10, 32, 32]).astype(np.float32))
396
425
  >>> output = net(x)
397
426
  >>> print(output.shape)
398
427
  (16, 3, 10, 32, 32)
@@ -407,7 +436,8 @@ class BatchNorm3d(Cell):
407
436
  beta_init='zeros',
408
437
  moving_mean_init='zeros',
409
438
  moving_var_init='ones',
410
- use_batch_statistics=None):
439
+ use_batch_statistics=None,
440
+ dtype=mstype.float32):
411
441
  """Initialize BatchNorm3d."""
412
442
  super(BatchNorm3d, self).__init__()
413
443
  self.bn2d = BatchNorm2d(num_features=num_features,
@@ -419,7 +449,8 @@ class BatchNorm3d(Cell):
419
449
  moving_mean_init=moving_mean_init,
420
450
  moving_var_init=moving_var_init,
421
451
  use_batch_statistics=use_batch_statistics,
422
- data_format="NCHW")
452
+ data_format="NCHW",
453
+ dtype=dtype)
423
454
  self.shape = P.Shape()
424
455
  self.reshape = P.Reshape()
425
456
 
@@ -464,34 +495,36 @@ class SyncBatchNorm(_BatchNorm):
464
495
 
465
496
  Note:
466
497
  Currently, SyncBatchNorm only supports 2D and 4D inputs.
498
+ :math:`\gamma` and :math:`\beta` are trainable scale and shift.
467
499
 
468
500
  Args:
469
501
  num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
470
- eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: 1e-5.
502
+ eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: ``1e-5`` .
471
503
  momentum (float): A floating hyperparameter of the momentum for the
472
- running_mean and running_var computation. Default: 0.9.
473
- affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned.
474
- Default: True.
504
+ running_mean and running_var computation. Default: ``0.9`` .
505
+ affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
506
+ Default: ``True`` .
475
507
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
476
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
477
- 'he_uniform', etc. Default: 'ones'.
508
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
509
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
478
510
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
479
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
480
- 'he_uniform', etc. Default: 'zeros'.
511
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
512
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
481
513
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
482
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
483
- 'he_uniform', etc. Default: 'zeros'.
514
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
515
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
484
516
  moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
485
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
486
- 'he_uniform', etc. Default: 'ones'.
487
- use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
488
- use the mean value and variance value of specified value. If None, training process will use the mean and
489
- variance of current batch data and track the running mean and variance, eval process will use the running
490
- mean and variance. Default: None.
517
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
518
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
519
+ use_batch_statistics (bool): If ``true`` , use the mean value and variance value of current batch data. If
520
+ ``false`` , use the mean value and variance value of specified value. If ``None`` , training process will
521
+ use the mean and variance of current batch data and track the running mean and variance, eval process will
522
+ use the running mean and variance. Default: ``None`` .
491
523
  process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists.
492
524
  Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same
493
- group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating
494
- synchronization across all devices.
525
+ group. All int values must be in [0, rank_size) and different from each other. Default: ``None`` ,
526
+ indicating synchronization across all devices.
527
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
495
528
 
496
529
  Inputs:
497
530
  - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -516,27 +549,27 @@ class SyncBatchNorm(_BatchNorm):
516
549
 
517
550
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
518
551
  Please see the `Ascend tutorial
519
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
552
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
520
553
  for more details.
521
554
 
522
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
523
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
555
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
556
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
557
+
558
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
559
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
524
560
 
525
561
  This example should be run with multiple devices.
526
562
 
527
563
  >>> import numpy as np
528
564
  >>> import mindspore as ms
529
565
  >>> from mindspore.communication import init
530
- >>> from mindspore import Tensor
531
- >>> from mindspore import nn
532
- >>> from mindspore import dtype as mstype
533
566
  >>>
534
567
  >>> ms.set_context(mode=ms.GRAPH_MODE)
535
568
  >>> init()
536
569
  >>> ms.reset_auto_parallel_context()
537
570
  >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
538
- >>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]])
539
- >>> x = Tensor(np.ones([1, 3, 2, 2]), mstype.float32)
571
+ >>> sync_bn_op = ms.nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]])
572
+ >>> x = ms.Tensor(np.ones([1, 3, 2, 2]), ms.float32)
540
573
  >>> output = sync_bn_op(x)
541
574
  >>> print(output)
542
575
  [[[[ 0.999995 0.999995 ]
@@ -557,7 +590,8 @@ class SyncBatchNorm(_BatchNorm):
557
590
  moving_mean_init='zeros',
558
591
  moving_var_init='ones',
559
592
  use_batch_statistics=None,
560
- process_groups=None):
593
+ process_groups=None,
594
+ dtype=mstype.float32):
561
595
  """Initialize SyncBatchNorm."""
562
596
  super(SyncBatchNorm, self).__init__(num_features,
563
597
  eps,
@@ -567,7 +601,8 @@ class SyncBatchNorm(_BatchNorm):
567
601
  beta_init,
568
602
  moving_mean_init,
569
603
  moving_var_init,
570
- use_batch_statistics)
604
+ use_batch_statistics,
605
+ dtype=dtype)
571
606
  self.is_global = False
572
607
  self.group_name = None
573
608
  self.process_groups = process_groups
@@ -642,27 +677,28 @@ class LayerNorm(Cell):
642
677
  normalization on a mini-batch of inputs for each single training case as described
643
678
  in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch
644
679
  Normalization, Layer Normalization performs exactly the same computation at training and
645
- testing time. It is applied across all channels
646
- and pixel but only one batch size. It can be described using the following formula:
680
+ testing time. It is applied across all channels and pixel but only one batch size.
681
+ :math:`\gamma` and :math:`\beta` are trainable scale and shift.
682
+ It can be described using the following formula:
647
683
 
648
684
  .. math::
649
685
  y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
650
686
 
651
687
  Args:
652
688
  normalized_shape (Union(tuple[int], list[int])): The normalization is performed over axis
653
- `begin_norm_axis ... R - 1`.
689
+ `begin_norm_axis ... R - 1`. R is the dimension size of input `x`.
654
690
  begin_norm_axis (int): The first normalization dimension: normalization will be performed along dimensions
655
- `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
656
- begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
657
- will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
658
- the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
691
+ `begin_norm_axis: R`, the value should be in [-1, R). Default: ``-1`` .
692
+ begin_params_axis (int): The begin axis of the parameter input :math:`(\gamma, \beta)` to
693
+ apply LayerNorm, the value should be in [-1, R). Default: ``-1`` .
659
694
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
660
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
661
- 'he_uniform', etc. Default: 'ones'.
695
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
696
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
662
697
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
663
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
664
- 'he_uniform', etc. Default: 'zeros'.
665
- epsilon (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-7.
698
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
699
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
700
+ epsilon (float): A value added to the denominator for numerical stability(:math:`\epsilon`). Default: ``1e-7`` .
701
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
666
702
 
667
703
  Inputs:
668
704
  - **x** (Tensor) - The shape of `x` is :math:`(x_1, x_2, ..., x_R)`,
@@ -680,9 +716,11 @@ class LayerNorm(Cell):
680
716
  ``Ascend`` ``GPU`` ``CPU``
681
717
 
682
718
  Examples:
683
- >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
719
+ >>> import mindspore as ms
720
+ >>> import numpy as np
721
+ >>> x = ms.Tensor(np.ones([20, 5, 10, 10]), ms.float32)
684
722
  >>> shape1 = x.shape[1:]
685
- >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
723
+ >>> m = ms.nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
686
724
  >>> output = m(x).shape
687
725
  >>> print(output)
688
726
  (20, 5, 10, 10)
@@ -694,21 +732,27 @@ class LayerNorm(Cell):
694
732
  begin_params_axis=-1,
695
733
  gamma_init='ones',
696
734
  beta_init='zeros',
697
- epsilon=1e-7
735
+ epsilon=1e-7,
736
+ dtype=mstype.float32
698
737
  ):
699
738
  """Initialize LayerNorm."""
700
739
  super(LayerNorm, self).__init__()
701
740
  if not isinstance(normalized_shape, (tuple, list)):
702
741
  raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], "
703
742
  f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
743
+ if not normalized_shape:
744
+ raise ValueError(
745
+ f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at "
746
+ f"least one element, but got normalized_shape = {normalized_shape}"
747
+ )
704
748
  self.normalized_shape = normalized_shape
705
749
  self.begin_norm_axis = begin_norm_axis
706
750
  self.begin_params_axis = begin_params_axis
707
751
  self.epsilon = epsilon
708
752
  self.gamma = Parameter(initializer(
709
- gamma_init, normalized_shape), name="gamma")
753
+ gamma_init, normalized_shape, dtype=dtype), name="gamma")
710
754
  self.beta = Parameter(initializer(
711
- beta_init, normalized_shape), name="beta")
755
+ beta_init, normalized_shape, dtype=dtype), name="beta")
712
756
  self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
713
757
  begin_params_axis=self.begin_params_axis,
714
758
  epsilon=self.epsilon)
@@ -731,7 +775,8 @@ class _InstanceNorm(Cell):
731
775
  momentum=0.1,
732
776
  affine=True,
733
777
  gamma_init='ones',
734
- beta_init='zeros'):
778
+ beta_init='zeros',
779
+ dtype=mstype.float32):
735
780
  """Initialize Normalization base class."""
736
781
  super(_InstanceNorm, self).__init__()
737
782
  validator.check_value_type('num_features', num_features, [int], self.cls_name)
@@ -748,12 +793,13 @@ class _InstanceNorm(Cell):
748
793
  f"but got {momentum}.")
749
794
  self.num_features = num_features
750
795
  self.eps = eps
751
- self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
752
- self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
796
+ self.moving_mean = Parameter(initializer('zeros', num_features, dtype=dtype), name="mean", requires_grad=False)
797
+ self.moving_variance = Parameter(initializer('ones', num_features, dtype=dtype), name="variance",
798
+ requires_grad=False)
753
799
  self.gamma = Parameter(initializer(
754
- gamma_init, num_features), name="gamma", requires_grad=affine)
800
+ gamma_init, num_features, dtype=dtype), name="gamma", requires_grad=affine)
755
801
  self.beta = Parameter(initializer(
756
- beta_init, num_features), name="beta", requires_grad=affine)
802
+ beta_init, num_features, dtype=dtype), name="beta", requires_grad=affine)
757
803
 
758
804
  self.shape = P.Shape()
759
805
  self.momentum = momentum
@@ -807,16 +853,17 @@ class InstanceNorm1d(_InstanceNorm):
807
853
 
808
854
  Args:
809
855
  num_features (int): `C` from an expected input of size :math:`(N, C, L)`.
810
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
856
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
811
857
  momentum (float): A floating hyperparameter of the momentum for the
812
- running_mean and running_var computation. Default: 0.1.
813
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
858
+ running_mean and running_var computation. Default: ``0.1`` .
859
+ affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: ``True`` .
814
860
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
815
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
816
- When initialized with Tensor, the shape should be :math:`(C)`. Default: 'zeros'.
861
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
862
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'ones'`` .
817
863
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
818
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
819
- When initialized with Tensor, the shape should be :math:`(C)`. Default: 'zeros'.
864
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
865
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
866
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
820
867
 
821
868
  Inputs:
822
869
  - **x** (Tensor) - Tensor of shape :math:`(N, C, L)`. Data type: float16 or float32.
@@ -842,12 +889,10 @@ class InstanceNorm1d(_InstanceNorm):
842
889
  ``GPU``
843
890
 
844
891
  Examples:
845
- >>> import mindspore
892
+ >>> import mindspore as ms
846
893
  >>> import numpy as np
847
- >>> import mindspore.nn as nn
848
- >>> from mindspore import Tensor
849
- >>> net = nn.InstanceNorm1d(3)
850
- >>> x = Tensor(np.ones([2, 3, 5]), mindspore.float32)
894
+ >>> net = ms.nn.InstanceNorm1d(3)
895
+ >>> x = ms.Tensor(np.ones([2, 3, 5]), ms.float32)
851
896
  >>> output = net(x)
852
897
  >>> print(output.shape)
853
898
  (2, 3, 5)
@@ -886,16 +931,17 @@ class InstanceNorm2d(_InstanceNorm):
886
931
 
887
932
  Args:
888
933
  num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
889
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
934
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
890
935
  momentum (float): A floating hyperparameter of the momentum for the
891
- running_mean and running_var computation. Default: 0.1.
892
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
936
+ running_mean and running_var computation. Default: ``0.1`` .
937
+ affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
893
938
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
894
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
895
- When initialized with Tensor, the shape should be :math:`(C)`. Default: 'zeros'.
939
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
940
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'ones'`` .
896
941
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
897
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
898
- When initialized with Tensor, the shape should be :math:`(C)`. Default: 'zeros'.
942
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
943
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
944
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
899
945
 
900
946
  Inputs:
901
947
  - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
@@ -921,12 +967,10 @@ class InstanceNorm2d(_InstanceNorm):
921
967
  ``GPU``
922
968
 
923
969
  Examples:
924
- >>> import mindspore
970
+ >>> import mindspore as ms
925
971
  >>> import numpy as np
926
- >>> import mindspore.nn as nn
927
- >>> from mindspore import Tensor
928
- >>> net = nn.InstanceNorm2d(3)
929
- >>> x = Tensor(np.ones([2, 3, 2, 2]), mindspore.float32)
972
+ >>> net = ms.nn.InstanceNorm2d(3)
973
+ >>> x = ms.Tensor(np.ones([2, 3, 2, 2]), ms.float32)
930
974
  >>> output = net(x)
931
975
  >>> print(output.shape)
932
976
  (2, 3, 2, 2)
@@ -964,16 +1008,17 @@ class InstanceNorm3d(_InstanceNorm):
964
1008
 
965
1009
  Args:
966
1010
  num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)`.
967
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
1011
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
968
1012
  momentum (float): A floating hyperparameter of the momentum for the
969
- running_mean and running_var computation. Default: 0.1.
970
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
1013
+ running_mean and running_var computation. Default: ``0.1`` .
1014
+ affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
971
1015
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
972
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
973
- When initialized with Tensor, the shape should be :math:`(C)`. Default: 'zeros'.
1016
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
1017
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'ones'`` .
974
1018
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
975
- The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
976
- When initialized with Tensor, the shape should be :math:`(C)`. Default: 'zeros'.
1019
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
1020
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
1021
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
977
1022
 
978
1023
  Inputs:
979
1024
  - **x** (Tensor) - Tensor of shape :math:`(N, C, D, H, W)`. Data type: float16 or float32.
@@ -999,12 +1044,10 @@ class InstanceNorm3d(_InstanceNorm):
999
1044
  ``GPU``
1000
1045
 
1001
1046
  Examples:
1002
- >>> import mindspore
1047
+ >>> import mindspore as ms
1003
1048
  >>> import numpy as np
1004
- >>> import mindspore.nn as nn
1005
- >>> from mindspore import Tensor
1006
- >>> net = nn.InstanceNorm3d(3)
1007
- >>> x = Tensor(np.ones([2, 3, 5, 2, 2]), mindspore.float32)
1049
+ >>> net = ms.nn.InstanceNorm3d(3)
1050
+ >>> x = ms.Tensor(np.ones([2, 3, 5, 2, 2]), ms.float32)
1008
1051
  >>> output = net(x)
1009
1052
  >>> print(output.shape)
1010
1053
  (2, 3, 5, 2, 2)
@@ -1025,7 +1068,9 @@ class GroupNorm(Cell):
1025
1068
  normalization on a mini-batch of inputs for each single training case as described
1026
1069
  in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group Normalization
1027
1070
  divides the channels into groups and computes within each group the mean and variance for normalization,
1028
- and it performs very stable over a wide range of batch size. It can be described using the following formula:
1071
+ and it performs very stable over a wide range of batch size. :math:`\gamma` and :math:`\beta` are trainable scale
1072
+ and shift.
1073
+ It can be described using the following formula:
1029
1074
 
1030
1075
  .. math::
1031
1076
  y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
@@ -1033,14 +1078,18 @@ class GroupNorm(Cell):
1033
1078
  Args:
1034
1079
  num_groups (int): The number of groups to be divided along the channel dimension.
1035
1080
  num_channels (int): The number of input channels.
1036
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
1037
- affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
1081
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-05`` .
1082
+ affine (bool): A bool value, this layer will have learnable affine parameters when set to ``true`` .
1083
+ Default: ``True`` .
1038
1084
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
1039
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
1040
- 'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be :math:`(num\_channels)`.
1085
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
1086
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` . If gamma_init is a Tensor, the shape
1087
+ must be :math:`(num\_channels)`.
1041
1088
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
1042
- The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
1043
- 'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be :math:`(num\_channels)`.
1089
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
1090
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` . If beta_init is a Tensor, the shape
1091
+ must be :math:`(num\_channels)`.
1092
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
1044
1093
 
1045
1094
  Inputs:
1046
1095
  - **x** (Tensor) - The input feature with shape :math:`(N, C, H, W)` .
@@ -1059,8 +1108,10 @@ class GroupNorm(Cell):
1059
1108
  ``Ascend`` ``GPU`` ``CPU``
1060
1109
 
1061
1110
  Examples:
1062
- >>> group_norm_op = nn.GroupNorm(2, 2)
1063
- >>> x = Tensor(np.ones([1, 2, 4, 4], np.float32))
1111
+ >>> import mindspore as ms
1112
+ >>> import numpy as np
1113
+ >>> group_norm_op = ms.nn.GroupNorm(2, 2)
1114
+ >>> x = ms.Tensor(np.ones([1, 2, 4, 4], np.float32))
1064
1115
  >>> output = group_norm_op(x)
1065
1116
  >>> print(output)
1066
1117
  [[[[0. 0. 0. 0.]
@@ -1073,7 +1124,8 @@ class GroupNorm(Cell):
1073
1124
  [0. 0. 0. 0.]]]]
1074
1125
  """
1075
1126
 
1076
- def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
1127
+ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros',
1128
+ dtype=mstype.float32):
1077
1129
  """Initialize GroupNorm."""
1078
1130
  super(GroupNorm, self).__init__()
1079
1131
  self.num_groups = validator.check_positive_int(num_groups, "num_groups", self.cls_name)
@@ -1085,27 +1137,27 @@ class GroupNorm(Cell):
1085
1137
  self.affine = validator.check_bool(affine, arg_name="affine", prim_name=self.cls_name)
1086
1138
 
1087
1139
  self.gamma = Parameter(initializer(
1088
- gamma_init, num_channels), name="gamma", requires_grad=affine)
1140
+ gamma_init, num_channels, dtype=dtype), name="gamma", requires_grad=affine)
1089
1141
  self.beta = Parameter(initializer(
1090
- beta_init, num_channels), name="beta", requires_grad=affine)
1142
+ beta_init, num_channels, dtype=dtype), name="beta", requires_grad=affine)
1143
+ self.reduce_mean = P.ReduceMean(keep_dims=True)
1144
+ self.reduce_sum = P.ReduceSum(keep_dims=True)
1091
1145
  self.shape = F.shape
1092
1146
  self.reshape = F.reshape
1093
- self.reduce_mean = P.ReduceMean(keep_dims=True)
1094
1147
  self.square = F.square
1095
- self.reduce_sum = P.ReduceSum(keep_dims=True)
1096
1148
  self.sqrt = P.Sqrt()
1097
1149
 
1098
1150
  def _cal_output(self, x):
1099
1151
  """calculate groupnorm output"""
1100
- batch, channel, height, width = self.shape(x)
1152
+ batch, channel, height, width = F.shape(x)
1101
1153
  self._channel_check(channel, self.num_channels, self.cls_name)
1102
- x = self.reshape(x, (batch, self.num_groups, -1))
1154
+ x = F.reshape(x, (batch, self.num_groups, -1))
1103
1155
  mean = self.reduce_mean(x, 2)
1104
- var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
1156
+ var = F.div(self.reduce_sum(F.square(F.sub(x, mean)), 2), (channel * height * width / self.num_groups))
1105
1157
  std = self.sqrt(var + self.eps)
1106
- x = (x - mean) / std
1107
- x = self.reshape(x, (batch, channel, height, width))
1108
- output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
1158
+ x = F.div(F.sub(x, mean), std)
1159
+ x = F.reshape(x, (batch, channel, height, width))
1160
+ output = F.add(x * F.reshape(self.gamma, (-1, 1, 1)), F.reshape(self.beta, (-1, 1, 1)))
1109
1161
  return output
1110
1162
 
1111
1163
  @staticmethod
@@ -1133,7 +1185,7 @@ class GroupNorm(Cell):
1133
1185
  return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
1134
1186
 
1135
1187
  def construct(self, x):
1136
- self._check_input_dim(self.shape(x), self.cls_name)
1188
+ self._check_input_dim(F.shape(x), self.cls_name)
1137
1189
  self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
1138
1190
  output = self._cal_output(x)
1139
1191
  return output