mindspore 2.0.0a0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (693) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +4 -2
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  13. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  14. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  15. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  17. mindspore/_extends/parse/__init__.py +5 -3
  18. mindspore/_extends/parse/namespace.py +16 -1
  19. mindspore/_extends/parse/parser.py +107 -22
  20. mindspore/_extends/parse/resources.py +0 -7
  21. mindspore/_extends/parse/standard_method.py +885 -413
  22. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  23. mindspore/amp.py +52 -57
  24. mindspore/bin/cache_admin +0 -0
  25. mindspore/bin/cache_server +0 -0
  26. mindspore/boost/boost.py +2 -2
  27. mindspore/boost/boost_cell_wrapper.py +38 -20
  28. mindspore/boost/dim_reduce.py +3 -3
  29. mindspore/boost/group_loss_scale_manager.py +1 -1
  30. mindspore/common/__init__.py +4 -6
  31. mindspore/common/_decorator.py +2 -0
  32. mindspore/common/_register_for_adapter.py +55 -0
  33. mindspore/common/_stub_tensor.py +201 -0
  34. mindspore/common/_utils.py +41 -7
  35. mindspore/common/api.py +215 -141
  36. mindspore/common/dtype.py +8 -1
  37. mindspore/common/dump.py +2 -2
  38. mindspore/common/initializer.py +4 -2
  39. mindspore/common/jit_config.py +17 -13
  40. mindspore/common/mutable.py +33 -13
  41. mindspore/common/parameter.py +23 -21
  42. mindspore/common/seed.py +8 -24
  43. mindspore/common/sparse_tensor.py +62 -41
  44. mindspore/common/tensor.py +852 -1154
  45. mindspore/communication/__init__.py +2 -2
  46. mindspore/communication/_comm_helper.py +11 -4
  47. mindspore/communication/management.py +22 -21
  48. mindspore/config/op_info.config +501 -1008
  49. mindspore/config/super_bar_config.json +512 -0
  50. mindspore/context.py +201 -23
  51. mindspore/dataset/__init__.py +6 -6
  52. mindspore/dataset/audio/__init__.py +7 -7
  53. mindspore/dataset/audio/transforms.py +670 -30
  54. mindspore/dataset/audio/utils.py +47 -4
  55. mindspore/dataset/audio/validators.py +223 -1
  56. mindspore/dataset/callback/ds_callback.py +2 -2
  57. mindspore/dataset/core/config.py +210 -14
  58. mindspore/dataset/core/validator_helpers.py +2 -2
  59. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  60. mindspore/dataset/debug/debug_hook.py +65 -0
  61. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  62. mindspore/dataset/engine/__init__.py +7 -3
  63. mindspore/dataset/engine/cache_client.py +1 -1
  64. mindspore/dataset/engine/datasets.py +322 -66
  65. mindspore/dataset/engine/datasets_audio.py +80 -76
  66. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  67. mindspore/dataset/engine/datasets_text.py +232 -118
  68. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  69. mindspore/dataset/engine/datasets_vision.py +746 -225
  70. mindspore/dataset/engine/graphdata.py +75 -10
  71. mindspore/dataset/engine/iterators.py +45 -5
  72. mindspore/dataset/engine/offload.py +48 -28
  73. mindspore/dataset/engine/validators.py +117 -8
  74. mindspore/dataset/text/__init__.py +6 -5
  75. mindspore/dataset/text/transforms.py +86 -3
  76. mindspore/dataset/text/utils.py +6 -4
  77. mindspore/dataset/text/validators.py +25 -0
  78. mindspore/dataset/transforms/__init__.py +3 -2
  79. mindspore/dataset/transforms/c_transforms.py +1 -1
  80. mindspore/dataset/transforms/transforms.py +2 -2
  81. mindspore/dataset/utils/__init__.py +2 -1
  82. mindspore/dataset/utils/line_reader.py +121 -0
  83. mindspore/dataset/vision/__init__.py +2 -3
  84. mindspore/dataset/vision/c_transforms.py +9 -9
  85. mindspore/dataset/vision/py_transforms.py +5 -5
  86. mindspore/dataset/vision/py_transforms_util.py +2 -0
  87. mindspore/dataset/vision/transforms.py +160 -161
  88. mindspore/dataset/vision/utils.py +3 -3
  89. mindspore/experimental/map_parameter.py +38 -26
  90. mindspore/include/OWNERS +0 -1
  91. mindspore/include/api/callback/callback.h +9 -13
  92. mindspore/include/api/callback/ckpt_saver.h +2 -2
  93. mindspore/include/api/callback/loss_monitor.h +2 -2
  94. mindspore/include/api/callback/lr_scheduler.h +5 -5
  95. mindspore/include/api/callback/time_monitor.h +2 -2
  96. mindspore/include/api/callback/train_accuracy.h +4 -6
  97. mindspore/include/api/cfg.h +19 -6
  98. mindspore/include/api/context.h +44 -9
  99. mindspore/include/api/delegate.h +1 -1
  100. mindspore/include/api/metrics/accuracy.h +2 -2
  101. mindspore/include/api/metrics/metrics.h +4 -3
  102. mindspore/include/api/model.h +9 -4
  103. mindspore/include/api/model_parallel_runner.h +2 -2
  104. mindspore/include/api/net.h +12 -11
  105. mindspore/include/api/serialization.h +19 -3
  106. mindspore/include/api/types.h +3 -3
  107. mindspore/include/dataset/constants.h +7 -0
  108. mindspore/include/dataset/text.h +59 -0
  109. mindspore/include/mindapi/base/type_id.h +1 -0
  110. mindspore/lib/libdnnl.so.2 +0 -0
  111. mindspore/lib/libicudata.so.69 +0 -0
  112. mindspore/lib/libicui18n.so.69 +0 -0
  113. mindspore/lib/libicuuc.so.69 +0 -0
  114. mindspore/lib/libmindspore.so +0 -0
  115. mindspore/lib/libmindspore_backend.so +0 -0
  116. mindspore/lib/libmindspore_common.so +0 -0
  117. mindspore/lib/libmindspore_core.so +0 -0
  118. mindspore/lib/libmindspore_glog.so.0 +0 -0
  119. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  120. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  121. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  122. mindspore/lib/libmindspore_shared_lib.so +0 -0
  123. mindspore/lib/libmpi_adapter.so +0 -0
  124. mindspore/lib/libmpi_collective.so +0 -0
  125. mindspore/lib/libnnacl.so +0 -0
  126. mindspore/lib/libopencv_core.so.4.5 +0 -0
  127. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  128. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  129. mindspore/lib/libps_cache.so +0 -0
  130. mindspore/lib/plugin/ascend/libakg.so +0 -0
  131. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  132. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  133. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  134. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  135. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  136. mindspore/lib/plugin/cpu/libakg.so +0 -0
  137. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  138. mindspore/lib/plugin/{libmindspore_ascend.so → libmindspore_ascend.so.2} +0 -0
  139. mindspore/log.py +1 -1
  140. mindspore/mindrecord/filereader.py +18 -0
  141. mindspore/mindrecord/filewriter.py +197 -34
  142. mindspore/mindrecord/shardreader.py +9 -0
  143. mindspore/mindrecord/shardwriter.py +1 -1
  144. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  145. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  146. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  147. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  148. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  149. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  150. mindspore/nn/__init__.py +0 -4
  151. mindspore/nn/cell.py +204 -132
  152. mindspore/nn/dynamic_lr.py +1 -1
  153. mindspore/nn/grad/cell_grad.py +7 -6
  154. mindspore/nn/layer/__init__.py +5 -4
  155. mindspore/nn/layer/activation.py +40 -89
  156. mindspore/nn/layer/basic.py +255 -624
  157. mindspore/nn/layer/channel_shuffle.py +7 -6
  158. mindspore/nn/layer/combined.py +1 -1
  159. mindspore/nn/layer/container.py +41 -4
  160. mindspore/nn/layer/conv.py +64 -28
  161. mindspore/nn/layer/dense.py +9 -8
  162. mindspore/nn/layer/embedding.py +27 -25
  163. mindspore/nn/layer/image.py +53 -46
  164. mindspore/nn/layer/math.py +97 -105
  165. mindspore/nn/layer/normalization.py +117 -86
  166. mindspore/nn/layer/padding.py +185 -95
  167. mindspore/nn/layer/pooling.py +817 -414
  168. mindspore/nn/layer/rnn_cells.py +10 -15
  169. mindspore/nn/layer/rnns.py +37 -38
  170. mindspore/nn/layer/thor_layer.py +11 -12
  171. mindspore/nn/layer/timedistributed.py +5 -5
  172. mindspore/nn/layer/transformer.py +701 -0
  173. mindspore/nn/learning_rate_schedule.py +8 -8
  174. mindspore/nn/loss/__init__.py +5 -4
  175. mindspore/nn/loss/loss.py +334 -199
  176. mindspore/nn/optim/ada_grad.py +6 -6
  177. mindspore/nn/optim/adadelta.py +2 -3
  178. mindspore/nn/optim/adafactor.py +4 -5
  179. mindspore/nn/optim/adam.py +126 -62
  180. mindspore/nn/optim/adamax.py +3 -4
  181. mindspore/nn/optim/adasum.py +6 -6
  182. mindspore/nn/optim/asgd.py +2 -2
  183. mindspore/nn/optim/ftrl.py +67 -38
  184. mindspore/nn/optim/lamb.py +4 -5
  185. mindspore/nn/optim/lars.py +2 -2
  186. mindspore/nn/optim/lazyadam.py +43 -4
  187. mindspore/nn/optim/momentum.py +6 -5
  188. mindspore/nn/optim/optimizer.py +3 -1
  189. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  190. mindspore/nn/optim/rmsprop.py +1 -1
  191. mindspore/nn/optim/rprop.py +8 -9
  192. mindspore/nn/optim/sgd.py +19 -13
  193. mindspore/nn/optim/thor.py +10 -15
  194. mindspore/nn/probability/__init__.py +0 -2
  195. mindspore/nn/probability/bijector/bijector.py +4 -4
  196. mindspore/nn/probability/bijector/invert.py +1 -1
  197. mindspore/nn/probability/bijector/softplus.py +2 -2
  198. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  199. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  200. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  201. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  202. mindspore/nn/probability/distribution/beta.py +1 -1
  203. mindspore/nn/probability/distribution/categorical.py +5 -7
  204. mindspore/nn/probability/distribution/cauchy.py +3 -3
  205. mindspore/nn/probability/distribution/distribution.py +2 -2
  206. mindspore/nn/probability/distribution/exponential.py +2 -2
  207. mindspore/nn/probability/distribution/gamma.py +3 -3
  208. mindspore/nn/probability/distribution/geometric.py +1 -1
  209. mindspore/nn/probability/distribution/gumbel.py +3 -3
  210. mindspore/nn/probability/distribution/half_normal.py +15 -11
  211. mindspore/nn/probability/distribution/laplace.py +16 -13
  212. mindspore/nn/probability/distribution/logistic.py +2 -2
  213. mindspore/nn/probability/distribution/normal.py +1 -1
  214. mindspore/nn/probability/distribution/poisson.py +1 -1
  215. mindspore/nn/probability/distribution/student_t.py +20 -15
  216. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  217. mindspore/nn/probability/distribution/uniform.py +2 -2
  218. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  219. mindspore/nn/reinforcement/tensor_array.py +2 -2
  220. mindspore/nn/sparse/sparse.py +2 -2
  221. mindspore/nn/wrap/cell_wrapper.py +27 -10
  222. mindspore/nn/wrap/grad_reducer.py +2 -2
  223. mindspore/nn/wrap/loss_scale.py +40 -24
  224. mindspore/numpy/array_creations.py +33 -22
  225. mindspore/numpy/array_ops.py +35 -30
  226. mindspore/numpy/logic_ops.py +6 -27
  227. mindspore/numpy/math_ops.py +22 -19
  228. mindspore/numpy/utils.py +1 -1
  229. mindspore/numpy/utils_const.py +108 -58
  230. mindspore/ops/_constants.py +0 -6
  231. mindspore/ops/_grad/__init__.py +2 -1
  232. mindspore/ops/_grad/grad_array_ops.py +86 -117
  233. mindspore/ops/_grad/grad_base.py +23 -1
  234. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  235. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  236. mindspore/ops/_grad/grad_implementations.py +9 -45
  237. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  238. mindspore/ops/_grad/grad_math_ops.py +142 -117
  239. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  240. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  241. mindspore/ops/_grad/grad_sparse.py +7 -6
  242. mindspore/ops/_grad_experimental/__init__.py +1 -0
  243. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  244. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  245. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  246. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  247. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  248. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  249. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  250. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  251. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  255. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  256. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  257. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  259. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  260. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  261. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  262. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  263. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  264. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  265. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  266. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  267. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  268. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  269. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  270. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  271. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  272. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  274. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  275. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  276. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  277. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  278. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  279. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  280. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  281. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  282. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  283. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  284. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  285. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  286. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  287. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  288. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  289. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  290. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  291. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  292. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  293. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  294. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  295. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  296. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  297. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  298. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  299. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  300. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  301. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  302. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  303. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  304. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  305. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  306. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  307. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  308. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  309. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  310. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  311. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  312. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  313. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  314. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  315. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  316. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  317. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  318. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  319. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  320. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  321. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  322. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  323. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  324. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  325. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  327. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  328. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  329. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  331. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  332. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  333. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  334. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  335. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  336. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  337. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  338. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  339. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  340. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  341. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  342. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  343. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  345. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  346. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  347. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  348. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  349. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  350. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  351. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  352. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  353. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  354. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  355. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  356. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  357. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  358. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  359. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  360. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  361. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  362. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  363. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  364. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  365. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  366. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  367. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  368. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  369. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  370. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  371. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  372. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  373. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  374. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  375. mindspore/ops/_register_for_op.py +1 -0
  376. mindspore/ops/_utils/__init__.py +1 -2
  377. mindspore/ops/_utils/utils.py +19 -40
  378. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  379. mindspore/ops/_vmap/vmap_base.py +16 -9
  380. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  381. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  382. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  383. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  384. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  385. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  386. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  387. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  388. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  389. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  390. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  391. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  392. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  394. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  395. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  396. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  397. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  398. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  399. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  400. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  401. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  402. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  403. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  404. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  405. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  406. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  407. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  408. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  409. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  413. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  414. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  415. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  416. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  417. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  418. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  419. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  420. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  421. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  422. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  423. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  424. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  425. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  426. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  427. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  428. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  429. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  430. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  431. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  432. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  433. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  434. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  435. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  436. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  437. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  438. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  439. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  440. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  441. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  442. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  443. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  444. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  445. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  446. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  447. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  448. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  449. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  450. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  451. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  452. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  453. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  454. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  455. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  456. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  457. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  458. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  459. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  460. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  461. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  462. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  463. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  464. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  465. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  466. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  467. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  468. mindspore/ops/composite/__init__.py +7 -8
  469. mindspore/ops/composite/base.py +101 -47
  470. mindspore/ops/composite/math_ops.py +188 -158
  471. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  472. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  473. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  474. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  475. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  476. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  477. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  478. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  479. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  480. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  481. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  482. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  483. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  484. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  485. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  486. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  487. mindspore/ops/function/__init__.py +152 -8
  488. mindspore/ops/function/array_func.py +2555 -674
  489. mindspore/ops/function/clip_func.py +209 -13
  490. mindspore/ops/function/debug_func.py +2 -2
  491. mindspore/ops/function/grad/__init__.py +2 -1
  492. mindspore/ops/function/grad/grad_func.py +147 -62
  493. mindspore/ops/function/image_func.py +54 -38
  494. mindspore/ops/function/linalg_func.py +167 -16
  495. mindspore/ops/function/math_func.py +4849 -1492
  496. mindspore/ops/function/nn_func.py +2573 -988
  497. mindspore/ops/function/other_func.py +115 -0
  498. mindspore/ops/function/parameter_func.py +3 -3
  499. mindspore/ops/function/random_func.py +790 -73
  500. mindspore/ops/function/sparse_func.py +98 -78
  501. mindspore/ops/function/sparse_unary_func.py +54 -53
  502. mindspore/ops/function/spectral_func.py +27 -24
  503. mindspore/ops/function/vmap_func.py +22 -2
  504. mindspore/ops/functional.py +97 -37
  505. mindspore/ops/op_info_register.py +70 -28
  506. mindspore/ops/operations/__init__.py +47 -14
  507. mindspore/ops/operations/_csr_ops.py +7 -7
  508. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  509. mindspore/ops/operations/_grad_ops.py +276 -187
  510. mindspore/ops/operations/_inner_ops.py +319 -113
  511. mindspore/ops/operations/_ms_kernel.py +10 -8
  512. mindspore/ops/operations/_ocr_ops.py +9 -9
  513. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  514. mindspore/ops/operations/_quant_ops.py +137 -102
  515. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  516. mindspore/ops/operations/_scalar_ops.py +466 -0
  517. mindspore/ops/operations/_sequence_ops.py +1004 -2
  518. mindspore/ops/operations/_tensor_array.py +10 -11
  519. mindspore/ops/operations/_thor_ops.py +1 -1
  520. mindspore/ops/operations/array_ops.py +801 -466
  521. mindspore/ops/operations/comm_ops.py +51 -49
  522. mindspore/ops/operations/control_ops.py +2 -2
  523. mindspore/ops/operations/custom_ops.py +123 -44
  524. mindspore/ops/operations/debug_ops.py +24 -24
  525. mindspore/ops/operations/image_ops.py +240 -153
  526. mindspore/ops/operations/inner_ops.py +34 -50
  527. mindspore/ops/operations/linalg_ops.py +31 -9
  528. mindspore/ops/operations/math_ops.py +988 -757
  529. mindspore/ops/operations/nn_ops.py +965 -819
  530. mindspore/ops/operations/other_ops.py +51 -40
  531. mindspore/ops/operations/random_ops.py +204 -122
  532. mindspore/ops/operations/rl_ops.py +8 -9
  533. mindspore/ops/operations/sparse_ops.py +254 -93
  534. mindspore/ops/operations/spectral_ops.py +35 -3
  535. mindspore/ops/primitive.py +111 -9
  536. mindspore/parallel/_auto_parallel_context.py +189 -83
  537. mindspore/parallel/_offload_context.py +185 -0
  538. mindspore/parallel/_parallel_serialization.py +99 -7
  539. mindspore/parallel/_ps_context.py +9 -5
  540. mindspore/parallel/_recovery_context.py +1 -1
  541. mindspore/parallel/_tensor.py +7 -1
  542. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  543. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  544. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  545. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  546. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  547. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  548. mindspore/parallel/_utils.py +1 -2
  549. mindspore/parallel/algo_parameter_config.py +1 -1
  550. mindspore/parallel/checkpoint_transform.py +37 -34
  551. mindspore/parallel/shard.py +17 -18
  552. mindspore/profiler/common/validator/validate_path.py +2 -2
  553. mindspore/profiler/envprofiling.py +69 -47
  554. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  555. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  556. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  557. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  558. mindspore/profiler/parser/integrator.py +15 -14
  559. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  560. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  561. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  562. mindspore/profiler/parser/optime_parser.py +17 -18
  563. mindspore/profiler/parser/profiler_info.py +2 -1
  564. mindspore/profiler/profiling.py +218 -186
  565. mindspore/rewrite/__init__.py +3 -1
  566. mindspore/rewrite/api/node.py +1 -114
  567. mindspore/rewrite/api/node_type.py +3 -0
  568. mindspore/rewrite/api/pattern_engine.py +31 -1
  569. mindspore/rewrite/api/scoped_value.py +4 -4
  570. mindspore/rewrite/api/symbol_tree.py +3 -78
  571. mindspore/rewrite/api/tree_node_helper.py +1 -1
  572. mindspore/rewrite/ast_creator_register.py +1 -0
  573. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  574. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  575. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  576. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  577. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  578. mindspore/rewrite/namespace.py +0 -2
  579. mindspore/rewrite/node.py +157 -11
  580. mindspore/rewrite/parsers/assign_parser.py +231 -53
  581. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  582. mindspore/rewrite/parsers/for_parser.py +24 -14
  583. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  584. mindspore/rewrite/parsers/if_parser.py +6 -2
  585. mindspore/rewrite/sparsify/__init__.py +0 -0
  586. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  587. mindspore/rewrite/sparsify/sparsify.py +109 -0
  588. mindspore/rewrite/sparsify/utils.py +173 -0
  589. mindspore/rewrite/symbol_tree.py +256 -133
  590. mindspore/rewrite/symbol_tree_builder.py +38 -1
  591. mindspore/run_check/_check_version.py +69 -63
  592. mindspore/run_check/run_check.py +2 -1
  593. mindspore/scipy/linalg.py +10 -114
  594. mindspore/scipy/ops.py +2 -2
  595. mindspore/scipy/ops_wrapper.py +1 -1
  596. mindspore/scipy/optimize/_bfgs.py +1 -1
  597. mindspore/scipy/optimize/_lagrange.py +200 -0
  598. mindspore/scipy/optimize/line_search.py +3 -2
  599. mindspore/scipy/optimize/minimize.py +41 -2
  600. mindspore/scipy/sparse/__init__.py +2 -2
  601. mindspore/scipy/sparse/linalg.py +4 -464
  602. mindspore/scipy/utils.py +1 -1
  603. mindspore/scipy/utils_const.py +7 -1
  604. mindspore/train/__init__.py +1 -1
  605. mindspore/train/_utils.py +28 -5
  606. mindspore/train/amp.py +273 -102
  607. mindspore/train/callback/_backup_and_restore.py +5 -5
  608. mindspore/train/callback/_callback.py +2 -2
  609. mindspore/train/callback/_checkpoint.py +3 -3
  610. mindspore/train/callback/_early_stop.py +3 -3
  611. mindspore/train/callback/_lambda_callback.py +2 -2
  612. mindspore/train/callback/_landscape.py +29 -31
  613. mindspore/train/callback/_loss_monitor.py +3 -3
  614. mindspore/train/callback/_on_request_exit.py +3 -3
  615. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  616. mindspore/train/callback/_summary_collector.py +23 -16
  617. mindspore/train/callback/_time_monitor.py +3 -3
  618. mindspore/train/checkpoint_pb2.py +68 -8
  619. mindspore/train/data_sink.py +15 -3
  620. mindspore/train/dataset_helper.py +10 -15
  621. mindspore/train/loss_scale_manager.py +8 -11
  622. mindspore/train/metrics/__init__.py +1 -1
  623. mindspore/train/metrics/bleu_score.py +1 -1
  624. mindspore/train/metrics/confusion_matrix.py +1 -1
  625. mindspore/train/metrics/cosine_similarity.py +1 -1
  626. mindspore/train/metrics/dice.py +2 -2
  627. mindspore/train/metrics/fbeta.py +1 -1
  628. mindspore/train/metrics/hausdorff_distance.py +4 -3
  629. mindspore/train/metrics/mean_surface_distance.py +2 -2
  630. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  631. mindspore/train/metrics/perplexity.py +1 -1
  632. mindspore/train/metrics/precision.py +1 -1
  633. mindspore/train/metrics/recall.py +1 -1
  634. mindspore/train/metrics/roc.py +2 -2
  635. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  636. mindspore/train/mind_ir_pb2.py +116 -37
  637. mindspore/train/model.py +45 -28
  638. mindspore/train/serialization.py +295 -188
  639. mindspore/train/summary/_summary_adapter.py +1 -1
  640. mindspore/train/summary/summary_record.py +43 -13
  641. mindspore/train/train_thor/convert_utils.py +2 -2
  642. mindspore/train/train_thor/dataset_helper.py +3 -3
  643. mindspore/version.py +1 -1
  644. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  645. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +648 -574
  646. mindspore/compression/__init__.py +0 -19
  647. mindspore/compression/common/constant.py +0 -124
  648. mindspore/compression/export/__init__.py +0 -19
  649. mindspore/compression/export/quant_export.py +0 -515
  650. mindspore/compression/quant/__init__.py +0 -28
  651. mindspore/compression/quant/qat.py +0 -634
  652. mindspore/compression/quant/quant_utils.py +0 -462
  653. mindspore/compression/quant/quantizer.py +0 -68
  654. mindspore/nn/layer/quant.py +0 -1868
  655. mindspore/nn/layer/rnn_utils.py +0 -90
  656. mindspore/nn/probability/dpn/__init__.py +0 -22
  657. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  658. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  659. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  660. mindspore/nn/probability/infer/__init__.py +0 -22
  661. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  662. mindspore/nn/probability/infer/variational/svi.py +0 -84
  663. mindspore/nn/probability/toolbox/__init__.py +0 -22
  664. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  665. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  666. mindspore/nn/probability/transforms/__init__.py +0 -22
  667. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  668. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  669. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  670. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  671. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  672. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  673. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  674. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  675. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  676. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  677. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  678. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  679. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  680. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  681. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  682. mindspore/ops/composite/array_ops.py +0 -241
  683. mindspore/ops/composite/clip_ops.py +0 -134
  684. mindspore/ops/composite/random_ops.py +0 -426
  685. mindspore/ops/composite/vmap_ops.py +0 -38
  686. mindspore/parallel/nn/__init__.py +0 -42
  687. mindspore/parallel/nn/loss.py +0 -22
  688. mindspore/parallel/nn/moe.py +0 -21
  689. mindspore/parallel/nn/op_parallel_config.py +0 -22
  690. mindspore/parallel/nn/transformer.py +0 -31
  691. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  692. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  693. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,21 @@ from __future__ import absolute_import
