mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/_checkparam.py CHANGED
@@ -18,10 +18,8 @@ from __future__ import absolute_import
18
18
  import re
19
19
  import inspect
20
20
  import math
21
- from enum import Enum
22
21
  from functools import reduce, wraps
23
- from itertools import repeat, zip_longest
24
- from collections import deque
22
+ from itertools import repeat
25
23
  from collections.abc import Iterable
26
24
  import numpy as np
27
25
 
@@ -31,71 +29,92 @@ from mindspore.common import dtype as mstype
31
29
  from mindspore._c_expression import Tensor as Tensor_
32
30
 
33
31
 
34
- class Rel(Enum):
35
-
36
- """Numerical relationship between variables, logical relationship enumeration definition of range."""
37
- # scalar compare
38
- EQ = 1 # ==
39
- NE = 2 # !=
40
- LT = 3 # <
41
- LE = 4 # <=
42
- GT = 5 # >
43
- GE = 6 # >=
44
- # scalar range check
45
- INC_NEITHER = 7 # (), include neither
46
- INC_LEFT = 8 # [), include left
47
- INC_RIGHT = 9 # (], include right
48
- INC_BOTH = 10 # [], include both
49
- # collection in, not in
50
- IN = 11
51
- NOT_IN = 12
52
-
53
- @staticmethod
54
- def get_strs(rel):
55
- """Get value from rel_strs."""
56
- return rel_strs.get(rel, "")
57
-
58
- @staticmethod
59
- def get_fns(rel):
60
- """Get value from rel_fns."""
61
- return rel_fns.get(rel, lambda *args: False)
62
-
63
-
64
- rel_fns = {
65
- # scalar compare
66
- Rel.EQ: lambda x, y: x == y,
67
- Rel.NE: lambda x, y: x != y,
68
- Rel.LT: lambda x, y: x < y,
69
- Rel.LE: lambda x, y: x <= y,
70
- Rel.GT: lambda x, y: x > y,
71
- Rel.GE: lambda x, y: x >= y,
72
- # scalar range check
73
- Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
74
- Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
75
- Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
76
- Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
77
- # collection in, not in
78
- Rel.IN: lambda x, y: x in y,
79
- Rel.NOT_IN: lambda x, y: x not in y,
80
- }
81
-
82
- rel_strs = {
83
- # scalar compare
84
- Rel.EQ: "= {}",
85
- Rel.NE: "!= {}",
86
- Rel.LT: "< {}",
87
- Rel.LE: "<= {}",
88
- Rel.GT: "> {}",
89
- Rel.GE: ">= {}",
90
- # scalar range check
91
- Rel.INC_NEITHER: "({}, {})",
92
- Rel.INC_LEFT: "[{}, {})",
93
- Rel.INC_RIGHT: "({}, {}]",
94
- Rel.INC_BOTH: "[{}, {}]",
95
- # collection in, not in
96
- Rel.IN: "in {}",
97
- Rel.NOT_IN: "not in {}",
98
- }
32
+ EQ = 1 # ==
33
+ NE = 2 # !=
34
+ LT = 3 # <
35
+ LE = 4 # <=
36
+ GT = 5 # >
37
+ GE = 6 # >=
38
+ # scalar range check
39
+ INC_NEITHER = 7 # (), include neither
40
+ INC_LEFT = 8 # [), include left
41
+ INC_RIGHT = 9 # (], include right
42
+ INC_BOTH = 10 # [], include both
43
+ # collection in, not in
44
+ IN = 11
45
+ NOT_IN = 12
46
+
47
+
48
+ def _check_binary_rel(val1, val2, rel):
49
+ """check binary relation"""
50
+ if rel == EQ:
51
+ return val1 == val2
52
+ if rel == NE:
53
+ return val1 != val2
54
+ if rel == LT:
55
+ return val1 < val2
56
+ if rel == LE:
57
+ return val1 <= val2
58
+ if rel == GT:
59
+ return val1 > val2
60
+ if rel == GE:
61
+ return val1 >= val2
62
+ if rel == IN:
63
+ return val1 in val2
64
+ if rel == NOT_IN:
65
+ return val1 not in val2
66
+
67
+ return False
68
+
69
+
70
+ def _check_inc_rel(val, lower, upper, rel):
71
+ """check include relation"""
72
+ if rel == INC_NEITHER:
73
+ return not (val <= lower or val >= upper)
74
+ if rel == INC_LEFT:
75
+ return not (val < lower or val >= upper)
76
+ if rel == INC_RIGHT:
77
+ return not (val <= lower or val > upper)
78
+ if rel == INC_BOTH:
79
+ return not (val < lower or val > upper)
80
+
81
+ return False
82
+
83
+
84
+ def _format_str_one_value(value, rel):
85
+ """format string"""
86
+ if rel == EQ:
87
+ return "= {}".format(value)
88
+ if rel == NE:
89
+ return "!= {}".format(value)
90
+ if rel == LT:
91
+ return "< {}".format(value)
92
+ if rel == LE:
93
+ return "<= {}".format(value)
94
+ if rel == GT:
95
+ return "> {}".format(value)
96
+ if rel == GE:
97
+ return ">= {}".format(value)
98
+ if rel == IN:
99
+ return "in {}".format(value)
100
+ if rel == NOT_IN:
101
+ return "not in {}".format(value)
102
+
103
+ return ""
104
+
105
+
106
+ def _format_str_two_value(val1, val2, rel):
107
+ """format string"""
108
+ if rel == INC_NEITHER:
109
+ return "({}, {})".format(val1, val2)
110
+ if rel == INC_LEFT:
111
+ return "[{}, {})".format(val1, val2)
112
+ if rel == INC_RIGHT:
113
+ return "({}, {}]".format(val1, val2)
114
+ if rel == INC_BOTH:
115
+ return "[{}, {}]".format(val1, val2)
116
+
117
+ return ""
99
118
 
100
119
 
101
120
  def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
@@ -106,71 +125,99 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
106
125
 
107
126
  def _raise_message(third_one_flag=False, three_input_flag=False):
108
127
  if third_one_flag:
109
- raise ValueError(f"For '{prim_name}', the depth of parameter '{arg_name}' must be 1, "
110
- f"but got {ret_value[-3]}.")
128
+ raise ValueError("For '{}', the depth of parameter '{}' must be 1, but got {}." \
129
+ .format(prim_name, arg_name, ret_value[-3]))
111
130
  if three_input_flag:
112
- raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer "
113
- f"or a tuple of three positive integer, but got {arg_value}.")
114
- raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer "
115
- f"or a tuple of three {'or five ' if allow_five else ''}positive integer, but got {arg_value}")
131
+ raise ValueError("For '{}', the parameter '{}' must be an positive integer " \
132
+ "or a tuple of three positive integer, but got {}.".format(prim_name, arg_name, arg_value))
133
+ raise ValueError("For '{}', the parameter '{}' must be an positive integer " \
134
+ "or a tuple of three {}positive integer, but got {}" \
135
+ .format(prim_name, arg_name, 'or five ' if allow_five else '', arg_value))
116
136
 
117
137
  def _get_return_value():
138
+ def _check():
139
+ if not isinstance(arg_value, int):
140
+ if len(arg_value) == 5:
141
+ if not allow_five:
142
+ _raise_message()
143
+ elif not len(arg_value) == 3:
144
+ _raise_message()
145
+
146
+ _check()
118
147
  if isinstance(arg_value, int):
119
148
  ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
120
149
  elif len(arg_value) == 3:
121
150
  ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
122
- elif len(arg_value) == 5:
123
- if not allow_five:
124
- _raise_message()
151
+ else: # case: len(arg_value) == 5
125
152
  ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
126
- else:
127
- _raise_message()
153
+
128
154
  return ret
129
155
 
130
- Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
156
+ def _check_value(ret_value):
157
+ for item in ret_value:
158
+ if isinstance(item, int) and not isinstance(item, bool):
159
+ if greater_zero and item > 0:
160
+ continue
161
+ if not greater_zero and item >= 0:
162
+ continue
163
+ _raise_message()
164
+
165
+ def _check_third_one(ret_value):
166
+ if third_one:
167
+ if ret_value[-3] != 1:
168
+ _raise_message(third_one_flag=third_one)
169
+
170
+ check_value_type(arg_name, arg_value, (int, tuple), prim_name)
131
171
  if three_input and isinstance(arg_value, tuple):
132
172
  if len(arg_value) != 3:
133
173
  _raise_message(three_input_flag=three_input)
134
174
  ret_value = _get_return_value()
135
- for item in ret_value:
136
- if isinstance(item, int) and not isinstance(item, bool):
137
- if greater_zero and item > 0:
138
- continue
139
- if not greater_zero and item >= 0:
140
- continue
141
- _raise_message()
142
-
143
- if third_one:
144
- if ret_value[-3] != 1:
145
- _raise_message(third_one_flag=third_one)
175
+ _check_value(ret_value)
176
+ _check_third_one(ret_value)
146
177
 
147
178
  return tuple(ret_value)
148
179
 
149
180
 
150
- def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
181
+ def _check_dup(axes):
182
+ for item in axes:
183
+ count = 0
184
+ for item2 in axes:
185
+ if item == item2:
186
+ count += 1
187
+
188
+ if count > 1:
189
+ raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.")
190
+
191
+
192
+ def _check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
151
193
  """
152
194
  Check argument integer.
153
195
 
154
196
  Usage:
155
- - arg_value = check_number(arg_value, 2, Rel.GT, int, "value", None)
197
+ - arg_value = _check_number(arg_value, 2, GT, int, "value", None)
156
198
  """
157
- rel_fn = Rel.get_fns(rel)
158
199
  prim_name = f"For \'{prim_name}\', the " if prim_name else 'The '
159
200
  arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
160
- prim_info = f'{prim_name}' + f'{arg_name}'
161
- if isinstance(arg_value, arg_type):
162
- if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
163
- raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
164
- else:
165
- raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
166
201
 
