mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/amp.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,17 +16,20 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  from abc import ABC, abstractmethod
19
-
20
- from ._checkparam import Validator as validator
19
+ from mindspore.common import mutable
20
+ from mindspore.ops._primitive_cache import _get_cache_prim
21
+ from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
22
+ from mindspore import _checkparam as validator
21
23
  from .common import dtype as mstype
22
24
  from . import context
23
25
  from . import ops
24
26
  from .ops import constexpr
25
- from .common.api import jit_class
27
+ from .common.api import jit_class, jit
26
28
  from .common.parameter import Parameter
27
29
  from .common.tensor import Tensor
28
30
  from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
29
- from .train.amp import build_train_network, auto_mixed_precision
31
+ from .train.amp import build_train_network, auto_mixed_precision, custom_mixed_precision,\
32
+ get_white_list, get_black_list
30
33
 
31
34
 
32
35
  _hypermap = ops.HyperMap()
@@ -51,46 +54,29 @@ def _grad_scale(scale, grad):
51
54
  return grad * scale.astype(grad.dtype)
52
55
 
53
56
 
54
- def _is_finite(inputs):
55
- if _gpu_target():
56
- return ops.FloatStatus()(inputs)[0] == 0
57
- status = ops.isfinite(inputs)
58
- return status.all()
59
-
60
-
61
- def init_status():
62
- r"""
63
- Returns a Tensor indicating initialized status for overflow detection.
64
-
65
- Note:
66
- Only Ascend need status to capture overflow status, you can also call
67
- this function on GPU or CPU, but the return value is useless.
57
+ @jit
58
+ def _grad_scale_map(scale_value, inputs):
59
+ return _hypermap(_partial(_grad_scale, scale_value), inputs)
68
60
 
69
- Returns:
70
- Tensor, has the shape of `(8,)`.
71
61
 
72
- Supported Platforms:
73
- ``Ascend`` ``GPU`` ``CPU``
62
+ @jit
63
+ def _grad_unscale_map(scale_value, inputs):
64
+ return _hypermap(_partial(_grad_unscale, scale_value), inputs)
74
65
 
75
- Examples:
76
- >>> status = amp.init_status()
77
- """
78
- if _ascend_target():
79
- status = ops.NPUAllocFloatStatus()()
80
- clear_status = ops.NPUClearFloatStatus()(status)
81
- status = ops.depend(status, clear_status)
82
- else:
83
- status = Tensor([0, 0, 0, 0, 0, 0, 0, 0], mstype.float32)
84
66
 
85
- return status
67
+ def _overflow(inputs):
68
+ if _gpu_target():
69
+ return ops.FloatStatus()(inputs)
70
+ status = ops.isfinite(inputs)
71
+ return 1 - status.all()
86
72
 
87
73
 
88
- def all_finite(inputs, status=None):
74
+ def all_finite(inputs):
89
75
  r"""
90
76
  Returns a scalar Tensor indicating whether the inputs are finite.
91
77
 
92
- Note:
93
- This is an experimental interface that is subject to change or deletion.
78
+ .. warning::
79
+ This is an experimental API that is subject to change or deletion.
94
80
 
95
81
  The interface must be used in whole network training scenario to detect
96
82
  whether grads are finite, and the results may be different on different
@@ -98,8 +84,6 @@ def all_finite(inputs, status=None):
98
84
 
99
85
  Args:
100
86
  inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
101
- status (Tensor): the status Tensor for overflow detection, only required on
102
- Ascend. Default: None.
103
87
 
104
88
  Returns:
105
89
  Tensor, a scalar Tensor and the dtype is bool.
@@ -112,16 +96,18 @@ def all_finite(inputs, status=None):
112
96
  >>> output = amp.all_finite(x)
113
97
  """
114
98
  if _ascend_target():
115
- if status is None:
116
- raise ValueError("The status must be initialized on Ascend, but get 'None'.")
99
+ status = Tensor([0] * 8, mstype.int32)
117
100
  status = ops.depend(status, inputs)
118
- get_status = ops.NPUGetFloatStatus()(status)
101
+ get_status = _get_cache_prim(NPUGetFloatStatusV2)()(status)
119
102
  status = ops.depend(status, get_status)