18
18
  from functools import partial
19
19
 
20
20
  import mindspore.context as context
21
- from mindspore._checkparam import Validator as validator
22
- from mindspore._checkparam import Rel
21
+ from mindspore import _checkparam as validator
23
22
  from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
24
23
  from mindspore.common import dtype as mstype
25
24
  from mindspore.common.dtype import QuantDtype
26
25
 
27
- if context.get_context('device_target') == "Ascend":
26
+
27
+ def _support_te():
28
+ try:
29
+ import te # pylint: disable=unused-import
30
+ return True
31
+ # pylint: disable=broad-except
32
+ except Exception:
33
+ return False
34
+
35
+ if context.get_context('device_target') == "Ascend" and _support_te():
28
36
  import mindspore.ops._op_impl._custom_op
29
37
 
30
38
  __all__ = ["MinMaxUpdatePerLayer",
@@ -108,8 +116,22 @@ class FakeQuantParam(Primitive):
108
116
 
109
117
  @classmethod
110
118
  def linear_quant_param(cls, quant_dtype, scale, zp, is_per_channel=False, **kwargs):
111
- kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale
112
- kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp
119
+ """
120
+ Create a linear quantization operator based on scale and zero-point parameter.
121
+ """
122
+ validator.check_value_type("scale", scale, [float, tuple, list], "FakeQuantParam")
123
+ if isinstance(scale, float):
124
+ scale_list = [scale]
125
+ else:
126
+ scale_list = scale
127
+ validator.check_value_type("zero_point", zp, [int, tuple, list], "FakeQuantParam")
128
+ if isinstance(zp, int):
129
+ zp_list = [zp]
130
+ else:
131
+ zp_list = zp
132
+ validator.check_value_type("is_per_channel", is_per_channel, [bool], "FakeQuantParam")
133
+ kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale_list
134
+ kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp_list
113
135
  return cls(quant_dtype, FakeQuantParam.attr_value_linear_quant_algo_name, is_per_channel, **kwargs)
114
136
 
115
137
 
@@ -147,14 +169,14 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
147
169
  f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
148
170
 
149
171
  self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
150
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
172
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
151
173
  self.init_prim_io_names(inputs=['x', 'min', 'max'],
152
174
  outputs=['min_up', 'max_up'])
153
175
 
154
176
  def infer_shape(self, x_shape, min_shape, max_shape):
155
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
177
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
156
178
  validator.check("min shape", min_shape, "max shape",
157
- max_shape, Rel.EQ, self.name)
179
+ max_shape, validator.EQ, self.name)
158
180
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
159
181
  return min_shape, max_shape