167
- type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
168
- type_except = TypeError if type_mismatch else ValueError
169
- if type_mismatch or not rel_fn(arg_value, value):
170
- rel_str = Rel.get_strs(rel).format(value)
171
- raise type_except(f"{prim_info} must be {arg_type.__name__} and must {rel_str}, "
172
- f"but got '{arg_value}' with type '{type(arg_value).__name__}'.")
202
+ def _check_param():
203
+ prim_info = f'{prim_name}' + f'{arg_name}'
204
+ if isinstance(arg_value, arg_type):
205
+ if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
206
+ raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
207
+ else:
208
+ raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
209
+
210
+ type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
211
+ rel_ret = _check_binary_rel(arg_value, value, rel)
212
+ if type_mismatch or not rel_ret:
213
+ rel_str = _format_str_one_value(value, rel)
214
+ msg = f"{prim_info} must be {arg_type.__name__} and must {rel_str}, " \
215
+ f"but got '{arg_value}' with type '{type(arg_value).__name__}'."
216
+ if type_mismatch:
217
+ raise TypeError(msg)
218
+ raise ValueError(msg)
173
219
 
220
+ _check_param()
174
221
  return arg_value
175
222
 
176
223
 
@@ -185,11 +232,16 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
185
232
  """
186
233
  prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
187
234
  arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
188
- if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
189
- if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
190
- raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.")
191
- return arg_value
192
- raise TypeError(f"{prim_name} type of {arg_name} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
235
+
236
+ def _check_param():
237
+ if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
238
+ if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
239
+ raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.")
240
+ else:
241
+ raise TypeError("{} type of {} must be {}, but got '{}'".format(
242
+ prim_name, arg_name, arg_type.__name__, type(arg_value).__name__))
243
+ _check_param()
244
+ return arg_value
193
245
 
194
246
 
195
247
  def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
@@ -197,899 +249,940 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
197
249
  Method for checking whether an int value is in some range.
198
250
 
199
251
  Usage:
200
- - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
201
- - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
252
+ - number = check_number_range(number, 0.0, 1.0, INC_NEITHER, "number", float) # number in [0.0, 1.0]
253
+ - number = check_number_range(number, 0, 1, INC_NEITHER, "number", int) # number in [0, 1]
202
254
  """
203
- rel_fn = Rel.get_fns(rel)
204
255
  prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
205
256
  arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
206
- type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
207
- if type_mismatch:
208
- raise TypeError("{} {} must be '{}', but got '{}'.".format(
209
- prim_name, arg_name, value_type.__name__, type(arg_value).__name__))
210
- if not rel_fn(arg_value, lower_limit, upper_limit):
211
- rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
212
- raise ValueError("{} {} must be in range of {}, but got {} with type '{}'.".format(
213
- prim_name, arg_name, rel_str, arg_value, type(arg_value).__name__))
257
+
258
+ def _check_param():
259
+ type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
260
+ if type_mismatch:
261
+ raise TypeError("{} {} must be '{}', but got '{}'.".format(
262
+ prim_name, arg_name, value_type.__name__, type(arg_value).__name__))
263
+
264
+ if not _check_inc_rel(arg_value, lower_limit, upper_limit, rel):
265
+ rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
266
+ raise ValueError("{} {} must be in range of {}, but got {} with type '{}'.".format(
267
+ prim_name, arg_name, rel_str, arg_value, type(arg_value).__name__))
268
+ _check_param()
214
269
  return arg_value
215
270
 
216
271
 
217
- class Validator:
218
- """validator for checking input parameters"""
272
+ def is_stub_tensor(tensor):
273
+ return hasattr(tensor, "stub")
274
+
219
275
 
220
- @staticmethod
221
- def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
222
- """
223
- Method for judging relation between two int values or list/tuple made up of ints.
224
- This method is not suitable for judging relation between floats, since it does not consider float error.
225
- """
226
- rel_fn = Rel.get_fns(rel)
227
- if not rel_fn(arg_value, value):
228
- rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
276
+ def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_cls=ValueError):
277
+ """
278
+ Method for judging relation between two int values or list/tuple made up of ints.
279
+ This method is not suitable for judging relation between floats, since it does not consider float error.
280
+ """
281
+ def _check():
282
+ if not _check_binary_rel(arg_value, value, rel):
283
+ rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
229
284
  msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