120
- status_finite = status.sum() == 0
121
- _ = ops.NPUClearFloatStatus()(status)
103
+ clear_status = _get_cache_prim(NPUClearFloatStatusV2)()(status)
104
+ get_status = ops.depend(get_status, clear_status)
105
+ status_finite = get_status.equal(Tensor(0, mstype.int32)).all()
122
106
  return status_finite
123
- outputs = _hypermap(_partial(_is_finite), inputs)
124
- return ops.stack(outputs).all()
107
+ outputs = _hypermap(_partial(_overflow), inputs)
108
+ flag_sum = ops.addn(outputs).reshape(())
109
+ _all_finite = ops.less(flag_sum, 1)
110
+ return _all_finite
125
111
 
126
112
 
127
113
  @jit_class
@@ -133,8 +119,11 @@ class LossScaler(ABC):
133
119
  to scale and unscale the loss value and gradients to avoid overflow, `adjust` is used to update the
134
120
  loss scale value.
135
121
 
136
- Note:
137
- This is an experimental interface that is subject to change or deletion.
122
+ For more information, refer to the `tutorials <https://mindspore.cn/tutorials/en/r2.0/advanced/
123
+ mixed_precision.html#loss-scaling>`_.
124
+
125
+ .. warning::
126
+ This is an experimental API that is subject to change or deletion.
138
127
  """
139
128
  @abstractmethod
140
129
  def scale(self, inputs):
@@ -173,8 +162,8 @@ class StaticLossScaler(LossScaler):
173
162
 
174
163
  Scales and unscales loss or gradients by a fixed constant.
175
164
 
176
- Note:
177
- This is an experimental interface that is subject to change or deletion.
165
+ .. warning::
166
+ This is an experimental API that is subject to change or deletion.
178
167
 
179
168
  Args:
180
169
  scale_value (Union(float, int)): The initial loss scale value.
@@ -211,7 +200,8 @@ class StaticLossScaler(LossScaler):
211
200
  Returns:
212
201
  Union(Tensor, tuple(Tensor)), the scaled value.
213
202
  """
214
- return _hypermap(_partial(_grad_scale, self.scale_value), inputs)
203
+ inputs = mutable(inputs)
204
+ return _grad_scale_map(self.scale_value, inputs)
215
205
 
216
206
  def unscale(self, inputs):
217
207
  """
@@ -223,7 +213,8 @@ class StaticLossScaler(LossScaler):
223
213
  Returns:
224
214
  Union(Tensor, tuple(Tensor)), the unscaled value.
225
215
  """
226
- return _hypermap(_partial(_grad_unscale, self.scale_value), inputs)
216
+ inputs = mutable(inputs)
217
+ return _grad_unscale_map(self.scale_value, inputs)
227
218
 
228
219
  def adjust(self, grads_finite):
229
220
  """
@@ -244,8 +235,8 @@ class DynamicLossScaler(LossScaler):
244
235
  `scale_window` steps by `factor` if the grads remain finite, otherwise it reduces
245
236
  the loss scale by `1 / factor` and resets the counter.
246
237
 
247
- Note:
248
- This is an experimental interface that is subject to change or deletion.
238
+ .. warning::
239
+ This is an experimental API that is subject to change or deletion.
249
240
 
250
241
  Args:
251
242
  scale_value (Union(float, int)): The initial loss scale value.
@@ -286,7 +277,8 @@ class DynamicLossScaler(LossScaler):
286
277
  Returns:
287
278
  Union(Tensor, tuple(Tensor)), the scaled value.
288
279
  """
289
- return _hypermap(_partial(_grad_scale, self.scale_value), inputs)
280
+ inputs = mutable(inputs)
281
+ return _grad_scale_map(self.scale_value, inputs)
290
282
 
291
283
  def unscale(self, inputs):
292
284
  """
@@ -298,8 +290,10 @@ class DynamicLossScaler(LossScaler):
298
290
  Returns:
299
291
  Union(Tensor, tuple(Tensor)), the unscaled value.
300
292
  """
301
- return _hypermap(_partial(_grad_unscale, self.scale_value), inputs)
293
+ inputs = mutable(inputs)
294
+ return _grad_unscale_map(self.scale_value, inputs)
302
295
 
296
+ @jit
303
297
  def adjust(self, grads_finite):