160
182
 
@@ -203,9 +225,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
203
225
  f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
204
226
 
205
227
  self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
206
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
228
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
207
229
  if self.is_ascend:
208
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
230
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
231
+ 'channel_axis', self.name)
209
232
  else:
210
233
  self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
211
234
  self.init_prim_io_names(
@@ -215,9 +238,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
215
238
  if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
216
239
  raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
217
240
  if not self.is_ascend:
218
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
241
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
219
242
  validator.check("min shape", min_shape, "max shape",
220
- max_shape, Rel.EQ, self.name)
243
+ max_shape, validator.EQ, self.name)
221
244
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
222
245
  return min_shape, max_shape
223
246
 
@@ -273,9 +296,9 @@ class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
273
296
  outputs=['out'])
274
297
 
275
298
  def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
276
- validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
277
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
278
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
299
+ validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
300
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
301
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
279
302
  return input_x_shape
280
303
 
281
304
  def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
@@ -314,9 +337,9 @@ class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer):
314
337
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
315
338
 
316
339
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
317
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
318
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
319
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
340
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
341
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
342
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
320
343
  return dout_shape, alpha_shape
321
344
 
322
345
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -345,9 +368,9 @@ class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer):
345
368
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
346
369
 
347
370
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
348
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
349
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
350
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
371
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
372
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
373
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
351
374
  return dout_shape, dout_shape