230
- raise excp_cls(f'{msg_prefix} \'{arg_name}\' should be {rel_str}, but got {arg_value}.')
231
- return arg_value
232
-
233
- @staticmethod
234
- def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
235
- """
236
- Checks input integer value `arg_value` compare to `value`.
237
-
238
- Usage:
239
- - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
240
- """
241
- return check_number(arg_value, value, rel, int, arg_name, prim_name)
242
-
243
- @staticmethod
244
- def check_is_int(arg_value, arg_name=None, prim_name=None):
245
- """
246
- Checks input value is float type or not.
247
-
248
- Usage:
249
- - number = check_is_int(number, int)
250
- - number = check_is_int(number, int, "bias")
251
- - number = check_is_int(number, int, "bias", "bias_class")
252
- """
253
- return check_is_number(arg_value, int, arg_name, prim_name)
254
-
255
- @staticmethod
256
- def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
257
- """
258
- Checks input integer value `arg_value` compare to `value`.
259
-
260
- Usage:
261
- - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
262
- """
263
- return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name)
264
-
265
- @staticmethod
266
- def check_positive_int(arg_value, arg_name=None, prim_name=None):
267
- """
268
- Check argument is positive integer, which mean arg_value > 0.
269
-
270
- Usage:
271
- - number = check_positive_int(number)
272
- - number = check_positive_int(number, "bias")
273
- """
274
- return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name)
275
-
276
- @staticmethod
277
- def check_positive_int_sequence(sequence, arg_name=None, prim_name=None):
278
- """
279
- Check argument is positive int sequence, which mean all element > 0 in sequence.
280
-
281
- Usage:
282
- - sequence = check_positive_int_sequence(sequence)
283
- - sequence = check_positive_int_sequence(sequence, "dims")
284
- """
285
- for idx, element in enumerate(sequence):
286
- arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
287
- check_number(element, 0, Rel.GT, int, arg_idx, prim_name)
288
- return sequence
289
-
290
- @staticmethod
291
- def check_negative_int(arg_value, arg_name=None, prim_name=None):
292
- """
293
- Check argument is negative integer, which mean arg_value < 0.
294
-
295
- Usage:
296
- - number = check_negative_int(number)
297
- - number = check_negative_int(number, "bias")
298
- """
299
- return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name)
300
-
301
- @staticmethod
302
- def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
303
- """
304
- Check argument is non-negative integer, which mean arg_value <= 0.
305
-
306
- Usage:
307
- - number = check_non_positive_int(number)
308
- - number = check_non_positive_int(number, "bias")
309
- """
310
- return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name)
311
-
312
- @staticmethod
313
- def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
314
- """
315
- Check argument is non-negative integer, which mean arg_value >= 0.
316
-
317
- Usage:
318
- - number = check_non_negative_int(number)
319
- - number = check_non_negative_int(number, "bias")
320
- """
321
- return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
322
-
323
- @staticmethod
324
- def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None):
325
- """
326
- Check argument is positive sequence, which mean all element >= 0 in sequence.
327
-
328
- Usage:
329
- - sequence = check_non_negative_int_sequence(sequence)
330
- - sequence = check_non_negative_int_sequence(sequence, "dims")
331
- """
332
- for idx, element in enumerate(sequence):
333
- arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
334
- check_number(element, 0, Rel.GE, int, arg_idx, prim_name)
335
- return sequence
336
-
337
- @staticmethod
338
- def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
339
- """
340
- Checks input float value `arg_value` compare to `value`.
341
-
342
- Usage:
343
- - number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0
344
- """
345
- return check_number(arg_value, value, rel, float, arg_name, prim_name)
346
-
347
- @staticmethod
348
- def check_is_float(arg_value, arg_name=None, prim_name=None):
349
- """
350
- Checks input value is float type or not.
351
-
352
- Usage:
353
- - number = check_is_float(number, int)
354
- - number = check_is_float(number, int, "bias")
355
- - number = check_is_float(number, int, "bias", "bias_class")
356
- """
357
- return check_is_number(arg_value, float, arg_name, prim_name)
358
-
359
- @staticmethod
360
- def check_positive_float(arg_value, arg_name=None, prim_name=None):
361
- """
362
- Check argument is positive float, which mean arg_value > 0.
363
-
364
- Usage:
365
- - number = check_positive_float(number)
366
- - number = check_positive_float(number, "bias")
367
- - number = check_positive_float(number, "bias", "bias_class")
368
- """
369
- return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name)
370
-
371
- @staticmethod
372
- def check_positive_float_sequence(sequence, arg_name=None, prim_name=None):
373
- """
374
- Check argument is positive sequence, which mean all element > 0 in sequence.
375
-
376
- Usage:
377
- - sequence = check_positive_float_sequence(sequence)
378
- - sequence = check_positive_float_sequence(sequence, "dims")
379
- """
380
- for idx, element in enumerate(sequence):
381
- arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
382
- check_number(element, 0, Rel.GT, float, arg_idx, prim_name)
383
- return sequence
384
-
385
- @staticmethod
386
- def check_negative_float(arg_value, arg_name=None, prim_name=None):
387
- """
388
- Check argument is negative float, which mean arg_value < 0.
389
-
390
- Usage:
391
- - number = check_negative_float(number)
392
- - number = check_negative_float(number, "bias")
393
- """
394
- return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name)
395
-
396
- @staticmethod
397
- def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
398
- """
399
- Check argument is non-negative float, which mean arg_value <= 0.
400
-
401
- Usage:
402
- - number = check_non_positive_float(number)
403
- - number = check_non_positive_float(number, "bias")
404
- """
405
- return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name)
406
-
407
- @staticmethod
408
- def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
409
- """
410
- Check argument is non-negative float, which mean arg_value >= 0.
411
-
412
- Usage:
413
- - number = check_non_negative_float(number)
414
- - number = check_non_negative_float(number, "bias")
415
- """
416
- return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name)
417
-
418
- @staticmethod
419
- def check_number(arg_name, arg_value, value, rel, prim_name):
420
- """Number value judgment."""
421
- rel_fn = Rel.get_fns(rel)
422
- if not rel_fn(arg_value, value):
423
- rel_str = Rel.get_strs(rel).format(value)
424
- raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' must {rel_str}, but got {arg_value}.')
425
- return arg_value
426
-
427
- @staticmethod
428
- def check_isinstance(arg_name, arg_value, classes):
429
- """Check arg isinstance of classes"""
285
+ msg_subject = f"{msg_prefix} \'{arg_name}\'" if " " not in arg_name else f"{msg_prefix} {arg_name}"
286
+ raise excp_cls(f'{msg_subject} should be {rel_str}, but got {arg_value}.')
287
+ _check()
288
+ return arg_value
289
+
290
+
291
+ def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
292
+ """
293
+ Checks input integer value `arg_value` compare to `value`.
294
+
295
+ Usage:
296
+ - number = check_int(number, 0, GE, "number", None) # number >= 0
297
+ """
298
+ return _check_number(arg_value, value, rel, int, arg_name, prim_name)
299
+
300
+
301
+ def check_is_int(arg_value, arg_name=None, prim_name=None):
302
+ """
303
+ Checks input value is float type or not.
304
+
305
+ Usage:
306
+ - number = check_is_int(number, int)
307
+ - number = check_is_int(number, int, "bias")
308
+ - number = check_is_int(number, int, "bias", "bias_class")
309
+ """
310
+ return check_is_number(arg_value, int, arg_name, prim_name)
311
+
312
+
313
+ def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
314
+ """
315
+ Checks input integer value `arg_value` compare to `value`.
316
+
317
+ Usage:
318
+ - number = check_int(number, 0, GE, "number", None) # number >= 0
319
+ """
320
+ return _check_number(arg_value, value, EQ, int, arg_name, prim_name)
321
+
322
+
323
+ def check_positive_int(arg_value, arg_name=None, prim_name=None):
324
+ """
325
+ Check argument is positive integer, which mean arg_value > 0.
326
+
327
+ Usage:
328
+ - number = check_positive_int(number)
329
+ - number = check_positive_int(number, "bias")
330
+ """
331
+ return _check_number(arg_value, 0, GT, int, arg_name, prim_name)
332
+
333
+
334
+ def check_positive_int_sequence(sequence, arg_name=None, prim_name=None):
335
+ """
336
+ Check argument is positive int sequence, which mean all element > 0 in sequence.
337
+
338
+ Usage:
339
+ - sequence = check_positive_int_sequence(sequence)
340
+ - sequence = check_positive_int_sequence(sequence, "dims")
341
+ """
342
+ for idx, element in enumerate(sequence):
343
+ arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
344
+ _check_number(element, 0, GT, int, arg_idx, prim_name)
345
+ return sequence
346
+
347
+
348
+ def check_negative_int(arg_value, arg_name=None, prim_name=None):
349
+ """
350
+ Check argument is negative integer, which mean arg_value < 0.
351
+
352
+ Usage:
353
+ - number = check_negative_int(number)
354
+ - number = check_negative_int(number, "bias")
355
+ """
356
+ return _check_number(arg_value, 0, LT, int, arg_name, prim_name)
357
+
358
+
359
+ def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
360
+ """
361
+ Check argument is non-negative integer, which mean arg_value <= 0.
362
+
363
+ Usage:
364
+ - number = check_non_positive_int(number)
365
+ - number = check_non_positive_int(number, "bias")
366
+ """
367
+ return _check_number(arg_value, 0, LE, int, arg_name, prim_name)
368
+
369
+
370
+ def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
371
+ """
372
+ Check argument is non-negative integer, which mean arg_value >= 0.
373
+
374
+ Usage:
375
+ - number = check_non_negative_int(number)
376
+ - number = check_non_negative_int(number, "bias")
377
+ """
378
+ return _check_number(arg_value, 0, GE, int, arg_name, prim_name)
379
+
380
+
381
+ def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None):
382
+ """
383
+ Check argument is positive sequence, which mean all element >= 0 in sequence.
384
+
385
+ Usage:
386
+ - sequence = check_non_negative_int_sequence(sequence)
387
+ - sequence = check_non_negative_int_sequence(sequence, "dims")
388
+ """
389
+ for idx, element in enumerate(sequence):
390
+ arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
391
+ _check_number(element, 0, GE, int, arg_idx, prim_name)
392
+ return sequence
393
+
394
+
395
+ def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
396
+ """
397
+ Checks input float value `arg_value` compare to `value`.
398
+
399
+ Usage:
400
+ - number = check_float(number, 0.0, GE, "number", None) # number >= 0
401
+ """
402
+ return _check_number(arg_value, value, rel, float, arg_name, prim_name)
403
+
404
+
405
+ def check_is_float(arg_value, arg_name=None, prim_name=None):
406
+ """
407
+ Checks input value is float type or not.
408
+
409
+ Usage:
410
+ - number = check_is_float(number)
411
+ - number = check_is_float(number, "bias")
412
+ - number = check_is_float(number, "bias", "bias_class")
413
+ """
414
+ return check_is_number(arg_value, float, arg_name, prim_name)
415
+
416
+
417
+ def check_positive_float(arg_value, arg_name=None, prim_name=None):
418
+ """
419
+ Check argument is positive float, which mean arg_value > 0.
420
+
421
+ Usage:
422
+ - number = check_positive_float(number)
423
+ - number = check_positive_float(number, "bias")
424
+ - number = check_positive_float(number, "bias", "bias_class")
425
+ """
426
+ return _check_number(arg_value, 0, GT, float, arg_name, prim_name)
427
+
428
+
429
+ def check_positive_float_sequence(sequence, arg_name=None, prim_name=None):
430
+ """
431
+ Check argument is positive sequence, which mean all element > 0 in sequence.
432
+
433
+ Usage:
434
+ - sequence = check_positive_float_sequence(sequence)
435
+ - sequence = check_positive_float_sequence(sequence, "dims")
436
+ """
437
+ for idx, element in enumerate(sequence):
438
+ arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
439
+ _check_number(element, 0, GT, float, arg_idx, prim_name)
440
+ return sequence
441
+
442
+
443
+ def check_negative_float(arg_value, arg_name=None, prim_name=None):
444
+ """
445
+ Check argument is negative float, which mean arg_value < 0.
446
+
447
+ Usage:
448
+ - number = check_negative_float(number)
449
+ - number = check_negative_float(number, "bias")
450
+ """
451
+ return _check_number(arg_value, 0, LT, float, arg_name, prim_name)
452
+
453
+
454
+ def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
455
+ """
456
+ Check argument is non-negative float, which mean arg_value <= 0.
457
+
458
+ Usage:
459
+ - number = check_non_positive_float(number)
460
+ - number = check_non_positive_float(number, "bias")
461
+ """
462
+ return _check_number(arg_value, 0, LE, float, arg_name, prim_name)
463
+
464
+
465
+ def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
466
+ """
467
+ Check argument is non-negative float, which mean arg_value >= 0.
468
+
469
+ Usage:
470
+ - number = check_non_negative_float(number)
471
+ - number = check_non_negative_float(number, "bias")
472
+ """
473
+ return _check_number(arg_value, 0, GE, float, arg_name, prim_name)
474
+
475
+
476
+ def check_number(arg_name, arg_value, value, rel, prim_name):
477
+ """Number value judgment."""
478
+ def _check():
479
+ if not _check_binary_rel(arg_value, value, rel):
480
+ rel_str = _format_str_one_value(value, rel)
481
+ raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
482
+ f'must {rel_str}, but got {arg_value}.')
483
+ _check()
484
+ return arg_value
485
+
486
+
487
+ def check_isinstance(arg_name, arg_value, classes):
488
+ """Check arg isinstance of classes"""
489
+ def _check():
430
490
  if not isinstance(arg_value, classes):
431
491
  raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
432
- return arg_value
492
+ _check()
493
+ return arg_value
433
494
 
434
- @staticmethod
435
- def check_bool(arg_value, arg_name=None, prim_name=None):
436
- """
437
- Check argument is instance of bool.
438
495
 
439
- Usage:
440
- - has_bias = check_bool(has_bias)
441
- - has_bias = check_bool(has_bias, "has_bias")
442
- """
496
+ def check_bool(arg_value, arg_name=None, prim_name=None):
497
+ """
498
+ Check argument is instance of bool.
499
+
500
+ Usage:
501
+ - has_bias = check_bool(has_bias)
502
+ - has_bias = check_bool(has_bias, "has_bias")
503
+ """
504
+ prim_name = f"For '{prim_name}', the" if prim_name else 'The'
505
+ arg_name = f"'{arg_name}'" if arg_name else 'input value'
506
+
507
+ def _check():
443
508
  if not isinstance(arg_value, bool):
444
- prim_name = f"For '{prim_name}', the" if prim_name else 'The'
445
- arg_name = f"'{arg_name}'" if arg_name else 'input value'
446
509
  raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