304
298
  """
305
299
  Adjust the `scale_value` dependent on whether grads are finite.
@@ -313,7 +307,7 @@ class DynamicLossScaler(LossScaler):
313
307
  grads_finite,
314
308
  ops.select(
315
309
  self.counter == (self.scale_window - 1),
316
- ops.select(_is_finite(scale_mul_factor),
310
+ ops.select(ops.isfinite(scale_mul_factor),
317
311
  scale_mul_factor,
318
312
  self.scale_value),
319
313
  self.scale_value),
@@ -327,5 +321,6 @@ class DynamicLossScaler(LossScaler):
327
321
  __all__ = [
328
322
  "DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager",
329
323
  "build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler",
330
- "auto_mixed_precision", "init_status", "all_finite"
324
+ "auto_mixed_precision", "all_finite", "custom_mixed_precision",
325
+ "get_white_list", "get_black_list"
331
326
  ]
mindspore/boost/boost.py CHANGED
@@ -156,8 +156,8 @@ class AutoBoost:
156
156
 
157
157
  Here:
158
158
 
159
- - pca_mat (array): Shape (k*n), k is part of n_components, n is the size of weight.
160
- - bk (array): Shape (k*k), is the symmetric positive definite matrix in Quasi-Newton method.
159
+ - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight.
160
+ - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method.
161
161
 
162
162
  we need to find the m satisfy:
163
163
 
@@ -27,6 +27,7 @@ from mindspore.common import Tensor
27
27
  from mindspore.common.sparse_tensor import RowTensorInner
28
28
  from mindspore.common.parameter import Parameter, ParameterTuple
29
29
  from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
30
+ from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
30
31
  from mindspore.ops import functional as F
31
32
  from mindspore.ops import composite as C
32
33
  from mindspore.ops import operations as P
@@ -115,7 +116,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
115
116
  sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
116
117
 
117
118
  Inputs:
118
- - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
119
+ - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
119
120
 
120
121
  Outputs:
121
122
  Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
@@ -392,7 +393,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
392
393
  is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.
393
394
 
394
395
  Inputs:
395
- - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
396
+ - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
396
397
 
397
398
  Outputs:
398
399
  Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
@@ -460,6 +461,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
460
461
  self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
461
462
  self.gpu_target = (context.get_context("device_target") == "GPU")
462
463
  self.loss_scaling_manager = None
464
+ self.base0 = Tensor(0, mstype.int32)
465
+ self.reduce_all = P.ReduceAll(keep_dims=False)
466
+ self.reduce_any = P.ReduceAny(keep_dims=False)
467
+ self.equal = P.Equal()
468
+ self.not_equal = P.NotEqual()
463
469
 
464
470
  if self.auto_boost.boost_config.get("loss_scale_group", False):
465
471
  self.enable_enhanced_amp = True
@@ -535,12 +541,13 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
535
541
  bool, overflow value.
536
542
  float, update ratio.
537
543
  """
538
- flag_sum = self.reduce_sum(param, (0,))
544
+ flag_sum = self.equal(self.base0, param)
539
545
  if self.reducer_flag:
540
546
  flag_reduce = self.allreduce(flag_sum)
541
- overflow = self.less_equal(self.base, flag_reduce)
547
+ overflow = not self.reduce_all(flag_reduce)
542
548
  else:
543
- overflow = self.less_equal(self.base, flag_sum)
549
+ overflow = not self.reduce_all(flag_sum)
550
+
544
551
  if overflow:
545
552
  update_ratio = self.reduce_ratio
546
553
  else:
@@ -609,13 +616,11 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
609
616
  The second value is the same as the input of `compute_input`, but contains some information about the
610
617
  execution order.
611
618
  """
612
- status = False
619
+ status = Tensor([0] * 8, mstype.int32)
613
620
  if not self.gpu_target:
614
- # init overflow buffer
615
- status = P.NPUAllocFloatStatus()()
616
621
  status = F.depend(status, pre_cond)
617
622
  # clear overflow buffer
618
- clear_status = P.NPUClearFloatStatus()(status)
623
+ clear_status = NPUClearFloatStatusV2()(status)
619
624
  compute_input = F.depend(compute_input, clear_status)
620
625
  return status, compute_input
621
626
 
@@ -636,22 +641,35 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
636
641
  """
637
642
  if not self.gpu_target:
638
643
  status = F.depend(status, compute_output)