352
375
 
353
376
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -423,7 +446,8 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
423
446
  self.training = validator.check_value_type(
424
447
  'training', training, (bool,), self.name)
425
448
  if self.is_ascend:
426
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
449
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
450
+ 'channel_axis', self.name)
427
451
  else:
428
452
  self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
429
453
  self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
@@ -433,12 +457,12 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
433
457
  if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank:
434
458
  raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
435
459
  if not self.is_ascend:
436
- validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
460
+ validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
437
461
  if len(input_x_shape) == 1:
438
462
  self.channel_axis = 0
439
463
 
440
464
  validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name)
441
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
465
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
442
466
  return input_x_shape
443
467
 
444
468
  def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
@@ -479,7 +503,7 @@ class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer):
479
503
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
480
504
 
481
505
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
482
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
506
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
483
507
  return dout_shape, alpha_shape
484
508
 
485
509
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -510,9 +534,9 @@ class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer):
510
534
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
511
535
 
512
536
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
513
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
514
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
515
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
537
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
538
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
539
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
516
540
  return dout_shape, dout_shape
517
541
 
518
542
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -576,7 +600,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
576
600
  num_bits=8,
577
601
  narrow_range=False):
578
602
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
579
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
603
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
580
604
  self.narrow_range = validator.check_value_type(
581
605
  'narrow_range', narrow_range, (bool,), self.name)