447
- return arg_value
448
-
449
- @staticmethod
450
- def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
451
- """
452
- Method for checking whether input value is in int range.
453
-
454
- Usage:
455
- - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
456
- - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
457
- """
458
- return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
459
-
460
- @staticmethod
461
- def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
462
- """
463
- Method for checking whether input value is in float range.
464
-
465
- Usage:
466
- - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
467
- - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
468
- """
469
- return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
470
-
471
- @staticmethod
472
- def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
473
- """
474
- Check whether string is in some value list.
475
-
476
- Usage:
477
- - method = check_string(method, ["string1", "string2", "string3"], "method")
478
- """
479
- if isinstance(arg_value, str) and arg_value in valid_values:
480
- return arg_value
481
- arg_name = arg_name if arg_name else "parameter"
482
- msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
483
- raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}',"
484
- f" but got '{arg_value}'.")
485
-
486
- @staticmethod
487
- def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
488
- if reg is None:
489
- # Named string regular expression
490
- reg = r"^\w+[0-9a-zA-Z\_\.]*$"
491
- if re.match(reg, target, flag) is None:
492
- prim_name = f"For '{prim_name}', the" if prim_name else "The"
493
- raise ValueError("{} '{}' is illegal, it must be match regular'{}' by flags'{}.'".format(
494
- prim_name, target, reg, flag))
495
- return True
510
+ _check()
511
+ return arg_value
496
512
 
497
- @staticmethod
498
- def check_file_name_by_regular(target, reg=None, prim_name=None):
499
- """Check whether file name is legitimate."""
500
- if not isinstance(target, str):
501
- prim_name = f"For '{prim_name}', the" if prim_name else "The"
502
- raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
503
- if target.endswith("\\") or target.endswith("/"):
504
- prim_name = f"For '{prim_name}', the" if prim_name else "The"
505
- raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
506
- if reg is None:
507
- reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
508
- if re.match(reg, target) is None:
509
- prim_name = f"For '{prim_name}', the" if prim_name else "The"
510
- raise ValueError("{} '{}' is illegal, it must be match regular '{}'.".format(
511
- prim_name, target, reg))
512
513
 
513
- return True
514
+ def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
515
+ """
516
+ Method for checking whether input value is in int range.
517
+
518
+ Usage:
519
+ - number = check_int_range(number, 0, 1, INC_NEITHER) # number in [0, 1]
520
+ - number = check_int_range(number, 0, 1, INC_NEITHER, "number") # number in [0, 1]
521
+ """
522
+ return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
523
+
524
+
525
+ def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
526
+ """
527
+ Method for checking whether input value is in float range.
528
+
529
+ Usage:
530
+ - number = check_float_range(number, 0.0, 1.0, INC_NEITHER) # number in [0.0, 1.0]
531
+ - number = check_float_range(number, 0.0, 1.0, INC_NEITHER, "number") # number in [0.0, 1.0]
532
+ """
533
+ return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
534
+
535
+
536
+ def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
537
+ """
538
+ Check whether string is in some value list.
539
+
540
+ Usage:
541
+ - method = check_string(method, ["string1", "string2", "string3"], "method")
542
+ """
543
+ arg_name = arg_name if arg_name else "parameter"
544
+ msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
545
+
546
+ def _check():
547
+ if not (isinstance(arg_value, str) and arg_value in valid_values):
548
+ raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
549
+ f" but got '{arg_value}'.")
550
+ _check()
551
+ return arg_value
552
+
553
+
554
+ def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
555
+ if reg is None:
556
+ # Named string regular expression
557
+ reg = r"^\w+[0-9a-zA-Z\_\.]*$"
558
+ if re.match(reg, target, flag) is None:
559
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
560
+ raise ValueError("{} '{}' is illegal, it must be match regular'{}' by flags'{}.'".format(
561
+ prim_name, target, reg, flag))
562
+ return True
563
+
564
+
565
+ def check_file_name_by_regular(target, reg=None, prim_name=None):
566
+ """Check whether file name is legitimate."""
567
+ if not isinstance(target, str):
568
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
569
+ raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
570
+ if target.endswith("\\") or target.endswith("/"):
571
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
572
+ raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
573
+ if reg is None:
574
+ reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$"
575
+ if re.match(reg, target) is None:
576
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
577
+ raise ValueError("{} '{}' is illegal, it must be match regular '{}'.".format(
578
+ prim_name, target, reg))
579
+
580
+ return True
581
+
514
582
 
515
- @staticmethod
516
- def check_pad_value_by_mode(pad_mode, padding, prim_name):
517
- """Validates value of padding according to pad_mode"""
518
- if pad_mode != 'pad' and padding != 0:
519
- raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}',"
520
- f" but got {padding}.")
521
- return padding
522
-
523
- @staticmethod
524
- def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
525
- """Checks whether some type is subclass of another type"""
526
- if not isinstance(template_types, Iterable):
527
- template_types = (template_types,)
528
- hit = False
529
- for template_type in template_types:
530
- if isinstance(template_type, mstype.Type):
531
- if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
532
- hit = True
533
- break
534
- elif type_ is template_type:
583
+ def check_pad_value_by_mode(pad_mode, padding, prim_name):
584
+ """Validates value of padding according to pad_mode"""
585
+ if pad_mode != 'pad' and padding != 0:
586
+ raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'," \
587
+ f" but got {padding}.")
588
+ return padding
589
+
590
+
591
+ def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
592
+ """Checks whether some type is subclass of another type"""
593
+ if not isinstance(template_types, Iterable):
594
+ template_types = (template_types,)
595
+ hit = False
596
+ for template_type in template_types:
597
+ if isinstance(template_type, mstype.Type):
598
+ if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
535
599
  hit = True
536
600
  break
537
- if not hit:
538
- if addition_error_info is None:
539
- addition_error_info = ''
540
- else:
541
- addition_error_info = ' ' + addition_error_info
542
- type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_))
543
- raise TypeError(f"For '{prim_name}', the type of '{arg_name}'"
544
- f" must be {'one of ' if len(template_types) > 1 else ''}"
545
- f"{', '.join((str(x) for x in template_types))}, but got {type_str}"
546
- f"{addition_error_info}.The supported data types depend on the hardware that"
547
- f" executes the operator, for more details, please refer to the MindSpore official "
548
- f"website to get more information about the data type.")
549
-
550
- @staticmethod
551
- def check_valid_input(arg_name, arg_value, prim_name):
552
- """Checks valid value."""
601
+ elif type_ is template_type:
602
+ hit = True
603
+ break
604
+ if not hit:
605
+ if addition_error_info is None:
606
+ addition_error_info = ''
607
+ else:
608
+ addition_error_info = ' ' + addition_error_info
609
+ type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_))
610
+ raise TypeError(f"For '{prim_name}', the element of '{arg_name}'" \
611
+ f" must be {'one of ' if len(template_types) > 1 else ''}" \
612
+ f"{', '.join((str(x) for x in template_types))}, but got {type_str}" \
613
+ f"{addition_error_info}.The supported data types depend on the hardware that" \
614
+ f" executes the operator, for more details, please refer to the MindSpore official " \
615
+ f"website to get more information about the data type.")
616
+
617
+
618
+ def check_valid_input(arg_name, arg_value, prim_name):
619
+ """Checks valid value."""
620
+ def _check():
553
621
  if arg_value is None:
554
- raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}' can not be None, but got {arg_value}.")
555
- return arg_value
556
-
557
- @staticmethod
558
- def check_types_same_and_valid(args, valid_values, prim_name):
559
- """Checks whether the types of inputs are the same and valid."""
560
-
561
- def _check_type_valid(arg):
562
- arg_key, arg_val = arg
563
- elem_type = arg_val
564
- Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
565
- return (arg_key, elem_type)
566
-
567
- def _check_types_same(arg1, arg2):
568
- arg1_name, arg1_type = arg1
569
- arg2_name, arg2_type = arg2
570
- if arg1_type != arg2_type:
571
- raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}',"
572
- f" but got '{arg1_name}' with type {arg1_type}"
573
- f" and '{arg2_name}' with type {arg2_type}.")
574
- return arg1
575
-
576
- elem_types = map(_check_type_valid, args.items())
577
- reduce(_check_types_same, elem_types)
578
-
579
- @staticmethod
580
- def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
581
- """Checks whether the element types of input tensors are the same and valid."""
582
- valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
583
- tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
584
- Validator.check_types_same_and_valid(args, tensor_types, prim_name)
585
-
586
- @staticmethod
587
- def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
588
- """Checks whether the element types of input tensors are valid."""
589
- valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
590
- tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
591
- Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)
592
-
593
- @staticmethod
594
- def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
595
- """
596
- Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
597
- If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
598
- """
599
-
600
- def _check_argument_type(arg):
601
- arg_key, arg_val = arg
602
- if isinstance(arg_val, type(mstype.tensor)):
603
- arg_val = arg_val.element_type()
604
- if arg_val not in valid_values:
605
- raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},'
606
- f' but got {arg_val}.')
607
- return arg
608
-
609
- def _check_types_same(arg1, arg2):
610
- arg1_name, arg1_type = arg1
611
- arg2_name, arg2_type = arg2
612
- except_flag = False
613
- if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
614
- arg1_type = arg1_type.element_type()
615
- arg2_type = arg2_type.element_type()
616
- elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
617
- pass
618
- elif allow_mix:
619
- arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
620
- arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
621
- else:
622
- except_flag = True
623
-
624
- if except_flag or arg1_type != arg2_type:
625
- raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}',"
626
- f" but got '{arg1_name}' with type {arg1_type}"
627
- f" and '{arg2_name}' with type {arg2_type}.")
628
- return arg1
629
-
630
- args_map = map(_check_argument_type, args.items())
631
- reduce(_check_types_same, args_map)
632
-
633
- @staticmethod
634
- def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
635
- """Checks whether a value is instance of some types."""
636
- valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
637
-
638
- def raise_error_msg():
639
- """func for raising error message when check failed"""
640
- type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
641
- num_types = len(valid_types)
642
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
643
- raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}'
644
- f'\'{type_names if num_types > 1 else type_names[0]}\', '
645
- f'but got type \'{type(arg_value).__name__}\'.')
646
-
647
- # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
648
- # `check_value_type('x', True, [bool, int])` will check pass
649
- if isinstance(arg_value, bool) and bool not in tuple(valid_types):
650
- raise_error_msg()
651
- if isinstance(arg_value, float) and float not in tuple(valid_types):
652
- arg_value = round(arg_value, 6)
653
- if not isinstance(arg_value, tuple(valid_types)):
654
- raise_error_msg()
655
- return arg_value
656
-
657
- @staticmethod
658
- def check_type_name(arg_name, arg_type, valid_types, prim_name):
659
- """Checks whether a type in some specified types"""
660
- valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
661
-
662
- def raise_error_msg():
663
- """func for raising error message when check failed"""
664
- type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
665
- num_types = len(valid_types)
666
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
667
- raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
668
- f"{type_names if num_types > 1 else type_names[0]}, "
669
- f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
670
-
671
- if isinstance(arg_type, type(mstype.tensor)):
672
- arg_type = arg_type.element_type()
673
- if arg_type not in valid_types:
674
- raise_error_msg()
675
- return arg_type
676
-
677
- @staticmethod
678
- def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
679
- """Checks whether shape is ori_shape reduced on axis"""
680
- axis_origin = axis
681
- axis = axis if isinstance(axis, Iterable) else (axis,)
682
- exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
683
- if list(shape) != exp_shape:
684
- raise ValueError(f"For '{prim_name}', "
685
- f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must "
686
- f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
687
-
688
- @staticmethod
689
- def check_astype_dtype(dtype):
690
- """Check whether dtype is a valid input, and convert to mstype"""
691
- all_types = mstype.__dtype__ + ["int", "float", "bool"]
692
- if isinstance(dtype, str):
693
- if dtype.lower() not in all_types:
694
- raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
695
- dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
696
- elif isinstance(dtype, type):
697
- dtype = mstype.pytype_to_dtype(dtype)
698
- elif not dtype in mstype.number_type + (mstype.bool_,):
699
- raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)},"
700
- f" but got '{dtype}'.")
701
- return dtype
702
-
703
- @staticmethod
704
- def check_transpose_axis(axes, ndim):
705
- """Check the axis argument for tensor.transpose"""
706
- if not axes or (len(axes) == 1 and axes[0] is None):
707
- return tuple(range(ndim-1, -1, -1))
708
-
709
- if len(axes) == 1:
710
- perm = axes[0]
711
- # if only one argument provided, it must be tuple or list
712
- if isinstance(perm, list):
713
- perm = tuple(perm)
714
- else:
715
- if not isinstance(perm, tuple):
716
- raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, "
717
- f"or series of integer, but got {type(axes[0])}")
718
- return perm
622
+ raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
623
+ f"can not be None, but got {arg_value}.")
624
+ _check()
625
+ return arg_value
626
+
627
+
628
+ def check_types_same_and_valid(args, valid_values, prim_name):
629
+ """Checks whether the types of inputs are the same and valid."""
630
+
631
+ def _check_type_valid(arg):
632
+ arg_key, arg_val = arg
633
+ elem_type = arg_val
634
+ check_subclass(arg_key, elem_type, valid_values, prim_name)
635
+ return (arg_key, elem_type)
636
+
637
+ def _check_types_same(arg1, arg2):
638
+ arg1_name, arg1_type = arg1
639
+ arg2_name, arg2_type = arg2
640
+ if arg1_type != arg2_type:
641
+ raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}'," \
642
+ f" but got '{arg1_name}' with type {arg1_type}" \
643
+ f" and '{arg2_name}' with type {arg2_type}.")
644
+ return arg1
645
+
646
+ elem_types = map(_check_type_valid, args.items())
647
+ reduce(_check_types_same, elem_types)
648
+
649
+
650
+ def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
651
+ """Checks whether the element types of input tensors are the same and valid."""
652
+ valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
653
+ tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
654
+ check_types_same_and_valid(args, tensor_types, prim_name)
655
+
656
+
657
+ def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
658
+ """Checks whether the element types of input tensors are valid."""
659
+ valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
660
+ tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
661
+ check_subclass(arg_name, arg_type, tensor_types, prim_name)
662
+
663
+
664
+ def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
665
+ """
666
+ Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
667
+ If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
668
+ """
669
+
670
+ def _check_argument_type(arg):
671
+ arg_key, arg_val = arg
672
+ if isinstance(arg_val, type(mstype.tensor)):
673
+ arg_val = arg_val.element_type()
674
+ if arg_val not in valid_values:
675
+ raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},' \
676
+ f' but got {arg_val}.')
677
+ return arg
678
+
679
+ def _check_types_same(arg1, arg2):
680
+ arg1_name, arg1_type = arg1
681
+ arg2_name, arg2_type = arg2
682
+ except_flag = False
683
+ if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
684
+ arg1_type = arg1_type.element_type()
685
+ arg2_type = arg2_type.element_type()
686
+ elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
687
+ pass
688
+ elif allow_mix:
689
+ arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
690
+ arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
691
+ else:
692
+ except_flag = True
693
+
694
+ if except_flag or arg1_type != arg2_type:
695
+ raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}'," \
696
+ f" but got '{arg1_name}' with type {arg1_type}" \
697
+ f" and '{arg2_name}' with type {arg2_type}.")
698
+ return arg1
699
+
700
+ args_map = map(_check_argument_type, args.items())
701
+ reduce(_check_types_same, args_map)
702
+
703
+
704
+ def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
705
+ """Checks whether a value is instance of some types."""
706
+ valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
707
+
708
+ def raise_error_msg(cond, arg_value):
709
+ """func for raising error message when check failed"""
710
+ if not cond:
711
+ return
712
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
713
+ num_types = len(valid_types)
714
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
715
+ raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
716
+ f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
717
+ f'but got type \'{type(arg_value).__name__}\'.')
718
+
719
+ # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
720
+ # `check_value_type('x', True, [bool, int])` will check pass
721
+ cond = isinstance(arg_value, bool) and bool not in tuple(valid_types)
722
+ raise_error_msg(cond, arg_value)
723
+ if isinstance(arg_value, float) and float not in tuple(valid_types):
724
+ arg_value = round(arg_value, 6)
725
+ cond = not isinstance(arg_value, tuple(valid_types))
726
+ raise_error_msg(cond, arg_value)
727
+ return arg_value
728
+
729
+
730
+ def check_type_name(arg_name, arg_type, valid_types, prim_name):
731
+ """Checks whether a type in some specified types"""
732
+ valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
719
733
 
734
+ def raise_error_msg(cond, arg_type):
735
+ """func for raising error message when check failed"""
736
+ if not cond:
737
+ return
738
+ type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
739
+ num_types = len(valid_types)
740
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
741
+ raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
742
+ f"{type_names if num_types > 1 else type_names[0]}, " \
743
+ f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
744
+
745
+ if isinstance(arg_type, type(mstype.tensor)):
746
+ arg_type = arg_type.element_type()
747
+ cond = arg_type not in valid_types
748
+ raise_error_msg(cond, arg_type)
749
+ return arg_type
750
+
751
+
752
+ def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
753
+ """Checks whether shape is ori_shape reduced on axis"""
754
+ axis_origin = axis
755
+ axis = axis if isinstance(axis, Iterable) else (axis,)
756
+ exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
757
+ if list(shape) != exp_shape:
758
+ raise ValueError(f"For '{prim_name}', " \
759
+ f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must " \
760
+ f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
761
+
762
+
763
+ def check_astype_dtype(dtype):
764
+ """Check whether dtype is a valid input, and convert to mstype"""
765
+ all_types = mstype.__dtype__ + ["int", "float", "bool"]
766
+ if isinstance(dtype, str):
767
+ if dtype.lower() not in all_types:
768
+ raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
769
+ dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
770
+ elif isinstance(dtype, type):
771
+ dtype = mstype.pytype_to_dtype(dtype)
772
+ elif not dtype in mstype.number_type + (mstype.bool_,):
773
+ raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)}," \
774
+ f" but got '{dtype}'.")
775
+ return dtype
776
+
777
+
778
+ def check_transpose_axis(axes, ndim):
779
+ """Check the axis argument for tensor.transpose"""
780
+ def _check_dim():
720
781
  # if multiple arguments provided, it must be `ndim` number of ints
721
782
  if len(axes) != ndim:
722
- raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, "
783
+ raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, " \
723
784
  f"but got {len(axes)} in the number of axes.")
724
- return axes
725
785
 
