mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ from __future__ import division
18
18
 
19
19
  import itertools
20
20
  import numbers
21
+ import hashlib
21
22
 
22
23
  from mindspore.ops import operations as P
23
24
  from mindspore.ops import functional as F
@@ -25,21 +26,26 @@ from mindspore.ops.operations import _inner_ops as inner
25
26
  from mindspore.common.parameter import Parameter
26
27
  from mindspore.common.initializer import initializer, Initializer
27
28
  from mindspore.common.tensor import Tensor
28
- from mindspore.ops.primitive import constexpr
29
+ from mindspore.ops.primitive import constexpr, _primexpr
29
30
  import mindspore.context as context
30
- from mindspore._checkparam import Rel
31
- from mindspore._checkparam import Validator as validator
31
+ from mindspore import _checkparam as validator
32
32
  from mindspore._extends import cell_attr_register
33
33
  from mindspore.communication.management import get_group_size, get_rank
34
34
  from mindspore.communication import management
35
35
  from mindspore.common import dtype as mstype
36
36
  from mindspore.parallel._utils import _is_in_auto_parallel_mode
37
37
  from mindspore.nn.cell import Cell
38
+ from mindspore import log as logger
38
39
 
39
40
  __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
40
41
  'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
41
42
 
42
- SYNC_BN_GROUP_NAME = ""
43
+
44
+ def _check_dim(val, target, cls_name):
45
+ def _check(val, target, cls_name):
46
+ if val != target:
47
+ raise ValueError(f"For '{cls_name}', the in_shape must have {target} dims, but got {val}.")
48
+ _check(val, target, cls_name)
43
49
 
44
50
 
45
51
  class _BatchNorm(Cell):
@@ -121,11 +127,13 @@ class _BatchNorm(Cell):
121
127
  self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
122
128
  self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
123
129
 
130
+
124
131
  @staticmethod
125
- @constexpr
132
+ @_primexpr
126
133
  def _check_input_dim(shape, cls_name):
127
134
  raise NotImplementedError
128
135
 
136
+
129
137
  def construct(self, x):
130
138
  self._check_input_dim(self.shape(x), self.cls_name)
131
139
  if self.use_batch_statistics is None:
@@ -164,7 +172,7 @@ class _BatchNorm(Cell):
164
172
  class BatchNorm1d(_BatchNorm):
165
173
  r"""
166
174
  This layer
167
- applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
175
+ applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D or 2D inputs) to
168
176
  reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
169
177
  For the setailed contents, refer to `Batch Normalization: Accelerating Deep Network Training by
170
178
  Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
@@ -179,14 +187,14 @@ class BatchNorm1d(_BatchNorm):
179
187
  recommended to be changed after net was initialized.
180
188
 
181
189
  Args:
182
- num_features (int): `C` from an expected input of size (N, C).
183
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
190
+ num_features (int): number of features or channels `C` of the input `x` .
191
+ eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
184
192
  momentum (float): A floating hyperparameter of the momentum for the
185
193
  running_mean and running_var computation. Default: 0.9.
186
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
187
- gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
194
+ affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
195
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
188
196
  The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
189
- beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
197
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
190
198
  The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
191
199
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
192
200
  The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
@@ -200,10 +208,11 @@ class BatchNorm1d(_BatchNorm):
200
208
  Default: 'NCHW'.
201
209
 
202
210
  Inputs:
203
- - **x** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
211
+ - **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
212
+ where `N` is the batch size, `C` is the number of features or channels, and `L` is the sequence length.
204
213
 
205
214
  Outputs:
206
- Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
215
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C)` or :math:`(N, C, L)` .
207
216
 
208
217
  Raises:
209
218
  TypeError: If `num_features` is not an int.
@@ -228,11 +237,13 @@ class BatchNorm1d(_BatchNorm):
228
237
  """
229
238
 
230
239
  @staticmethod
231
- @constexpr
240
+ @_primexpr
232
241
  def _check_input_dim(shape, cls_name):
242
+ def _check(dim):
243
+ if dim not in (2, 3):
244
+ raise ValueError(f"For '{cls_name}', the must have 2 dims or 3 dims, but got {dim}.")
233
245
  dim = len(shape)
234
- if dim != 2:
235
- raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.")
246
+ _check(dim)
236
247
 
237
248
 
238
249
  class BatchNorm2d(_BatchNorm):
@@ -254,22 +265,22 @@ class BatchNorm2d(_BatchNorm):
254
265
  Note that the formula for updating the :math:`moving\_mean` and :math:`moving\_var` is
255
266
 
256
267
  .. math::