582
606
 
@@ -588,9 +612,9 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
588
612
  raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
589
613
 
590
614
  def infer_shape(self, x_shape, min_shape, max_shape):
591
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
592
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
593
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
615
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
616
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
617
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
594
618
  self.check_broadcast(min_shape, x_shape)
595
619
  return x_shape
596
620
 
@@ -640,7 +664,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
640
664
  num_bits=8,
641
665
  narrow_range=False):
642
666
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
643
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
667
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
644
668
  self.narrow_range = validator.check_value_type(
645
669
  'narrow_range', narrow_range, (bool,), self.name)
646
670
 
@@ -652,10 +676,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
652
676
  raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
653
677
 
654
678
  def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
655
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
656
- validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
657
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
658
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
679
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
680
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
681
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
682
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
659
683
  self.check_broadcast(min_shape, x_shape)
660
684
  return x_shape, min_shape, max_shape
661
685
 
@@ -699,15 +723,15 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
699
723
  num_bits=8,
700
724
  narrow_range=False):
701
725
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
702
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
726
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
703
727
  self.narrow_range = validator.check_value_type(
704
728
  'narrow_range', narrow_range, (bool,), self.name)
705
729
 
706
730
  def infer_shape(self, x_shape, min_shape, max_shape):
707
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
708
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
709
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
710
- validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
731
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
732
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
733
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
734
+ validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
711
735
  return x_shape