726
- @staticmethod
727
- def check_reshape_shp(shp):
728
- """Check the shape argument for tensor.reshape"""
729
-
730
- if len(shp) == 1:
731
- new_shape = shp[0]
732
- # if only one argument provided, it must be int, tuple or list
733
- if isinstance(new_shape, int):
734
- return shp
735
- if isinstance(new_shape, list):
736
- new_shape = tuple(new_shape)
737
- else:
738
- if not isinstance(new_shape, tuple):
739
- raise TypeError(
740
- f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, "
741
- f"or series of integer, but got {type(shp[0])}")
742
- return new_shape
743
-
744
- return shp
745
-
746
- @staticmethod
747
- def check_flatten_order(order):
748
- """Check flatten function input order"""
749
- if not isinstance(order, str):
750
- raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
751
- if order not in ('C', 'F'):
752
- raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
753
- return order
754
-
755
- @staticmethod
756
- def check_swapaxes_axis(axes, ndim):
757
- """Check all the axes argument for tensor.swapaxes"""
758
- if isinstance(axes, int):
759
- Validator.check_axis_in_range(axes, ndim)
760
- return axes % ndim
761
- if isinstance(axes, (tuple, list)):
762
- for axis in axes:
763
- if not isinstance(axis, int):
764
- raise TypeError(f"For Tensor.swapaxes, the axis argument must be integer, but got {type(axis)}.")
765
- Validator.check_axis_in_range(axis, ndim)
766
- axes = tuple(map(lambda x: x % ndim, axes))
767
- return axes
768
- raise TypeError(f"For Tensor.swapaxes, the argument 'axes' must be integer, list or tuple for check, "
769
- f"but got {type(axes)}.")
770
-
771
- @staticmethod
772
- def prepare_shape_for_squeeze(shape, axes):
773
- """
774
- Creates the squeezed new shape based on the tensor and given axes.
775
-
776
- Args:
777
- shape (tuple): the shape of the tensor
778
- axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
779
- be squeezed.
780
-
781
- Returns:
782
- new_shape(tuple): the shape with dimensions squeezed.
783
- """
784
- new_shape = []
785
- ndim = len(shape)
786
-
787
- # Convert to set
788
- if isinstance(axes, int):
789
- if axes >= ndim or axes < -ndim:
790
- raise ValueError(f"For Tensor.squeeze, "
791
- f"the 'axis' must be in the range of [-{ndim}, {ndim}), but got {axes}.")
792
- axes = {axes}
793
-
794
- elif isinstance(axes, (list, tuple)):
795
- for axis in axes:
796
- if axis >= ndim or axis < -ndim:
797
- raise ValueError(f"For Tensor.squeeze, "
798
- f"the 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
799
- axes = set(axes)
786
+ if not axes or (len(axes) == 1 and axes[0] is None):
787
+ return tuple(range(ndim-1, -1, -1))
788
+
789
+ if len(axes) == 1:
790
+ perm = axes[0]
791
+ # if only one argument provided, it must be tuple or list
792
+ if isinstance(perm, list):
793
+ perm = tuple(perm)
794
+ else:
795
+ if not isinstance(perm, tuple):
796
+ raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \
797
+ f"or series of integer, but got {type(axes[0])}")
798
+ return perm
799
+
800
+ _check_dim()
801
+ return axes
802
+
803
+
804
+ def check_reshape_shp(shp):
805
+ """Check the shape argument for tensor.reshape"""
800
806
 
807
+ if len(shp) == 1:
808
+ new_shape = shp[0]
809
+ # if only one argument provided, it must be int, tuple or list
810
+ if isinstance(new_shape, int):
811
+ return shp
812
+ if isinstance(new_shape, list):
813
+ new_shape = tuple(new_shape)
801
814
  else:
802
- raise TypeError(f"For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], "
803
- f"but got {type(axes)}")
804
-
805
- for idx, s in enumerate(shape):
806
- if s != 1 or (idx not in axes) and (idx - ndim not in axes):
807
- new_shape.append(s)
808
- # if an axis is selected with shape entry greater than one, an error is raised.
809
- if s != 1 and ((idx in axes) or (idx - ndim in axes)):
810
- raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.")
811
- return tuple(new_shape)
812
-
813
- @staticmethod
814
- def check_axis_in_range(axis, ndim):
815
- """Checks axes are with the bounds of ndim"""
815
+ if not isinstance(new_shape, tuple):
816
+ raise TypeError(
817
+ f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, " \
818
+ f"or series of integer, but got {type(shp[0])}")
819
+ return new_shape
820
+
821
+ return shp
822
+
823
+
824
+ def check_flatten_order(order):
825
+ """Check flatten function input order"""
826
+ if not isinstance(order, str):
827
+ raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
828
+ if order not in ('C', 'F'):
829
+ raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
830
+
831
+
832
+ def check_swapaxes_axis(axes, ndim):
833
+ """Check all the axes argument for ops.swapaxes"""
834
+ if isinstance(axes, int):
835
+ return check_axis_in_range(axes, ndim)
836
+ if isinstance(axes, (tuple, list)):
837
+ for axis in axes:
838
+ if not isinstance(axis, int):
839
+ raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
840
+ check_axis_in_range(axis, ndim)
841
+ tmp = ()
842
+ for x in axes:
843
+ tmp = tmp + ((x + ndim) % ndim,)
844
+ return tmp
845
+ raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " \
846
+ f"but got {type(axes)}.")
847
+
848
+
849
+ def prepare_shape_for_squeeze(shape, axes):
850
+ """
851
+ Creates the squeezed new shape based on the tensor and given axes.
852
+
853
+ Args:
854
+ shape (tuple): the shape of the tensor
855
+ axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
856
+ be squeezed.
857
+
858
+ Returns:
859
+ new_shape(tuple): the shape with dimensions squeezed.
860
+ """
861
+ new_shape = ()
862
+ ndim = len(shape)
863
+
864
+ def _check(axes, ndim):
865
+ if axes >= ndim or axes < -ndim:
866
+ raise ValueError("For Tensor.squeeze, the 'axis' must be in the range of [-{0}, {0}), but got {1}." \
867
+ .format(ndim, axes))
868
+
869
+ def _check_for(axes, ndim):
870
+ for axis in axes:
871
+ _check(axis, ndim)
872
+
873
+ if isinstance(axes, int):
874
+ _check(axes, ndim)
875
+ axes = (axes,)
876
+ elif isinstance(axes, (list, tuple)):
877
+ _check_for(axes, ndim)
878
+ new_axes = ()
879
+ for item in axes:
880
+ if item not in new_axes:
881
+ new_axes += (item,)
882
+ axes = new_axes
883
+ else:
884
+ raise TypeError("For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], but got {}" \
885
+ .format(type(axes)))
886
+
887
+ def _check_axis(s, idx, axes, ndim):
888
+ # if an axis is selected with shape entry greater than one, an error is raised.
889
+ if s != 1 and ((idx in axes) or (idx - ndim in axes)):
890
+ raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.")
891
+
892
+ for idx, s in enumerate(shape):
893
+ _check_axis(s, idx, axes, ndim)
894
+ if s != 1 or (idx not in axes) and (idx - ndim not in axes):
895
+ new_shape = new_shape + (s,)
896
+
897
+ return new_shape
898
+
899
+
900
+ def check_axis_in_range(axis, ndim):
901
+ """Checks axes are with the bounds of ndim"""
902
+ def _check():
816
903
  if not isinstance(axis, int):
817
904
  raise TypeError(f'The axes must be integers, but got {type(axis)}')
818
- if not -ndim <= axis < ndim:
905
+
906
+ if axis >= ndim or axis < -ndim:
819
907
  raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
820
- return axis % ndim
821
-
822
- @staticmethod
823
- def check_axis_valid(axes, ndim):
824
- """
825
- Checks axes are valid given ndim, and returns axes that can be passed
826
- to the built-in operator (non-negative, int or tuple)
827
- """
828
- if axes is None:
829
- axes = tuple(range(ndim))
830
- return axes
831
- if isinstance(axes, (tuple, list)):
832
- for axis in axes:
833
- Validator.check_axis_in_range(axis, ndim)
834
- axes = tuple(map(lambda x: x % ndim, axes))
835
- if any(axes.count(el) > 1 for el in axes):
836
- raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.")
837
- return axes
838
- Validator.check_axis_in_range(axes, ndim)
839
- return (axes % ndim,)
840
-
841
- @staticmethod
842
- def max_(*args):
843
- return max(*args)
844
-
845
- @staticmethod
846
- def min_(*args):
847
- return min(*args)
848
-
849
- @staticmethod
850
- def expanded_shape(ndim, axis_size, axis):
851
- """
852
- Returns a shape with size = 1 for all dimensions
853
- except at axis.
854
- """
855
- return tuple(axis_size if i == axis else 1 for i in range(ndim))
856
-
857
- @staticmethod
858
- def tuple_slice(tup, start, end):
859
- """get sliced tuple from start and end."""
860
- return tup[start:end]
861
-
862
- @staticmethod
863
- def infer_out_shape(*shapes):
864
- """
865
- Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
866
- """
867
- shape_out = deque()
868
- reversed_shapes = map(reversed, shapes)
869
- for items in zip_longest(*reversed_shapes, fillvalue=1):
870
- max_size = 0 if 0 in items else max(items)
871
- if any(item not in (1, max_size) for item in items):
872
- raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max on the axis'
873
- f'to support broadcast, but got shapes {*shapes,}')
874
- shape_out.appendleft(max_size)
875
- return tuple(shape_out)
876
-
877
- @staticmethod
878
- def get_log2_size(size):
879
- return math.ceil(math.log2(size))
880
-
881
- @staticmethod
882
- def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
883
- """Check axis argument type."""
884
- if type_int and isinstance(axis, int):
885
- return True
886
- if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
887
- for ax in axis:
888
- if not isinstance(ax, int):
889
- raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.")
890
- return True
891
-
892
- type_str = ""
893
- if type_int:
894
- type_str += "int, "
895
- if type_tuple:
896
- type_str += "tuple, "
897
- if type_list:
898
- type_str += "list, "
899
- raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
900
-
901
- @staticmethod
902
- def check_and_canonicalize_axes(axes, ndim):
903
- """Check whether the types and values of input axes are valid."""
904
- axes = axes if isinstance(axes, tuple) else (axes,)
905
- new_axes = ()
906
- for ax in axes:
908
+
909
+ _check()
910
+ return (axis + ndim) % ndim
911
+
912
+
913
+ def check_axis_valid(axes, ndim):
914
+ """
915
+ Checks axes are valid given ndim, and returns axes that can be passed
916
+ to the built-in operator (non-negative, int or tuple)
917
+ """
918
+ def _check_range(axes):
919
+ for axis in axes:
920
+ check_axis_in_range(axis, ndim)
921
+
922
+ if axes is None:
923
+ axes = tuple(range(ndim))
924
+ return axes
925
+ if isinstance(axes, (tuple, list)):
926
+ _check_range(axes)
927
+ tmp = ()
928
+ for x in axes:
929
+ tmp = tmp + ((x + ndim) % ndim,)
930
+ _check_dup(tmp)
931
+ return tmp
932
+ check_axis_in_range(axes, ndim)
933
+ return (axes % ndim,)
934
+
935
+
936
+ def max_(*args):
937
+ return max(*args)
938
+
939
+
940
+ def min_(*args):
941
+ return min(*args)
942
+
943
+
944
+ def expanded_shape(ndim, axis_size, axis):
945
+ """
946
+ Returns a shape with size = 1 for all dimensions
947
+ except at axis.
948
+ """
949
+ return tuple(axis_size if i == axis else 1 for i in range(ndim))
950
+
951
+
952
+ def tuple_slice(tup, start, end):
953
+ """get sliced tuple from start and end."""
954
+ return tup[start:end]
955
+
956
+
957
+ def infer_out_shape(*shapes):
958
+ """
959
+ Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
960
+ """
961
+ def _check(items, max_size, shapes):
962
+ for item in items:
963
+ if item not in (1, max_size):
964
+ raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max on the axis' \
965
+ f'to support broadcast, but got shapes {shapes,}')
966
+ shape_out = ()
967
+ max_len = max([len(it) for it in shapes])
968
+ for i in range(max_len):
969
+ items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
970
+ >= 0 else 1 for it in shapes]
971
+ max_size = 0 if 0 in items else max(items)
972
+ _check(items, max_size, shapes)
973
+ shape_out = shape_out + (max_size,)
974
+ return shape_out
975
+
976
+
977
+ def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
978
+ """Check axis argument type."""
979
+ if type_int and isinstance(axis, int):
980
+ return True
981
+ if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
982
+ for ax in axis:
907
983
  if not isinstance(ax, int):