639
- get_status = P.NPUGetFloatStatus()(status)
640
- status = F.depend(status, get_status)
641
- # sum overflow buffer elements, 0:not overflow , >0:overflow
642
- flag_sum = self.reduce_sum(status, (0,))
644
+ get_status = NPUGetFloatStatusV2()(status)
645
+
646
+ if self.is_distributed:
647
+ # sum overflow flag over devices
648
+ flag_reduce = self.allreduce(get_status)
649
+ # get_status not equal to [0]*8 means overflow
650
+ flag = self.not_equal(self.base0, flag_reduce)
651
+ status = F.depend(status, flag)
652
+ # distributed needs to skip allreduce to avoid its overflow affecting the next step
653
+ clear_status = NPUClearFloatStatusV2()(status)
654
+ flag = F.depend(flag, clear_status)
655
+ else:
656
+ status = F.depend(status, get_status)
657
+ clear_status = NPUClearFloatStatusV2()(status)
658
+ get_status = F.depend(get_status, clear_status)
659
+ flag = self.not_equal(self.base0, get_status)
660
+ overflow = self.reduce_any(flag)
643
661
  else:
644
662
  flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
645
663
  flag_sum = P.AddN()(flag_sum)
646
664
  # convert flag_sum to scalar
647
665
  flag_sum = P.Reshape()(flag_sum, (()))
648
666
 
649
- if self.is_distributed:
650
- # sum overflow flag over devices
651
- flag_reduce = self.allreduce(flag_sum)
652
- overflow = self.less_equal(self.base, flag_reduce)
653
- else:
654
- overflow = self.less_equal(self.base, flag_sum)
667
+ if self.is_distributed:
668
+ # sum overflow flag over devices
669
+ flag_reduce = self.allreduce(flag_sum)
670
+ overflow = self.less_equal(self.base, flag_reduce)
671
+ else:
672
+ overflow = self.less_equal(self.base, flag_sum)
655
673
  return overflow
656
674
 
657
675
  def _process_loss_scale(self, overflow):
@@ -688,7 +706,7 @@ class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
688
706
  self.optimizer_loss_scale = [self.parent.count(x) for x in parent_set]
689
707
  self.reduce_ratio = Tensor(1.0 / (2 ** 0.5), mstype.float32)
690
708
  self.growth_ratio = Tensor(2 ** (1.0 / 1000.0), mstype.float32)