712
736
 
713
737
  def infer_dtype(self, x_type, min_type, max_type):
@@ -757,16 +781,16 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
757
781
  num_bits=8,
758
782
  narrow_range=False):
759
783
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
760
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
784
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
761
785
  self.narrow_range = validator.check_value_type(
762
786
  'narrow_range', narrow_range, (bool,), self.name)
763
787
 
764
788
  def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
765
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
766
- validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
767
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
768
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
769
- validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
789
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
790
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
791
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
792
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
793
+ validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
770
794
  return x_shape, min_shape, max_shape
771
795
 
772
796
  def infer_dtype(self, dout_type, x_type, min_type, max_type):
@@ -855,15 +879,15 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
855
879
  self.narrow_range = validator.check_value_type(
856
880
  'narrow_range', narrow_range, (bool,), self.name)
857
881
  self.training = validator.check_value_type('training', training, (bool,), self.name)
858
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
882
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
859
883
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
860
884
  self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
861
885
  self.init_prim_io_names(inputs=['x', 'min', 'max'],
862
886
  outputs=['out'])
863
887
 
864
888
  def infer_shape(self, x_shape, min_shape, max_shape):
865
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
866
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
889
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
890
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
867
891
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
868
892
  return x_shape
869
893
 
@@ -909,9 +933,9 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
909
933
 