257
- \text{moving_mean}=\text{moving_meanmomentum}+μ_β\text{(1−momentum)}\\
258
- \text{moving_var}=\text{moving_varmomentum}+σ^2_β\text{(1−momentum)}
268
+ \text{moving_mean}=\text{moving_mean*momentum}+μ_β\text{*(1−momentum)}\\
269
+ \text{moving_var}=\text{moving_var*momentum}+σ^2_β\text{*(1−momentum)}
259
270
 
260
271
  where :math:`moving\_mean` is the updated mean, :math:`moving\_var` is the updated variance,
261
272
  :math:`μ_β, σ^2_β` are the observed value (mean and variance) of each batch of data.
262
273
 
263
274
  Args:
264
- num_features (int): The number of channels of the input tensor. Expected input size is (N, C, H, W),
275
+ num_features (int): The number of channels of the input tensor. Expected input size is :math:`(N, C, H, W)`,
265
276
  `C` represents the number of channels.
266
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
277
+ eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
267
278
  momentum (float): A floating hyperparameter of the momentum for the
268
279
  running_mean and running_var computation. Default: 0.9.
269
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
270
- gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
280
+ affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
281
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
271
282
  The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
272
- beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
283
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
273
284
  The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
274
285
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
275
286
  The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
@@ -288,10 +299,10 @@ class BatchNorm2d(_BatchNorm):
288
299
  Default: 'NCHW'.
289
300
 
290
301
  Inputs:
291
- - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
302
+ - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
292
303
 
293
304
  Outputs:
294
- Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
305
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`.
295
306
 
296
307
  Raises:
297
308
  TypeError: If `num_features` is not an int.
@@ -320,11 +331,10 @@ class BatchNorm2d(_BatchNorm):
320
331
  """
321
332
 
322
333
  @staticmethod
323
- @constexpr
334
+ @_primexpr
324
335
  def _check_input_dim(shape, cls_name):
325
336
  dim = len(shape)
326
- if dim != 4:
327
- raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
337
+ _check_dim(dim, 4, cls_name)
328
338
 
329
339
 
330
340
  class BatchNorm3d(Cell):
@@ -344,7 +354,7 @@ class BatchNorm3d(Cell):
344
354
  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
345
355
 
346
356
  Args:
347
- num_features (int): `C` from an expected input of size (N, C, D, H, W).
357
+ num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)` .
348
358
  eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
349
359
  momentum (float): A floating hyperparameter of the momentum for the
350
360
  running_mean and running_var computation. Default: 0.9.
@@ -414,11 +424,11 @@ class BatchNorm3d(Cell):
414
424
  self.reshape = P.Reshape()
415
425
 
416
426
  @staticmethod
417
- @constexpr
427
+ @_primexpr
418
428
  def _check_input_dim(shape, cls_name):
419
429
  dim = len(shape)
420
- if dim != 5:
421
- raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
430
+ _check_dim(dim, 5, cls_name)
431
+
422
432
 
423
433
  def construct(self, x):
424
434
  x_shape = self.shape(x)
@@ -429,6 +439,16 @@ class BatchNorm3d(Cell):
429
439
  return bn3d_out
430
440
 
431
441
 
442
+ SYNCBN_GROUP_DICT = None
443
+
444
+
445
+ def _syncbatchnorm_group_dict():
446
+ global SYNCBN_GROUP_DICT
447
+ if SYNCBN_GROUP_DICT is None:
448
+ SYNCBN_GROUP_DICT = dict()
449
+ return SYNCBN_GROUP_DICT
450
+
451
+
432
452
  class SyncBatchNorm(_BatchNorm):
433
453
  r"""
434
454
  Sync Batch Normalization layer over a N-dimension input.
@@ -446,15 +466,16 @@ class SyncBatchNorm(_BatchNorm):
446
466
  Currently, SyncBatchNorm only supports 2D and 4D inputs.
447
467
 
448
468
  Args:
449
- num_features (int): `C` from an expected input of size (N, C, H, W).
450
- eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
469
+ num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
470
+ eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: 1e-5.
451
471
  momentum (float): A floating hyperparameter of the momentum for the
452
472
  running_mean and running_var computation. Default: 0.9.
453
- affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
454
- gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
473
+ affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned.
474
+ Default: True.
475
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
455
476
  The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
456
477
  'he_uniform', etc. Default: 'ones'.
457
- beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
478
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
458
479
  The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
459
480
  'he_uniform', etc. Default: 'zeros'.
460
481
  moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
@@ -495,11 +516,11 @@ class SyncBatchNorm(_BatchNorm):
495
516
 
496
517
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
497
518
  Please see the `Ascend tutorial
498
- <https://www.mindspore.cn/tutorials/experts/en/r2.0.0-alpha/parallel/train_ascend.html#preparations>`_
519
+ <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
499
520
  for more details.
500
521
 
501
522
  For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
502
- <https://www.mindspore.cn/tutorials/experts/en/r2.0.0-alpha/parallel/train_gpu.html#preparation>`_ .
523
+ <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
503
524
 
504
525
  This example should be run with multiple devices.
505
526
 
@@ -525,7 +546,7 @@ class SyncBatchNorm(_BatchNorm):
525
546
  [[ 0.999995 0.999995 ]
526
547
  [ 0.999995 0.999995 ]]]]
527
548
  """