908
- raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
909
- if not -ndim <= ax < ndim:
910
- raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.")
911
- ax = ax if ax >= 0 else ax + ndim
912
- new_axes += (ax,)
913
- if any(new_axes.count(el) > 1 for el in new_axes):
914
- raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {new_axes}.")
915
- return new_axes
916
-
917
- @staticmethod
918
- def empty_compile(dtype, shape):
919
- """Returns an empty Tensor."""
920
- return Tensor_(dtype, shape)
921
-
922
- @staticmethod
923
- def check_type_support(dtype, device, supported_dtypes):
924
- """Checks whether the data type is supported."""
925
- return dtype in supported_dtypes or not context.get_context('device_target') == device
926
-
927
- @staticmethod
928
- def check_sparse_tensor_input(indices, values, shape):
929
- """Common input check for SparseTensors."""
930
- if not isinstance(indices, Tensor_):
931
- raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
932
- if not isinstance(values, Tensor_):
933
- raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
934
- if not isinstance(shape, tuple):
935
- raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
936
-
937
- @staticmethod
938
- def check_csr_tensor_input(indptr, indices, values, shape):
939
- """Checks inputs type for CSRTensor."""
940
- if not isinstance(indptr, Tensor_):
941
- raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
942
- Validator.check_sparse_tensor_input(indices, values, shape)
943
-
944
- @staticmethod
945
- def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
946
- """Checks input tensors' shapes for CSRTensor."""
947
- shape_size = 1
948
- val_shp_size = 1
949
- for item in csr_shp:
950
- if item <= 0:
951
- raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}")
952
- if not isinstance(item, int):
953
- raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}")
954
- shape_size *= item
955
- for item in values_shp:
956
- if item <= 0:
957
- raise ValueError(f"The element of shape must be positive, but got {item}")
958
- val_shp_size *= item
959
- if shape_size < val_shp_size:
960
- raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.")
961
- if len(indices_shp) != 1:
962
- raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, "
963
- f"but got a {len(indices_shp)} dimension tensor.")
964
- if len(indptr_shp) != 1:
965
- raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, "
966
- f"but got a {len(indptr_shp)} dimension tensor.")
967
- if csr_shp[0] + 1 != indptr_shp[0]:
968
- raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), "
969
- f"but got: {indptr_shp[0]}")
970
- if indices_shp[0] != values_shp[0]:
971
- err_msg1 = "For CSRTensor, indices and values must equal in their shape, "
972
- err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
973
- raise ValueError(err_msg1 + err_msg2)
974
- if len(values_shp) + 1 != len(csr_shp):
975
- raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got"\
976
- f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: "\
977
- f"{len(csr_shp)}")
978
- if values_shp[1: ] != csr_shp[2: ]:
979
- raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ],"\
980
- f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]"\
981
- f"got: {values_shp[1: ]}")
982
-
983
- @staticmethod
984
- def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
985
- """Checks input tensors' data types for CSRTensor."""
986
- if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
987
- raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, "
988
- f"but got {indptr_dtype}.")
989
- if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
990
- raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, "
991
- f"but got {indices_dtype}.")
992
-
993
- @staticmethod
994
- def check_coo_tensor_input(indices, values, shape):
995
- """Checks inputs type for COOTensor."""
996
- Validator.check_sparse_tensor_input(indices, values, shape)
997
-
998
- @staticmethod
999
- def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
1000
- """Checks input tensors' shapes for COOTensor."""
1001
- if len(coo_shp) != 2:
1002
- raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.")
1003
- shp_mul = 1
1004
- for sh in coo_shp:
1005
- if sh <= 0:
1006
- raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.")
1007
- if not isinstance(sh, int):
1008
- raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
1009
- shp_mul *= sh
1010
- if shp_mul < values_shp[0]:
1011
- raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).")
1012
- if len(indices_shp) != 2:
1013
- raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}"
1014
- f"-dimensional tensor.")
1015
- if len(values_shp) != 1:
1016
- raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}"
1017
- f"-dimensional tensor.")
1018
- if indices_shp[0] != values_shp[0]:
1019
- raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got "
1020
- f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
1021
- if indices_shp[1] != 2:
1022
- raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.")
1023
-
1024
- @staticmethod
1025
- def check_coo_tensor_dtype(indices_dtype):
1026
- """Checks input tensors' data types for COOTensor."""
1027
- if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1028
- raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got "
1029
- f"{indices_dtype}.")
1030
-
1031
- @staticmethod
1032
- def check_dynamic_shape(dyn_inputs, actual_inputs):
1033
- """Check the consistency of dynamic shape tensors and actual input tensors."""
1034
- if len(dyn_inputs) != len(actual_inputs):
1035
- raise ValueError(f"The number of actual input tensors: {len(actual_inputs)} is not equal to the number of "
1036
- f"dynamic shape tensors: {len(dyn_inputs)}.")
1037
- for i, dyn_elem in enumerate(dyn_inputs):
1038
- if dyn_elem.dtype is not actual_inputs[i].dtype:
1039
- raise TypeError(f"The data type of `{i}`th args in actual input tensors should be `{dyn_elem.dtype}`, "
1040
- f"but got `{actual_inputs[i].dtype}`.")
1041
- if dyn_elem.ndim != actual_inputs[i].ndim:
1042
- raise ValueError(f"The dimension of `{i}`th args in actual input tensors should be `{dyn_elem.ndim}`, "
1043
- f"but got `{actual_inputs[i].ndim}`.")
1044
- check_dyn_shape_value_equal(i, dyn_elem.shape, actual_inputs[i].shape)
1045
-
1046
- @staticmethod
1047
- def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None):
1048
- """Check type of the element of a iterabel object, execpt dict."""
1049
- Validator.check_value_type(arg_name, arg_value, [list, tuple], prim_name)
1050
- type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
1051
- num_types = len(valid_types)
1052
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1053
- for element in arg_value:
1054
- if not isinstance(element, tuple(valid_types)):
1055
- raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
1056
- f"{type_names if num_types > 1 else type_names[0]}, "
1057
- f"but got '{element}' with type '{type(element).__name__}'.")
1058
-
1059
- @staticmethod
1060
- def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None):
1061
- """Check the type of key and value of a dict."""
1062
- Validator.check_value_type(arg_name, arg_value, [dict], prim_name)
1063
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1064
- type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types]
1065
- num_types = len(key_types)
1066
- for element in arg_value.keys():
1067
- if not isinstance(element, tuple(key_types)):
1068
- raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
1069
- f"{type_names if num_types > 1 else type_names[0]}, "
1070
- f"but got '{element}' with type '{type(element).__name__}'.")
1071
-
1072
- type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types]
1073
- num_types = len(value_types)
1074
- for element in arg_value.values():
1075
- if not isinstance(element, tuple(value_types)):
1076
- raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
1077
- f"{type_names if num_types > 1 else type_names[0]}, "
1078
- f"but got '{element}' with type '{type(element).__name__}'.")
1079
-
1080
- @staticmethod
1081
- def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None):
1082
- """Check the size and element type of a tuple."""
1083
- Validator.check_value_type(arg_name, arg_value, [tuple], prim_name)
1084
- Validator.check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name)
1085
- Validator.check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name)
984
+ raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.")
985
+ return True
986
+
987
+ type_str = ""
988
+ if type_int:
989
+ type_str += "int, "
990
+ if type_tuple:
991
+ type_str += "tuple, "
992
+ if type_list:
993
+ type_str += "list, "
994
+ raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
995
+
996
+
997
+ def check_and_canonicalize_axes(axes, ndim):
998
+ """Check whether the types and values of input axes are valid."""
999
+ def _check(axes, ax, ndim):
1000
+ if not isinstance(ax, int):
1001
+ raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
1002
+ if ax >= ndim or ax < -ndim:
1003
+ raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.")
1004
+
1005
+ axes = axes if isinstance(axes, tuple) else (axes,)
1006
+ new_axes = ()
1007
+ for ax in axes:
1008
+ _check(axes, ax, ndim)
1009
+ ax = ax if ax >= 0 else ax + ndim
1010
+ new_axes += (ax,)
1011
+ _check_dup(new_axes)
1012
+ return new_axes
1013
+
1014
+
1015
+ def check_type_support(dtype, device, supported_dtypes):
1016
+ """Checks whether the data type is supported."""
1017
+ return dtype in supported_dtypes or not context.get_context('device_target') == device
1018
+
1019
+
1020
+ def check_sparse_tensor_input(indices, values, shape):
1021
+ """Common input check for SparseTensors."""
1022
+ if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
1023
+ raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
1024
+ if not isinstance(values, Tensor_) and not is_stub_tensor(values):
1025
+ raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
1026
+ if not isinstance(shape, tuple):
1027
+ raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
1028
+
1029
+
1030
+ def check_csr_tensor_input(indptr, indices, values, shape):
1031
+ """Checks inputs type for CSRTensor."""
1032
+ if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
1033
+ raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
1034
+ check_sparse_tensor_input(indices, values, shape)
1035
+
1036
+
1037
+ def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
1038
+ """Checks input tensors' shapes for CSRTensor."""
1039
+ # Support empty sparse tensor
1040
+ if (indptr_shp == (0,)) and (indices_shp == (0,)) and (values_shp == (0,)):
1041
+ return
1042
+ shape_size = 1
1043
+ val_shp_size = 1
1044
+ for item in csr_shp:
1045
+ if item <= 0:
1046
+ raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}")
1047
+ if not isinstance(item, int):
1048
+ raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}")
1049
+ shape_size *= item
1050
+ for item in values_shp:
1051
+ if item <= 0:
1052
+ raise ValueError(f"The element of shape must be positive, but got {item}")
1053
+ val_shp_size *= item
1054
+ if shape_size < val_shp_size:
1055
+ raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.")
1056
+ if len(indices_shp) != 1:
1057
+ raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, " \
1058
+ f"but got a {len(indices_shp)} dimension tensor.")
1059
+ if len(indptr_shp) != 1:
1060
+ raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, " \
1061
+ f"but got a {len(indptr_shp)} dimension tensor.")
1062
+ if csr_shp[0] + 1 != indptr_shp[0]:
1063
+ raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), " \
1064
+ f"but got: {indptr_shp[0]}")
1065
+ if indices_shp[0] != values_shp[0]:
1066
+ err_msg1 = "For CSRTensor, indices and values must equal in their shape, "
1067
+ err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
1068
+ raise ValueError(err_msg1 + err_msg2)
1069
+ if len(values_shp) + 1 != len(csr_shp):
1070
+ raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" \
1071
+ f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " \
1072
+ f"{len(csr_shp)}")
1073
+ if values_shp[1:] != csr_shp[2:]:
1074
+ raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
1075
+ f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]" \
1076
+ f"got: {values_shp[1: ]}")
1077
+
1078
+
1079
+ def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
1080
+ """Checks input tensors' data types for CSRTensor."""
1081
+ if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1082
+ raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, " \
1083
+ f"but got {indptr_dtype}.")
1084
+ if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1085
+ raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, " \
1086
+ f"but got {indices_dtype}.")
1087
+
1088
+
1089
+ def check_coo_tensor_input(indices, values, shape):
1090
+ """Checks inputs type for COOTensor."""
1091
+ check_sparse_tensor_input(indices, values, shape)
1092
+
1093
+
1094
+ def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
1095
+ """Checks input tensors' shapes for COOTensor."""
1096
+ if len(coo_shp) != 2:
1097
+ raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.")
1098
+ if (indices_shp == (0,)) and (values_shp == (0,)):
1099
+ return
1100
+ shp_mul = 1
1101
+ for sh in coo_shp:
1102
+ if sh <= 0:
1103
+ raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.")
1104
+ if not isinstance(sh, int):
1105
+ raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
1106
+ shp_mul *= sh
1107
+ if shp_mul < values_shp[0]:
1108
+ raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).")
1109
+ if len(indices_shp) != 2:
1110
+ raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" \
1111
+ f"-dimensional tensor.")
1112
+ if len(values_shp) != 1:
1113
+ raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" \
1114
+ f"-dimensional tensor.")
1115
+ if indices_shp[0] != values_shp[0]:
1116
+ raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " \
1117
+ f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
1118
+ if indices_shp[1] != 2:
1119
+ raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.")
1120
+
1121
+
1122
+ def check_coo_tensor_dtype(indices_dtype):
1123
+ """Checks input tensors' data types for COOTensor."""
1124
+ if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1125
+ raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got " \
1126
+ f"{indices_dtype}.")
1127
+
1128
+
1129
+ def check_dynamic_shape(dyn_elem, actual_input, i):
1130
+ """Check the consistency of dynamic shape tensors and actual input tensors."""
1131
+ if dyn_elem.dtype != actual_input.dtype:
1132
+ raise TypeError(f"The data type of '{i}'th args in actual input tensors should be '{dyn_elem.dtype}', " \
1133
+ f"but got '{actual_input.dtype}'.")
1134
+ if dyn_elem.ndim != actual_input.ndim:
1135
+ raise ValueError(f"The dimension of '{i}'th args in actual input tensors should be '{dyn_elem.ndim}', " \
1136
+ f"but got '{actual_input.ndim}'.")
1137
+ check_dyn_shape_value_equal(i, dyn_elem.shape, actual_input.shape)
1138
+
1139
+
1140
+ def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None):
1141
+ """Check type of the element of a iterabel object, execpt dict."""
1142
+ check_value_type(arg_name, arg_value, [list, tuple], prim_name)
1143
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
1144
+ num_types = len(valid_types)
1145
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1146
+ for element in arg_value:
1147
+ if not isinstance(element, tuple(valid_types)):
1148
+ raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1149
+ f"{type_names if num_types > 1 else type_names[0]}, " \
1150
+ f"but got '{element}' with type '{type(element).__name__}'.")
1151
+
1152
+
1153
+ def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None):
1154
+ """Check the type of key and value of a dict."""
1155
+ check_value_type(arg_name, arg_value, [dict], prim_name)
1156
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1157
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types]
1158
+ num_types = len(key_types)
1159
+ for element in arg_value.keys():
1160
+ if not isinstance(element, tuple(key_types)):
1161
+ raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1162
+ f"{type_names if num_types > 1 else type_names[0]}, " \
1163
+ f"but got '{element}' with type '{type(element).__name__}'.")
1164
+
1165
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types]
1166
+ num_types = len(value_types)
1167
+ for element in arg_value.values():
1168
+ if not isinstance(element, tuple(value_types)):
1169
+ raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1170
+ f"{type_names if num_types > 1 else type_names[0]}, " \
1171
+ f"but got '{element}' with type '{type(element).__name__}'.")
1172
+
1173
+
1174
+ def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None):
1175
+ """Check the size and element type of a tuple."""
1176
+ check_value_type(arg_name, arg_value, [tuple], prim_name)
1177
+ check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name)
1178
+ check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name)
1086
1179
 