910
934
  def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
911
935
  validator.check("dout shape", dout_shape, "x shape",
912
- x_shape, Rel.EQ, self.name)
936
+ x_shape, validator.EQ, self.name)
913
937
  validator.check("min shape", min_shape, "max shape",
914
- max_shape, Rel.EQ, self.name)
938
+ max_shape, validator.EQ, self.name)
915
939
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
916
940
  return dout_shape
917
941
 
@@ -981,11 +1005,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
981
1005
  'narrow_range', narrow_range, (bool,), self.name)
982
1006
  self.training = validator.check_value_type(
983
1007
  'training', training, (bool,), self.name)
984
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
1008
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
985
1009
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
986
1010
  self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
987
1011
  if self.is_ascend:
988
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
1012
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
1013
+ 'channel_axis', self.name)
989
1014
  else:
990
1015
  self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
991
1016
  self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
@@ -994,10 +1019,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
994
1019
  if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
995
1020
  raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
996
1021
  if not self.is_ascend:
997
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
1022
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
998
1023
  if len(x_shape) == 1:
999
1024
  self.channel_axis = 0
1000
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
1025
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
1001
1026
  validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
1002
1027
  validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
1003
1028
  return x_shape
@@ -1093,7 +1118,7 @@ class BatchNormFold(PrimitiveWithInfer):
1093
1118
  @prim_attr_register
1094
1119
  def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1095
1120
  """Initialize batch norm fold layer"""
1096
- self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1121
+ self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
1097
1122
  self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1098
1123
  self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1099
1124
  self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@@ -1102,8 +1127,9 @@ class BatchNormFold(PrimitiveWithInfer):
1102
1127
  outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
1103
1128
 
1104
1129
  def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
1105
- validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
1106
- validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
1130
+ validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
1131
+ validator.check("mean_shape[0]", mean_shape[0], "input channel",
1132
+ x_shape[self.channel_axis], validator.EQ, self.name)
1107
1133
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1108
1134
  return mean_shape, mean_shape, mean_shape, mean_shape
1109
1135
 
@@ -1144,13 +1170,13 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
1144
1170
  def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
1145
1171
  global_step_shape):
1146
1172
  validator.check("d_batch_mean shape", d_batch_mean_shape,
1147
- "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
1173
+ "d_batch_std shape", d_batch_std_shape, validator.EQ, self.name)
1148
1174
  validator.check("d_batch_mean shape", d_batch_mean_shape,
1149
- "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1175
+ "batch_mean shape", batch_mean_shape, validator.EQ, self.name)
1150
1176
  validator.check("d_batch_mean shape", d_batch_mean_shape,
1151
- "batch_std shape", batch_std_shape, Rel.EQ, self.name)
1177
+ "batch_std shape", batch_std_shape, validator.EQ, self.name)
1152
1178
  validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
1153
- "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
1179
+ "input channel", x_shape[self.channel_axis], validator.EQ, self.name)
1154
1180
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1155
1181
  return x_shape
1156
1182
 
@@ -1195,9 +1221,10 @@ class CorrectionMul(PrimitiveWithInfer):
1195
1221
  outputs=['out'])
1196
1222
 
1197
1223
  def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
1198
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1224
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1225
+ running_std_shape, validator.EQ, self.name)
1199
1226
  validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1200
- Rel.EQ, self.name)
1227
+ validator.EQ, self.name)
1201
1228
  return x_shape
1202
1229
 
1203
1230
  def infer_dtype(self, x_type, batch_std_type, running_std_type):
@@ -1229,11 +1256,11 @@ class CorrectionMulGrad(PrimitiveWithInfer):
1229
1256
  outputs=['dx', 'mul_dx'])
1230
1257
 
1231
1258
  def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
1232
- validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
1259
+ validator.check("dout shape", dout_shape, "x_shape x", x_shape, validator.EQ, self.name)
1233
1260
  validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
1234
- Rel.EQ, self.name)
1261
+ validator.EQ, self.name)
1235
1262
  validator.check("running_std_shape[0]", running_std_shape[0],
1236
- "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
1263
+ "dout channel size", dout_shape[self.channel_axis], validator.EQ, self.name)
1237
1264
  if context.get_context('device_target') == "Ascend":
1238
1265
  return x_shape, x_shape
1239
1266
  return x_shape, gamma_shape
@@ -1319,14 +1346,16 @@ class BatchNormFold2(PrimitiveWithInfer):
1319
1346
 
1320
1347
  def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
1321
1348
  running_mean_shape, global_step_shape):
1322
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1323
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1324
- validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
1349
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1350
+ running_std_shape, validator.EQ, self.name)
1351
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1352
+ batch_mean_shape, validator.EQ, self.name)
1353
+ validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
1325
1354
  validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1326
- Rel.EQ, self.name)
1327
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
1355
+ validator.EQ, self.name)
1356
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
1328
1357
  validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1329
- Rel.EQ, self.name)
1358
+ validator.EQ, self.name)
1330
1359
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1331
1360
  return x_shape
1332
1361
 