528
-
549
+ @cell_attr_register(attrs=['num_features', 'process_groups'])
529
550
  def __init__(self,
530
551
  num_features,
531
552
  eps=1e-5,
@@ -548,7 +569,7 @@ class SyncBatchNorm(_BatchNorm):
548
569
  moving_var_init,
549
570
  use_batch_statistics)
550
571
  self.is_global = False
551
- global SYNC_BN_GROUP_NAME
572
+ self.group_name = None
552
573
  self.process_groups = process_groups
553
574
  if self.process_groups != 0:
554
575
  self.rank_id = get_rank()
@@ -560,43 +581,53 @@ class SyncBatchNorm(_BatchNorm):
560
581
  elif self.rank_size > 1:
561
582
  self.is_global = True
562
583
  self.group_device_num = self.rank_size
563
- self.device_list = [i for i in range(0, self.rank_size)]
564
584
  if context.get_context("device_target") == "Ascend":
565
- if SYNC_BN_GROUP_NAME == "":
566
- SYNC_BN_GROUP_NAME = "sync_bn_group0"
567
- management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
585
+ self.group_name = "hccl_world_group"
568
586
  elif context.get_context("device_target") == "GPU":
569
- if SYNC_BN_GROUP_NAME == "":
570
- SYNC_BN_GROUP_NAME = "nccl_world_group"
587
+ self.group_name = "nccl_world_group"
571
588
 
572
589
  if self.is_global:
573
590
  self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
574
591
  momentum=self.momentum,
575
- group=SYNC_BN_GROUP_NAME,
592
+ group=self.group_name,
576
593
  device_num=self.group_device_num)
577
594
 
578
595
  def _create_sync_groups(self):
579
- for i in range(len(self.process_groups)):
580
- validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
581
- self.group_device_num = len(self.process_groups[i])
582
- if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
596
+ """ create groups by process groups. """
597
+ for sub_group in self.process_groups:
598
+ validator.check_isinstance("sub group", sub_group, list)
599
+ self.group_device_num = len(sub_group)
600
+ if self.rank_id in sub_group and self.group_device_num > 1:
583
601
  self.is_global = True
584
- global SYNC_BN_GROUP_NAME
585
- if SYNC_BN_GROUP_NAME == "":
586
- SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
587
- management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
602
+ rank_list_name = '_'.join('%s' % id for id in sub_group)
603
+ group_dict = _syncbatchnorm_group_dict()
604
+ if rank_list_name not in group_dict:
605
+ md5 = hashlib.md5()
606
+ md5.update(rank_list_name.encode('utf-8'))
607
+ hash_name = md5.hexdigest()
608
+ self.group_name = str(self.group_device_num) + '_' + hash_name
609
+ group_dict[rank_list_name] = self.group_name
610
+ management.create_group(self.group_name, sub_group)
611
+ logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
612
+ rank_list_name, self.group_name))
613
+ else:
614
+ self.group_name = group_dict[rank_list_name]
615
+ logger.info("the group for {} already exists, no need to create".format(rank_list_name))
588
616
 
589
617
  @staticmethod
590
- @constexpr
618
+ @_primexpr
591
619
  def _check_input_dim(shape, cls_name):
620
+ def _check(dim):
621
+ if dim not in (2, 4):
622
+ raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
592
623
  dim = len(shape)
593
- if dim not in (2, 4):
594
- raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
624
+ _check(dim)
625
+
595
626
 
596
627
  def _check_rank_ids(self, process_groups, rank_size):
597
628
  seen = set()
598
629
  for rid in itertools.chain(*process_groups):
599
- validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
630
+ validator.check_int_range(rid, 0, rank_size, validator.INC_LEFT, "rank id in process_groups", self.cls_name)
600
631
  if rid in seen:
601
632
  raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
602
633
  f"but got {process_groups}.")
@@ -625,13 +656,13 @@ class LayerNorm(Cell):
625
656
  begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
626
657
  will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
627
658
  the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
628
- gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
659
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
629
660
  The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
630
661
  'he_uniform', etc. Default: 'ones'.
631
- beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
662
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
632
663
  The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
633
664
  'he_uniform', etc. Default: 'zeros'.
