mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,21 @@ from __future__ import absolute_import
18
18
  from functools import partial
19
19
 
20
20
  import mindspore.context as context
21
- from mindspore._checkparam import Validator as validator
22
- from mindspore._checkparam import Rel
21
+ from mindspore import _checkparam as validator
23
22
  from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
24
23
  from mindspore.common import dtype as mstype
24
+ from mindspore.common.dtype import QuantDtype
25
25
 
26
- if context.get_context('device_target') == "Ascend":
26
+
27
+ def _support_te():
28
+ try:
29
+ import te # pylint: disable=unused-import
30
+ return True
31
+ # pylint: disable=broad-except
32
+ except Exception:
33
+ return False
34
+
35
+ if context.get_context('device_target') == "Ascend" and _support_te():
27
36
  import mindspore.ops._op_impl._custom_op
28
37
 
29
38
  __all__ = ["MinMaxUpdatePerLayer",
@@ -61,10 +70,71 @@ __all__ = ["MinMaxUpdatePerLayer",
61
70
  "ActsULQInputGrad",
62
71
  "ActULQClampMinGrad",
63
72
  "ActULQClampMaxGrad",
64
- "WtsARQ"
73
+ "WtsARQ",
74
+ "FakeQuantParam",
65
75
  ]
66
76
 
67
77
 
78
+ class FakeQuantParam(Primitive):
79
+ r"""
80
+ Define the operation for storing quant parameter. This operation passes through input tensor to output tensor
81
+ without any calculation.
82
+
83
+ Args:
84
+ quant_dtype (QuantDtype) - The valid data type of the input tensor.
85
+ quant_algo_name (str) - Define the name of quant algorithm. Use
86
+ `FakeQuantParam.attr_value_linear_quant_algo_name` for linear quantization specially.
87
+ is_per_channel (bool) - Define whether quant parameter is per-channel or per-layer.
88
+ kwargs (dict): Other quant parameter in key-value form. Please use classmethod `linear_quant_param` to create a
89
+ linear quantization specially because key of scale and zero-point is pre-defined by MindSpore.
90
+
91
+ Inputs:
92
+ - *input_x* (Tensor) : Input tensor.
93
+
94
+ Outputs:
95
+ - Tensor: Output tensor same with `input_x`.
96
+
97
+ Examples:
98
+ >>> input_tensor = mindspore.Tensor(numpy.random.rand(1, 16, 5, 5), mindspore.dtype.float32)
99
+ >>> fake_quant_param_op = FakeQuantParam.linear_quant_param(mindspore.common.dtype.QuantDtype.INT8,
100
+ >>> 0.5, 1)
101
+ >>> output_tensor = fake_quant_param_op(input_tensor)
102
+ """
103
+
104
+ attr_key_linear_quant_scale = "linear_quant_scale"
105
+ attr_key_linear_quant_zero_point = "linear_quant_zero_point"
106
+
107
+ attr_value_linear_quant_algo_name = "linear_quant_algo"
108
+
109
+ @prim_attr_register
110
+ def __init__(self, quant_dtype: QuantDtype, quant_algo_name: str, is_per_channel: bool, **kwargs):
111
+ self.add_prim_attr("quant_algo_name", quant_algo_name)
112
+ self.add_prim_attr("is_per_channel", is_per_channel)
113
+ self.add_prim_attr("quant_dtype", quant_dtype.value())
114
+ for key, value in kwargs.items():
115
+ self.add_prim_attr(key, value)
116
+
117
+ @classmethod
118
+ def linear_quant_param(cls, quant_dtype, scale, zp, is_per_channel=False, **kwargs):
119
+ """
120
+ Create a linear quantization operator based on scale and zero-point parameter.
121
+ """
122
+ validator.check_value_type("scale", scale, [float, tuple, list], "FakeQuantParam")
123
+ if isinstance(scale, float):
124
+ scale_list = [scale]
125
+ else:
126
+ scale_list = scale
127
+ validator.check_value_type("zero_point", zp, [int, tuple, list], "FakeQuantParam")
128
+ if isinstance(zp, int):
129
+ zp_list = [zp]
130
+ else:
131
+ zp_list = zp
132
+ validator.check_value_type("is_per_channel", is_per_channel, [bool], "FakeQuantParam")
133
+ kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale_list
134
+ kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp_list
135
+ return cls(quant_dtype, FakeQuantParam.attr_value_linear_quant_algo_name, is_per_channel, **kwargs)
136
+
137
+
68
138
  class MinMaxUpdatePerLayer(PrimitiveWithInfer):
69
139
  r"""
70
140
  Updates min and max per layer.
@@ -99,14 +169,14 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
99
169
  f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
100
170
 
101
171
  self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
102
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
172
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
103
173
  self.init_prim_io_names(inputs=['x', 'min', 'max'],
104
174
  outputs=['min_up', 'max_up'])
105
175
 
106
176
  def infer_shape(self, x_shape, min_shape, max_shape):
107
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
177
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
108
178
  validator.check("min shape", min_shape, "max shape",
109
- max_shape, Rel.EQ, self.name)
179
+ max_shape, validator.EQ, self.name)
110
180
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
111
181
  return min_shape, max_shape
112
182
 
@@ -155,9 +225,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
155
225
  f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
156
226
 
157
227
  self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
158
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
228
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
159
229
  if self.is_ascend:
160
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
230
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
231
+ 'channel_axis', self.name)
161
232
  else:
162
233
  self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
163
234
  self.init_prim_io_names(
@@ -167,9 +238,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
167
238
  if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
168
239
  raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
169
240
  if not self.is_ascend:
170
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
241
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
171
242
  validator.check("min shape", min_shape, "max shape",
172
- max_shape, Rel.EQ, self.name)
243
+ max_shape, validator.EQ, self.name)
173
244
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
174
245
  return min_shape, max_shape
175
246
 
@@ -225,9 +296,9 @@ class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
225
296
  outputs=['out'])
226
297
 
227
298
  def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
228
- validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
229
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
230
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
299
+ validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
300
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
301
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
231
302
  return input_x_shape
232
303
 
233
304
  def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
@@ -266,9 +337,9 @@ class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer):
266
337
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
267
338
 
268
339
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
269
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
270
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
271
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
340
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
341
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
342
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
272
343
  return dout_shape, alpha_shape
273
344
 
274
345
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -297,9 +368,9 @@ class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer):
297
368
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
298
369
 
299
370
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
300
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
301
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
302
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
371
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
372
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
373
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
303
374
  return dout_shape, dout_shape
304
375
 
305
376
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -375,7 +446,8 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
375
446
  self.training = validator.check_value_type(
376
447
  'training', training, (bool,), self.name)
377
448
  if self.is_ascend:
378
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
449
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
450
+ 'channel_axis', self.name)
379
451
  else:
380
452
  self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
381
453
  self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
@@ -385,12 +457,12 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
385
457
  if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank:
386
458
  raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
387
459
  if not self.is_ascend:
388
- validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
460
+ validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
389
461
  if len(input_x_shape) == 1:
390
462
  self.channel_axis = 0
391
463
 
392
464
  validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name)
393
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
465
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
394
466
  return input_x_shape
395
467
 
396
468
  def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
@@ -431,7 +503,7 @@ class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer):
431
503
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
432
504
 
433
505
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
434
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
506
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
435
507
  return dout_shape, alpha_shape
436
508
 
437
509
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -462,9 +534,9 @@ class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer):
462
534
  inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
463
535
 
464
536
  def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
465
- validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
466
- validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
467
- validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
537
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
538
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
539
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
468
540
  return dout_shape, dout_shape
469
541
 
470
542
  def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
@@ -528,7 +600,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
528
600
  num_bits=8,
529
601
  narrow_range=False):
530
602
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
531
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
603
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
532
604
  self.narrow_range = validator.check_value_type(
533
605
  'narrow_range', narrow_range, (bool,), self.name)
534
606
 
@@ -540,9 +612,9 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
540
612
  raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
541
613
 
542
614
  def infer_shape(self, x_shape, min_shape, max_shape):
543
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
544
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
545
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
615
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
616
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
617
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
546
618
  self.check_broadcast(min_shape, x_shape)
547
619
  return x_shape
548
620
 
@@ -592,7 +664,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
592
664
  num_bits=8,
593
665
  narrow_range=False):
594
666
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
595
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
667
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
596
668
  self.narrow_range = validator.check_value_type(
597
669
  'narrow_range', narrow_range, (bool,), self.name)
598
670
 
@@ -604,10 +676,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
604
676
  raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
605
677
 
606
678
  def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
607
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
608
- validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
609
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
610
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
679
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
680
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
681
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
682
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
611
683
  self.check_broadcast(min_shape, x_shape)
612
684
  return x_shape, min_shape, max_shape
613
685
 
@@ -651,15 +723,15 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
651
723
  num_bits=8,
652
724
  narrow_range=False):
653
725
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
654
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
726
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
655
727
  self.narrow_range = validator.check_value_type(
656
728
  'narrow_range', narrow_range, (bool,), self.name)
657
729
 
658
730
  def infer_shape(self, x_shape, min_shape, max_shape):
659
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
660
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
661
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
662
- validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
731
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
732
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
733
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
734
+ validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
663
735
  return x_shape
664
736
 
665
737
  def infer_dtype(self, x_type, min_type, max_type):
@@ -709,16 +781,16 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
709
781
  num_bits=8,
710
782
  narrow_range=False):
711
783
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
712
- self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
784
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
713
785
  self.narrow_range = validator.check_value_type(
714
786
  'narrow_range', narrow_range, (bool,), self.name)
715
787
 
716
788
  def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
717
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
718
- validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
719
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
720
- validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
721
- validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
789
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
790
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
791
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
792
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
793
+ validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
722
794
  return x_shape, min_shape, max_shape
723
795
 
724
796
  def infer_dtype(self, dout_type, x_type, min_type, max_type):
@@ -807,15 +879,15 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
807
879
  self.narrow_range = validator.check_value_type(
808
880
  'narrow_range', narrow_range, (bool,), self.name)
809
881
  self.training = validator.check_value_type('training', training, (bool,), self.name)
810
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
882
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
811
883
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
812
884
  self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
813
885
  self.init_prim_io_names(inputs=['x', 'min', 'max'],
814
886
  outputs=['out'])
815
887
 
816
888
  def infer_shape(self, x_shape, min_shape, max_shape):
817
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
818
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
889
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
890
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
819
891
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
820
892
  return x_shape
821
893
 
@@ -861,9 +933,9 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
861
933
 
862
934
  def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
863
935
  validator.check("dout shape", dout_shape, "x shape",
864
- x_shape, Rel.EQ, self.name)
936
+ x_shape, validator.EQ, self.name)
865
937
  validator.check("min shape", min_shape, "max shape",
866
- max_shape, Rel.EQ, self.name)
938
+ max_shape, validator.EQ, self.name)
867
939
  validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
868
940
  return dout_shape
869
941
 
@@ -933,11 +1005,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
933
1005
  'narrow_range', narrow_range, (bool,), self.name)
934
1006
  self.training = validator.check_value_type(
935
1007
  'training', training, (bool,), self.name)
936
- self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
1008
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
937
1009
  self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
938
1010
  self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
939
1011
  if self.is_ascend:
940
- self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
1012
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
1013
+ 'channel_axis', self.name)
941
1014
  else:
942
1015
  self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
943
1016
  self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
@@ -946,10 +1019,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
946
1019
  if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
947
1020
  raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
948
1021
  if not self.is_ascend:
949
- validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
1022
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
950
1023
  if len(x_shape) == 1:
951
1024
  self.channel_axis = 0
952
- validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
1025
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
953
1026
  validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
954
1027
  validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
955
1028
  return x_shape
@@ -1045,7 +1118,7 @@ class BatchNormFold(PrimitiveWithInfer):
1045
1118
  @prim_attr_register
1046
1119
  def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1047
1120
  """Initialize batch norm fold layer"""
1048
- self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1121
+ self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
1049
1122
  self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1050
1123
  self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1051
1124
  self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@@ -1054,8 +1127,9 @@ class BatchNormFold(PrimitiveWithInfer):
1054
1127
  outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
1055
1128
 
1056
1129
  def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
1057
- validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
1058
- validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
1130
+ validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
1131
+ validator.check("mean_shape[0]", mean_shape[0], "input channel",
1132
+ x_shape[self.channel_axis], validator.EQ, self.name)
1059
1133
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1060
1134
  return mean_shape, mean_shape, mean_shape, mean_shape
1061
1135
 
@@ -1096,13 +1170,13 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
1096
1170
  def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
1097
1171
  global_step_shape):
1098
1172
  validator.check("d_batch_mean shape", d_batch_mean_shape,
1099
- "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
1173
+ "d_batch_std shape", d_batch_std_shape, validator.EQ, self.name)
1100
1174
  validator.check("d_batch_mean shape", d_batch_mean_shape,
1101
- "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1175
+ "batch_mean shape", batch_mean_shape, validator.EQ, self.name)
1102
1176
  validator.check("d_batch_mean shape", d_batch_mean_shape,
1103
- "batch_std shape", batch_std_shape, Rel.EQ, self.name)
1177
+ "batch_std shape", batch_std_shape, validator.EQ, self.name)
1104
1178
  validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
1105
- "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
1179
+ "input channel", x_shape[self.channel_axis], validator.EQ, self.name)
1106
1180
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1107
1181
  return x_shape
1108
1182
 
@@ -1147,9 +1221,10 @@ class CorrectionMul(PrimitiveWithInfer):
1147
1221
  outputs=['out'])
1148
1222
 
1149
1223
  def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
1150
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1224
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1225
+ running_std_shape, validator.EQ, self.name)
1151
1226
  validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1152
- Rel.EQ, self.name)
1227
+ validator.EQ, self.name)
1153
1228
  return x_shape
1154
1229
 
1155
1230
  def infer_dtype(self, x_type, batch_std_type, running_std_type):
@@ -1181,11 +1256,11 @@ class CorrectionMulGrad(PrimitiveWithInfer):
1181
1256
  outputs=['dx', 'mul_dx'])
1182
1257
 
1183
1258
  def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
1184
- validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
1259
+ validator.check("dout shape", dout_shape, "x_shape x", x_shape, validator.EQ, self.name)
1185
1260
  validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
1186
- Rel.EQ, self.name)
1261
+ validator.EQ, self.name)
1187
1262
  validator.check("running_std_shape[0]", running_std_shape[0],
1188
- "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
1263
+ "dout channel size", dout_shape[self.channel_axis], validator.EQ, self.name)
1189
1264
  if context.get_context('device_target') == "Ascend":
1190
1265
  return x_shape, x_shape
1191
1266
  return x_shape, gamma_shape
@@ -1271,14 +1346,16 @@ class BatchNormFold2(PrimitiveWithInfer):
1271
1346
 
1272
1347
  def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
1273
1348
  running_mean_shape, global_step_shape):
1274
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1275
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1276
- validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
1349
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1350
+ running_std_shape, validator.EQ, self.name)
1351
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1352
+ batch_mean_shape, validator.EQ, self.name)
1353
+ validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
1277
1354
  validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1278
- Rel.EQ, self.name)
1279
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
1355
+ validator.EQ, self.name)
1356
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
1280
1357
  validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1281
- Rel.EQ, self.name)
1358
+ validator.EQ, self.name)
1282
1359
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1283
1360
  return x_shape
1284
1361
 
@@ -1321,13 +1398,15 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
1321
1398
  def infer_shape(self, dout_shape, x_shape, gamma_shape,
1322
1399
  batch_std_shape, batch_mean_shape,
1323
1400
  running_std_shape, running_mean_shape, global_step_shape):
1324
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1325
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1401
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1402
+ batch_mean_shape, validator.EQ, self.name)
1403
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1404
+ running_std_shape, validator.EQ, self.name)
1326
1405
  validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1327
- Rel.EQ, self.name)
1328
- validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
1406
+ validator.EQ, self.name)
1407
+ validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
1329
1408
  validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1330
- Rel.EQ, self.name)
1409
+ validator.EQ, self.name)
1331
1410
  validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1332
1411
  return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
1333
1412
 
@@ -1358,7 +1437,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
1358
1437
  def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1359
1438
  """Initialize _BatchNormFold layer"""
1360
1439
  from mindspore.ops._op_impl._custom_op import batchnorm_fold
1361
- self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1440
+ self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
1362
1441
  self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1363
1442
  self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1364
1443
  self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@@ -1368,8 +1447,8 @@ class BatchNormFoldD(PrimitiveWithInfer):
1368
1447
  'mean_updated', 'variance_updated'])
1369
1448
 
1370
1449
  def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
1371
- validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
1372
- validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
1450
+ validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
1451
+ validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], validator.EQ, self.name)
1373
1452
  return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
1374
1453
 
1375
1454
  def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
@@ -1439,12 +1518,14 @@ class BatchNormFold2D(PrimitiveWithInfer):
1439
1518
  outputs=['y'])
1440
1519
 
1441
1520
  def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
1442
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1443
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1444
- validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
1445
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
1521
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1522
+ running_std_shape, validator.EQ, self.name)
1523
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1524
+ batch_mean_shape, validator.EQ, self.name)
1525
+ validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
1526
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
1446
1527
  validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1447
- Rel.EQ, self.name)
1528
+ validator.EQ, self.name)
1448
1529
  return x_shape
1449
1530
 
1450
1531
  def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
@@ -1469,11 +1550,13 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
1469
1550
 
1470
1551
  def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
1471
1552
  batch_mean_shape, running_std_shape):
1472
- validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1473
- validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1474
- validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
1553
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1554
+ batch_mean_shape, validator.EQ, self.name)
1555
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1556
+ running_std_shape, validator.EQ, self.name)
1557
+ validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
1475
1558
  validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1476
- Rel.EQ, self.name)
1559
+ validator.EQ, self.name)
1477
1560
  return gamma_shape, gamma_shape, gamma_shape, dout_shape
1478
1561
 
1479
1562
  def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
@@ -1505,7 +1588,7 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
1505
1588
  outputs=['dout_reduce', 'dout_x_reduce'])
1506
1589
 
1507
1590
  def infer_shape(self, dout_shape, x_shape):
1508
- validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
1591
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
1509
1592
  return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
1510
1593
 
1511
1594
  def infer_dtype(self, dout_type, x_type):
@@ -1547,17 +1630,17 @@ class ActsULQ(PrimitiveWithInfer):
1547
1630
  def __init__(self, fixed_min=False, num_bits=8):
1548
1631
  validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
1549
1632
  validator.check_value_type("num_bits", num_bits, [int], self.name)
1550
- validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
1633
+ validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
1551
1634
 
1552
1635
  def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
1553
1636
  """infer shape of primitive"""
1554
- validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name)
1555
- validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name)
1637
+ validator.check_int(len(clamp_min_shape), len(x_shape), validator.EQ, "dims of clamp_min", self.name)
1638
+ validator.check_int(len(clamp_max_shape), len(x_shape), validator.EQ, "dims of clamp_max", self.name)
1556
1639
 
1557
1640
  x_shape_len = len(x_shape)
1558
1641
  for i in range(x_shape_len):
1559
- validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name)
1560
- validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name)
1642
+ validator.check_int(clamp_min_shape[i], 1, validator.EQ, "dims of clamp_min", self.name)
1643
+ validator.check_int(clamp_max_shape[i], 1, validator.EQ, "dims of clamp_max", self.name)
1561
1644
 
1562
1645
  return x_shape, x_shape, x_shape, x_shape
1563
1646
 
@@ -1698,12 +1781,12 @@ class WtsARQ(PrimitiveWithInfer):
1698
1781
  @prim_attr_register
1699
1782
  def __init__(self, num_bits, offset_flag):
1700
1783
  validator.check_value_type("num_bits", num_bits, [int], self.name)
1701
- validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
1784
+ validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
1702
1785
  validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
1703
1786
 
1704
1787
  def infer_shape(self, w_shape, w_min_shape, w_max_shape):
1705
- validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name)
1706
- validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name)
1788
+ validator.check_int(len(w_min_shape), len(w_shape), validator.EQ, "dims of w_min", self.name)
1789
+ validator.check_int(len(w_max_shape), len(w_shape), validator.EQ, "dims of w_max", self.name)
1707
1790
  return w_shape
1708
1791
 
1709
1792
  def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
@@ -1753,11 +1836,13 @@ class IFMR(Primitive):
1753
1836
  @prim_attr_register
1754
1837
  def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
1755
1838
  with_offset=True):
1839
+ self.init_prim_io_names(
1840
+ inputs=['data', 'data_min', 'data_max', 'cumsum'], outputs=['scale', 'offset'])
1756
1841
  validator.check_value_type("min_percentile", min_percentile, [float], self.name)
1757
1842
  validator.check_value_type("max_percentile", max_percentile, [float], self.name)
1758
1843
  validator.check_value_type("search_range", search_range, [list, tuple], self.name)
1759
1844
  for item in search_range:
1760
1845
  validator.check_positive_float(item, "item of search_range", self.name)
1761
- validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
1846
+ validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], validator.GE, self.name)
1762
1847
  validator.check_value_type("search_step", search_step, [float], self.name)
1763
1848
  validator.check_value_type("offset_flag", with_offset, [bool], self.name)