@@ -1369,13 +1398,15 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
1369
1398
  def infer_shape(self, dout_shape, x_shape, gamma_shape,
1370
1399
  batch_std_shape, batch_mean_shape,
1371
1400
  running_std_shape, running_mean_shape, global_step_shape):
1372
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1373
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1401
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1402
+ batch_mean_shape, validator.EQ, self.name)
1403
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1404
+ running_std_shape, validator.EQ, self.name)
1374
1405
  validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1375
- Rel.EQ, self.name)
1376
- validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
1406
+ validator.EQ, self.name)
1407
+ validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
1377
1408
  validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1378
- Rel.EQ, self.name)
1409
+ validator.EQ, self.name)
1379
1410
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1380
1411
  return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
1381
1412
 
@@ -1406,7 +1437,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
1406
1437
  def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1407
1438
  """Initialize _BatchNormFold layer"""
1408
1439
  from mindspore.ops._op_impl._custom_op import batchnorm_fold
1409
- self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1440
+ self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
1410
1441
  self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1411
1442
  self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1412
1443
  self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@@ -1416,8 +1447,8 @@ class BatchNormFoldD(PrimitiveWithInfer):
1416
1447
  'mean_updated', 'variance_updated'])
1417
1448
 
1418
1449
  def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
1419
- validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
1420
- validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
1450
+ validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
1451
+ validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], validator.EQ, self.name)
1421
1452
  return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
1422
1453
 
1423
1454
  def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
@@ -1487,12 +1518,14 @@ class BatchNormFold2D(PrimitiveWithInfer):
1487
1518
  outputs=['y'])
1488
1519
 
1489
1520
  def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
1490
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1491
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1492
- validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
1493
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
1521
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1522
+ running_std_shape, validator.EQ, self.name)
1523
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1524
+ batch_mean_shape, validator.EQ, self.name)
1525
+ validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
1526
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
1494
1527
  validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1495
- Rel.EQ, self.name)
1528
+ validator.EQ, self.name)
1496
1529
  return x_shape
1497
1530
 
1498
1531
  def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
@@ -1517,11 +1550,13 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
1517
1550
 
1518
1551
  def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
1519
1552
  batch_mean_shape, running_std_shape):
1520
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1521
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1522
- validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
1553
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1554
+ batch_mean_shape, validator.EQ, self.name)
1555
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1556
+ running_std_shape, validator.EQ, self.name)
1557
+ validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
1523
1558
  validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1524
- Rel.EQ, self.name)
1559
+ validator.EQ, self.name)
1525
1560
  return gamma_shape, gamma_shape, gamma_shape, dout_shape
1526
1561
 
1527
1562
  def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
@@ -1553,7 +1588,7 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
1553
1588
  outputs=['dout_reduce', 'dout_x_reduce'])
1554
1589
 
1555
1590
  def infer_shape(self, dout_shape, x_shape):
1556
- validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
1591
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
1557
1592
  return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
1558
1593
 
1559
1594
  def infer_dtype(self, dout_type, x_type):
@@ -1595,17 +1630,17 @@ class ActsULQ(PrimitiveWithInfer):
1595
1630
  def __init__(self, fixed_min=False, num_bits=8):
1596
1631
  validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
1597
1632
  validator.check_value_type("num_bits", num_bits, [int], self.name)
1598
- validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
1633
+ validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
1599
1634
 
1600
1635
  def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
1601
1636
  """infer shape of primitive"""
1602
- validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name)
1603
- validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name)
1637
+ validator.check_int(len(clamp_min_shape), len(x_shape), validator.EQ, "dims of clamp_min", self.name)
1638
+ validator.check_int(len(clamp_max_shape), len(x_shape), validator.EQ, "dims of clamp_max", self.name)
1604
1639
 
1605
1640
  x_shape_len = len(x_shape)
1606
1641
  for i in range(x_shape_len):
1607
- validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name)
1608
- validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name)
1642
+ validator.check_int(clamp_min_shape[i], 1, validator.EQ, "dims of clamp_min", self.name)
1643
+ validator.check_int(clamp_max_shape[i], 1, validator.EQ, "dims of clamp_max", self.name)
1609
1644
 
1610
1645
  return x_shape, x_shape, x_shape, x_shape
1611
1646
 
@@ -1746,12 +1781,12 @@ class WtsARQ(PrimitiveWithInfer):
1746
1781
  @prim_attr_register
1747
1782
  def __init__(self, num_bits, offset_flag):
1748
1783
  validator.check_value_type("num_bits", num_bits, [int], self.name)
1749
- validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
1784
+ validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
1750
1785
  validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
1751
1786
 
1752
1787
  def infer_shape(self, w_shape, w_min_shape, w_max_shape):
1753
- validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name)
1754
- validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name)
1788
+ validator.check_int(len(w_min_shape), len(w_shape), validator.EQ, "dims of w_min", self.name)
1789
+ validator.check_int(len(w_max_shape), len(w_shape), validator.EQ, "dims of w_max", self.name)
1755
1790
  return w_shape
1756
1791
 
1757
1792
  def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
@@ -1808,6 +1843,6 @@ class IFMR(Primitive):
1808
1843
  validator.check_value_type("search_range", search_range, [list, tuple], self.name)
1809
1844
  for item in search_range:
1810
1845
  validator.check_positive_float(item, "item of search_range", self.name)
1811
- validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
1846
+ validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], validator.GE, self.name)
1812
1847
  validator.check_value_type("search_step", search_step, [float], self.name)
1813
1848
  validator.check_value_type("offset_flag", with_offset, [bool], self.name)