634
- epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7.
665
+ epsilon (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-7.
635
666
 
636
667
  Inputs:
637
668
  - **x** (Tensor) - The shape of `x` is :math:`(x_1, x_2, ..., x_R)`,
@@ -775,7 +806,7 @@ class InstanceNorm1d(_InstanceNorm):
775
806
  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
776
807
 
777
808
  Args:
778
- num_features (int): `C` from an expected input of size (N, C, L).
809
+ num_features (int): `C` from an expected input of size :math:`(N, C, L)`.
779
810
  eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
780
811
  momentum (float): A floating hyperparameter of the momentum for the
781
812
  running_mean and running_var computation. Default: 0.1.
@@ -823,11 +854,11 @@ class InstanceNorm1d(_InstanceNorm):
823
854
  """
824
855
 
825
856
  @staticmethod
826
- @constexpr
857
+ @_primexpr
827
858
  def _check_input_dim(shape, cls_name):
828
859
  dim = len(shape)
829
- if dim != 3:
830
- raise ValueError(f"For '{cls_name}', the in_shape must have 3 dims, but got {dim}.")
860
+ _check_dim(dim, 3, cls_name)
861
+
831
862
 
832
863
 
833
864
  class InstanceNorm2d(_InstanceNorm):
@@ -854,7 +885,7 @@ class InstanceNorm2d(_InstanceNorm):
854
885
  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
855
886
 
856
887
  Args:
857
- num_features (int): `C` from an expected input of size (N, C, H, W).
888
+ num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
858
889
  eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
859
890
  momentum (float): A floating hyperparameter of the momentum for the
860
891
  running_mean and running_var computation. Default: 0.1.
@@ -902,11 +933,10 @@ class InstanceNorm2d(_InstanceNorm):
902
933
  """
903
934
 
904
935
  @staticmethod
905
- @constexpr
936
+ @_primexpr
906
937
  def _check_input_dim(shape, cls_name):
907
938
  dim = len(shape)
908
- if dim != 4:
909
- raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
939
+ _check_dim(dim, 4, cls_name)
910
940
 
911
941
 
912
942
  class InstanceNorm3d(_InstanceNorm):
@@ -933,7 +963,7 @@ class InstanceNorm3d(_InstanceNorm):
933
963
  where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
934
964
 
935
965
  Args:
936
- num_features (int): `C` from an expected input of size (N, C, D, H, W).
966
+ num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)`.
937
967
  eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
938
968
  momentum (float): A floating hyperparameter of the momentum for the
939
969
  running_mean and running_var computation. Default: 0.1.
@@ -979,12 +1009,12 @@ class InstanceNorm3d(_InstanceNorm):
979
1009
  >>> print(output.shape)
980
1010
  (2, 3, 5, 2, 2)
981
1011
  """
1012
+
982
1013
  @staticmethod
983
- @constexpr
1014
+ @_primexpr
984
1015
  def _check_input_dim(shape, cls_name):
985
1016
  dim = len(shape)
986
- if dim != 5:
987
- raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
1017
+ _check_dim(dim, 5, cls_name)
988
1018
 
989
1019
 
990
1020
  class GroupNorm(Cell):
@@ -1007,10 +1037,10 @@ class GroupNorm(Cell):
1007
1037
  affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
1008
1038
  gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
1009
1039
  The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
1010
- 'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be [num_channels].
1040
+ 'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be :math:`(num\_channels)`.
1011
1041
  beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
1012
1042
  The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
1013
- 'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be [num_channels].
1043
+ 'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be :math:`(num\_channels)`.
1014
1044
 
1015
1045
  Inputs:
1016
1046
  - **x** (Tensor) - The input feature with shape :math:`(N, C, H, W)` .
@@ -1079,19 +1109,20 @@ class GroupNorm(Cell):
1079
1109
  return output
1080
1110
 
1081
1111
  @staticmethod
1082
- @constexpr
1112
+ @_primexpr
1083
1113
  def _check_input_dim(shape, cls_name):
1084
1114
  dim = len(shape)
1085
- if dim != 4:
1086
- raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
1115
+ _check_dim(dim, 4, cls_name)
1087
1116
 
1088
1117
  @staticmethod
1089
- @constexpr
1118
+ @_primexpr
1090
1119
  def _channel_check(channel, num_channel, prim_name=None):
1091
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1092
- if channel != num_channel:
1093
- raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
1094
- f"but got channel: {channel}, num_channels: {num_channel}.")
1120
+ def _check():
1121
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1122
+ if channel != num_channel:
1123
+ raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to "
1124
+ f"num_channels, but got channel: {channel}, num_channels: {num_channel}.")
1125
+ _check()
1095
1126
 
1096
1127
  @staticmethod
1097
1128
  @constexpr