1087
1180
 
1088
1181
  def check_dyn_shape_value_equal(index, dyn_shape, actual_shape):
1089
1182
  """Check the consistency of dynamic shape and actual input shape."""
1090
1183
  for i, x in enumerate(dyn_shape):
1091
1184
  if x not in (-1, actual_shape[i]):
1092
- raise ValueError(f"The {i}th shape value of `{index}`th actual input args should be `{x}`, but got "
1185
+ raise ValueError(f"The {i}th shape value of `{index}`th actual input args should be `{x}`, but got " \
1093
1186
  f"`{actual_shape[i]}`.")
1094
1187
 
1095
1188
 
@@ -1107,17 +1200,17 @@ def _expand_tuple(n_dimensions):
1107
1200
  if not isinstance(m, tuple):
1108
1201
  if isinstance(m, int) and not isinstance(m, bool):
1109
1202
  return tuple(repeat(m, n_dimensions))
1110
- raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], "
1203
+ raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], " \
1111
1204
  f"but got {type(m)}")
1112
1205
 
1113
1206
  if not len(m) is n_dimensions:
1114
- raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, "
1207
+ raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, " \
1115
1208
  f"but got {m}")
1116
1209
 
1117
1210
  for i in m:
1118
1211
  if not isinstance(i, int) or isinstance(i, bool):
1119
- raise TypeError(f"When expanding an int number to tuple, "
1120
- f"the type of element in input tuple must be a integer, but got {type(i)}.")
1212
+ raise TypeError(f"When expanding an int number to tuple, " \
1213
+ f"the type of element in input tuple must be an integer, but got {type(i)}.")
1121
1214
  return m
1122
1215
 
1123
1216
  return convert
@@ -1153,8 +1246,8 @@ def check_input_data(*data, data_class):
1153
1246
  if not ret:
1154
1247
  data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance(
1155
1248
  data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__)
1156
- raise TypeError(f'The type of input data must be in the Union({data_class_str}, '
1157
- f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), '
1249
+ raise TypeError(f'The type of input data must be in the Union({data_class_str}, ' \
1250
+ f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
1158
1251
  f'but got type {item if item is None else type(item).__name__}.')
1159
1252
 
1160
1253
 
@@ -1208,31 +1301,3 @@ def args_type_check(*type_args, **type_kwargs):
1208
1301
 
1209
1302
 
1210
1303
  _set_record = {}
1211
-
1212
-
1213
- def args_unreset_check(*unreset_args, **unreset_kwargs):
1214
- """Check the entered non repeatable setting properties."""
1215
-
1216
- def unreset_check(func):
1217
- sig = inspect.signature(func)
1218
- bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
1219
-
1220
- @wraps(func)
1221
- def wrapper(*args, **kwargs):
1222
- nonlocal bound_unreset
1223
- bound_values = sig.bind(*args, **kwargs)
1224
- argument_dict = bound_values.arguments
1225
- if "kwargs" in bound_unreset:
1226
- bound_unreset = bound_unreset["kwargs"]
1227
- if "kwargs" in argument_dict:
1228
- argument_dict = argument_dict["kwargs"]
1229
- for name, value in argument_dict.items():
1230
- if name in _set_record.keys():
1231
- raise TypeError("For 'set_context', the parameter '{}' can not be set repeatedly.".format(name))
1232
- if name in bound_unreset:
1233
- _set_record[name] = value
1234
- return func(*args, **kwargs)
1235
-
1236
- return wrapper
1237
-
1238
- return unreset_check