691
- self.overflow_status_list = ParameterTuple(Parameter(Tensor(np.zeros(shape=[8]), mstype.float32),
709
+ self.overflow_status_list = ParameterTuple(Parameter(Tensor(np.zeros(shape=[8]), mstype.int32),
692
710
  name='mix_layer_status_{}'.format(x), requires_grad=False)
693
711
  for x in range(loss_scale_number))
694
712
  self.loss_scaling_manager.set_loss_scale_status(loss_scale_number, self.loss_scaling_manager.get_loss_scale())
@@ -102,8 +102,8 @@ class DimReduce(Cell):
102
102
 
103
103
  Here:
104
104
 
105
- - pca_mat (array): Shape (k*n), k is part of n_components, n is the size of weight.
106
- - bk (array): Shape (k*k), is the symmetric positive definite matrix in Quasi-Newton method.
105
+ - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight.
106
+ - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method.
107
107
 
108
108
  we need to find the m satisfy:
109
109
 
@@ -138,7 +138,7 @@ class DimReduce(Cell):
138
138
  - **old_grad** (Tuple(Tensor)) - Tuple of gradient tensors.
139
139
  - **weight** (Tuple(Tensor)) - Tuple of parameters.
140
140
  - **weight_clone** (Tuple(Tensor)) - clone of weight
141
- - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
141
+ - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
142
142
 
143
143
  Outputs:
144
144
  - **loss** (Tensor) - Tensor with shape :math:`()`.
@@ -93,7 +93,7 @@ class GroupLossScaleManager(Cell):
93
93
  >>> boost_level="O1", boost_config_dict=boost_config_dict)
94
94
  >>> # For details about how to build the dataset, please refer to the variable `dataset_train` in tutorial
95
95
  >>> # document on the official website:
96
- >>> # https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/quick_start.html
96
+ >>> # https://www.mindspore.cn/tutorials/zh-CN/r2.0/beginner/quick_start.html
97
97
  >>> dataset = create_custom_dataset()
98
98
  >>> model.train(2, dataset)
99
99
  """
@@ -15,13 +15,13 @@
15
15
  """Top-level reference to dtype of common module."""
16
16
  from __future__ import absolute_import
17
17
  from mindspore.common import dtype
18
- from mindspore.common.api import no_recursive, ms_function, ms_memory_recycle, ms_class, jit, jit_class
18
+ from mindspore.common.api import ms_function, ms_memory_recycle, ms_class, jit, jit_class
19
19
  from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
20
20
  uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
21
21
  float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
22
22
  uint, number, tensor, string, type_none, tensor_type, Int, \
23
23
  complex64, complex128, dtype_to_nptype, _null, _null_type, \
24
- dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
24
+ dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype, QuantDtype
25
25
  from mindspore.common.dump import set_dump
26
26
  from mindspore.common.parameter import Parameter, ParameterTuple
27
27
  from mindspore.common.seed import set_seed, get_seed
@@ -29,7 +29,6 @@ from mindspore.common.tensor import Tensor
29
29
  from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor
30
30
  from mindspore.common.mutable import mutable
31
31
  from mindspore.common.jit_config import JitConfig
32
- from mindspore.common._utils import update_and_return_dict
33
32
 
34
33
  # symbols from dtype
35
34
  __all__ = [
@@ -50,7 +49,7 @@ __all__ = [
50
49
  "number", "tensor",
51
50
  "string", "type_none",
52
51
  "_null",
53
- "tensor_type",
52
+ "tensor_type", "QuantDtype",
54
53
  "Type", "Int", "_null_type",
55
54
  "complex64", "complex128",
56
55
  # __method__ from dtype
@@ -60,12 +59,11 @@ __all__ = [
60
59
 
61
60
  __all__.extend([
62
61
  "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
63
- "no_recursive", "ms_function", "ms_class", 'jit', 'jit_class', # api
62
+ "ms_function", "ms_class", 'jit', 'jit_class', # api
64
63
  "Parameter", "ParameterTuple", # parameter
65
64
  "dtype",
66
65
  "set_seed", "get_seed", # random seed
67
66
  "set_dump",
68
67
  "ms_memory_recycle",
69
68
  "mutable", "JitConfig",
70
- "update_and_return_dict",
71
69
  ])
@@ -15,6 +15,7 @@
15
15
  """Providing decorators."""
16
16
 
17
17
  from __future__ import absolute_import
18
+ from functools import wraps
18
19
  from mindspore import log
19
20
 
20
21
 
@@ -31,6 +32,7 @@ def deprecated(version, substitute, use_substitute_name=False):
31
32
  """
32
33
 
33
34
  def decorate(func):
35
+ @wraps(func)
34
36
  def wrapper(*args, **kwargs):
35
37
  cls = getattr(args[0], "__class__", None) if args else None
36
38
  name = cls.__name__ if cls else func.__name__
@@ -0,0 +1,55 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Registry MSAdapter config."""
17
+
18
+ from mindspore.common.tensor import Tensor
19
+
20
+
21
+ class Registry:
22
+ """Registry class for ms adapter."""
23
+
24
+ def __init__(self):
25
+ self._tensor = None
26
+ self._convert_map = {}
27
+
28
+ @property
29
+ def tensor(self):
30
+ """Return the registered tensor."""
31
+ if self._tensor is None:
32
+ raise ValueError("Before using Tensor in MSAdapter, please call 'set_adapter_config'.")
33
+ return self._tensor
34
+
35
+ @property
36
+ def convert_map(self):
37
+ """Return the registered convert map."""
38
+ return self._convert_map
39
+
40
+ def register_tensor(self, value):
41
+ """Register the tensor of ms adapter."""
42
+ if self._tensor is not None:
43
+ raise ValueError("Repeated registration of tensor in ms adapter config.")
44
+ if not issubclass(value, Tensor):
45
+ raise ValueError(f"The tensor definition here should be a subclass of ms.Tensor, but got {value}.")
46
+ self._tensor = value
47
+
48
+ def register_convert_map(self, value):
49
+ """Register the convert map of ms adapter."""
50
+ if not isinstance(value, dict):
51
+ raise ValueError(f"Expect a dict type, but got {type(value)}.")
52
+ self._convert_map = value
53
+
54
+
55
+ ms_adapter_registry = Registry()