mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@
14
14
  # ============================================================================
15
15
 
16
16
  """Spectral operators."""
17
- from mindspore._checkparam import Validator as validator
17
+ from mindspore import _checkparam as validator
18
18
  from mindspore.common import dtype as mstype
19
19
  from mindspore.ops.primitive import Primitive, prim_attr_register
20
20
 
@@ -23,10 +23,26 @@ class BartlettWindow(Primitive):
23
23
  r"""
24
24
  Bartlett window function.
25
25
 
26
+ .. warning::
27
+ This is an experimental API that is subject to change or deletion.
28
+
26
29
  Refer to :func:`mindspore.ops.bartlett_window` for more details.
27
30
 
31
+ Args:
32
+ periodic (bool, optional): If True, returns a window to be used as periodic function.
33
+ If False, return a symmetric window. Default: True.
34
+ dtype (mindspore.dtype, optional): The desired datatype of returned tensor.
35
+ Only float16, float32 and float64 are allowed. Default: mstype.float32.
36
+
37
+ Inputs:
38
+ - **window_length** (Tensor) - The size of returned window, with data type int32, int64.
39
+ The input data should be an integer with a value of [0, 1000000].
40
+
41
+ Outputs:
42
+ A 1-D tensor of size `window_length` containing the window. Its datatype is set by the attr `dtype`.
43
+
28
44
  Supported Platforms:
29
- ``GPU`` ``CPU``
45
+ ``Ascend`` ``GPU`` ``CPU``
30
46
 
31
47
  Examples:
32
48
  >>> window_length = Tensor(5, mstype.int32)
@@ -50,10 +66,26 @@ class BlackmanWindow(Primitive):
50
66
  r"""
51
67
  Blackman window function.
52
68
 
69
+ .. warning::
70
+ This is an experimental API that is subject to change or deletion.
71
+
53
72
  Refer to :func:`mindspore.ops.blackman_window` for more details.
54
73
 
74
+ Args:
75
+ periodic (bool, optional): If True, returns a window to be used as periodic function.
76
+ If False, return a symmetric window. Default: True.
77
+ dtype (mindspore.dtype, optional): the desired data type of returned tensor.
78
+ Only float16, float32 and float64 is allowed. Default: mstype.float32.
79
+
80
+ Inputs:
81
+ - **window_length** (Tensor) - the size of returned window, with data type int32, int64.
82
+ The input data should be an integer with a value of [0, 1000000].
83
+
84
+ Outputs:
85
+ A 1-D tensor of size `window_length` containing the window. Its datatype is set by the attr `dtype`.
86
+
55
87
  Supported Platforms:
56
- ``GPU`` ``CPU``
88
+ ``Ascend`` ``GPU`` ``CPU``
57
89
 
58
90
  Examples:
59
91
  >>> window_length = Tensor(10, mindspore.int32)
@@ -24,8 +24,9 @@ from mindspore.parallel._utils import _is_in_auto_parallel_mode, _is_in_data_par
24
24
  from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
25
25
  from mindspore.common.parameter import Parameter
26
26
  from mindspore.common.api import _pynative_executor
27
+ from mindspore.common._stub_tensor import _convert_stub
27
28
  from mindspore._c_expression import Primitive_, prim_type
28
- from mindspore._checkparam import Validator
29
+ from mindspore import _checkparam as Validator
29
30
  from mindspore.ops import signature as sig
30
31
 
31
32
 
@@ -486,10 +487,10 @@ class PrimitiveWithCheck(Primitive):
486
487
  ... def __init__(self):
487
488
  ... pass
488
489
  ... def check_shape(self, input_x):
489
- ... validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
490
+ ... Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name)
490
491
  ...
491
492
  ... def check_dtype(self, input_x):
492
- ... validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
493
+ ... Validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
493
494
  ...
494
495
  >>> # init a Primitive obj
495
496
  >>> add = Flatten()
@@ -501,10 +502,18 @@ class PrimitiveWithCheck(Primitive):
501
502
 
502
503
  def __check__(self, *args):
503
504
  """Checking the input shape and the input type of ops is valid """
504
- tracks = ['dtype', 'shape']
505
- for track in tracks:
506
- fn = getattr(self, 'check_' + track)
507
- fn(*(x[track] for x in args))
505
+ check_dtype_fn = getattr(self, 'check_dtype')
506
+ check_dtype_fn(*(x['dtype'] for x in args))
507
+
508
+ is_shape_known = True
509
+ for x in args:
510
+ shape = x['shape']
511
+ if shape is None or -1 in shape or -2 in shape:
512
+ is_shape_known = False
513
+ break
514
+ if is_shape_known:
515
+ check_shape_fn = getattr(self, 'check_shape')
516
+ check_shape_fn(*(x['shape'] for x in args))
508
517
 
509
518
  def _clone(self):
510
519
  """
@@ -731,6 +740,24 @@ def prim_attr_register(fn):
731
740
  return deco
732
741
 
733
742
 
743
+ def _check_contains_variable(item_dtype, item_value):
744
+ """
745
+ Check whether the item is or contains variable.
746
+ """
747
+ if isinstance(item_value, (list, tuple)):
748
+ for i, element in enumerate(item_value):
749
+ if _check_contains_variable(item_dtype[i], element):
750
+ return True
751
+ elif isinstance(item_value, dict):
752
+ for i in range(len(item_value)):
753
+ if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
754
+ return True
755
+ for i in range(len(item_value)):
756
+ if _check_contains_variable(item_dtype[i], list(item_value.values())[i]):
757
+ return True
758
+ return item_dtype is not None and item_value is None
759
+
760
+
734
761
  def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
735
762
  """
736
763
  Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
@@ -778,13 +805,14 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
778
805
  PrimitiveWithInfer.__init__(self, op_name)
779
806
  self.set_const_prim(True)
780
807
  self.fn = fn
808
+ self.add_prim_attr('constexpr_prim', True)
781
809
  if not reuse_result:
782
810
  self.add_prim_attr('forbid_reuse_result', True)
783
811
 
784
812
  def __infer__(self, *args):
785
813
  value_args = []
786
814
  for item in args:
787
- if (item["dtype"] is not None and item["value"] is None and check):
815
+ if _check_contains_variable(item["dtype"], item["value"]) and check:
788
816
  logger.warning("The \"" + self.name + "\" is a constexpr function." \
789
817
  " The input arguments must be all constant value.")
790
818
  value_args.append(item["value"])
@@ -802,8 +830,82 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
802
830
  return deco
803
831
 
804
832
 
805
- @_wrap_func
833
+ def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
834
+ """
835
+ _primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains
836
+ variable, the function will be compiled as graph.
837
+
838
+ _primexpr is only for internal use.
839
+
840
+ Args:
841
+ fn (function): A `fn` use as the infer_value of the output operator. Default: None.
842
+ get_instance (bool): If true, return the instance of operator,
843
+ otherwise return the operator class. Default: True.
844
+ name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
845
+ reuse_result (bool): If true, the operator will be executed once and reuse the result next time,
846
+ otherwise the operator will always be executed. Default: True.
847
+ """
848
+ def deco(fn):
849
+ """Decorator for CompileOp."""
850
+
851
+ class CompileOp(PrimitiveWithInfer):
852
+ """
853
+ CompileOp is a temporary operator used to execute the constexpr function.
854
+ """
855
+
856
+ def __init__(self):
857
+ op_name = name if name else fn.__name__
858
+ PrimitiveWithInfer.__init__(self, op_name)
859
+ self.set_const_prim(True)
860
+ self.fn = fn
861
+ self.add_prim_attr('constexpr_prim', True)
862
+ if not reuse_result:
863
+ self.add_prim_attr('forbid_reuse_result', True)
864
+
865
+ def __infer__(self, *args):
866
+ value_args = []
867
+ for item in args:
868
+ if _check_contains_variable(item["dtype"], item["value"]):
869
+ return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
870
+ value_args.append(item["value"])
871
+ return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
872
+
873
+ def __call__(self, *args, **kwargs):
874
+ return fn(*args, **kwargs)
875
+
876
+ if get_instance:
877
+ return CompileOp()
878
+ return CompileOp
879
+
880
+ if fn is not None:
881
+ return deco(fn)
882
+ return deco
883
+
884
+
885
+ _RUN_OP_ASYNC = True
886
+
887
+
806
888
  def _run_op(obj, op_name, args):
807
889
  """Single op execution function supported by ge in PyNative mode."""
890
+ if _RUN_OP_ASYNC:
891
+ stub = _pynative_executor.run_op_async(obj, args)
892
+ return _convert_stub(stub)
893
+ return _run_op_sync(obj, op_name, args)
894
+
895
+
896
+ @_wrap_func
897
+ def _run_op_sync(obj, op_name, args):
898
+ """Single op execution function in synchronous mode."""
808
899
  output = _pynative_executor.real_run_op(obj, op_name, args)
809
900
  return output
901
+
902
+
903
+ class _PrimitiveC(Primitive):
904
+ def __init__(self, name, attrs):
905
+ super().__init__(name)
906
+ for key, value in attrs.items():
907
+ super().add_prim_attr(key, value)
908
+
909
+
910
+ def _get_primitivec(name, attrs):
911
+ return _PrimitiveC(name, attrs)
@@ -15,13 +15,15 @@
15
15
  """Context of auto parallel"""
16
16
  from __future__ import absolute_import
17
17
  import os
18
+ import copy
18
19
  import threading
19
20
  from mindspore import context
20
21
  import mindspore.log as logger
21
22
  from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
22
23
  from mindspore.parallel._ps_context import _is_role_pserver
23
24
  from mindspore._c_expression import AutoParallelContext
24
- from mindspore._checkparam import args_type_check, Validator
25
+ from mindspore._checkparam import args_type_check
26
+ from mindspore import _checkparam as Validator
25
27
 
26
28
  _MAX_GROUP_NAME_LEN = 127
27
29
  _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
@@ -40,6 +42,18 @@ class _ParallelFusionConfig:
40
42
  AUTO = "auto"
41
43
  INDEX = "index"
42
44
  SIZE = "size"
45
+ OPENSTATE = "openstate"
46
+ CONFIG = {"openstate": True,
47
+ "allreduce": {"mode": "auto", "config": None},
48
+ "allgather": {"mode": "auto", "config": None},
49
+ "reducescatter": {"mode": "auto", "config": None}}
50
+
51
+ @classmethod
52
+ def reset(cls):
53
+ cls.CONFIG = {"openstate": True,
54
+ "allreduce": {"mode": "auto", "config": None},
55
+ "allgather": {"mode": "auto", "config": None},
56
+ "reducescatter": {"mode": "auto", "config": None}}
43
57
 
44
58
 
45
59
  class _ParallelOptimizerConfig:
@@ -117,6 +131,9 @@ class _AutoParallelContext:
117
131
  KeyError: When key of comm_fusion is not 'allreduce'.
118
132
  """
119
133
  self.check_context_handle()
134
+ config = copy.deepcopy(config)
135
+ if _ParallelFusionConfig.OPENSTATE not in config.keys():
136
+ config[_ParallelFusionConfig.OPENSTATE] = True
120
137
  for key in list(config.keys()):
121
138
  if key == _ParallelFusionConfig.ALLREDUCE:
122
139
  self._set_allreduce_comm_fusion(config[key])
@@ -124,91 +141,18 @@ class _AutoParallelContext:
124
141
  self._set_allgather_comm_fusion(config[key], key)
125
142
  elif key == _ParallelFusionConfig.REDUCESCATTER:
126
143
  self._set_allgather_comm_fusion(config[key], key)
144
+ elif key == _ParallelFusionConfig.OPENSTATE:
145
+ self._set_openstate_comm_fusion(config[key])
127
146
  else:
128
- raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key))
147
+ raise KeyError("comm fusion type must be openstate,"
148
+ "allreduce, allgather or reducescatter, but got {}".format(key))
149
+ if key in _ParallelFusionConfig.CONFIG:
150
+ _ParallelFusionConfig.CONFIG[key] = config[key]
129
151
 
130
152
  def get_comm_fusion(self):
131
153
  """Get comm fusion config."""
132
154
  self.check_context_handle()
133
- mode = self._context_handle.get_fusion_mode()
134
- if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE):
135
- config = self.fusion_threshold_mb()
136
- if mode == _ParallelFusionConfig.INDEX:
137
- config = self.get_all_reduce_fusion_split_indices()
138
- return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode,
139
- _ParallelFusionConfig.FUSION_CONFIG: config}}
140
-
141
- def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
142
- """
143
- Set allgather and reducescatter fusion method for auto parallel.
144
-
145
- Args:
146
- comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
147
- supports four fusion methods: `auto` and `size`.
148
- comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
149
-
150
- Raises:
151
- KeyError: When key of comm_fusion is not 'mode' or 'config'.
152
- KeyError: When `mode` is not 'auto', 'size'.
153
- """
154
- self.check_context_handle()
155
- if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
156
- return
157
- if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
158
- return
159
- if not isinstance(comm_fusion, dict):
160
- raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
161
- comm_type, type(comm_fusion)))
162
- if _ParallelFusionConfig.MODE not in comm_fusion:
163
- raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
164
- if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
165
- raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
166
- check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
167
- if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
168
- self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
169
- else:
170
- raise KeyError("fusion method mode must be auto or size, but got {}".format(
171
- comm_fusion[_ParallelFusionConfig.MODE]))
172
-
173
- fusion_threshold = 64
174
- if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
175
- fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
176
- self.set_fusion_threshold_mb(fusion_threshold, comm_type)
177
-
178
- def _set_allreduce_comm_fusion(self, comm_fusion):
179
- """
180
- Set fusion method for auto parallel.
181
-
182
- Args:
183
- comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
184
- supports four fusion methods: `auto`, `size` and `index`.
185
-
186
- Raises:
187
- KeyError: When key of comm_fusion is not 'mode' or 'config'.
188
- KeyError: When `mode` is not 'auto', 'size' or 'index'.
189
- """
190
- self.check_context_handle()
191
- if not self.get_enable_all_reduce_fusion():
192
- return
193
- if not isinstance(comm_fusion, dict):
194
- raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
195
- type(comm_fusion)))
196
- if _ParallelFusionConfig.MODE not in comm_fusion:
197
- raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
198
- if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
199
- raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
200
- check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
201
- if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
202
- self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
203
- else:
204
- raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
205
- comm_fusion[_ParallelFusionConfig.MODE]))
206
- if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
207
- self.set_fusion_threshold_mb(fusion_threshold=64)
208
- if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
209
- self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
210
- if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
211
- self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
155
+ return _ParallelFusionConfig.CONFIG
212
156
 
213
157
  def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
214
158
  """
@@ -521,6 +465,9 @@ class _AutoParallelContext:
521
465
  if not isinstance(dim, int):
522
466
  raise TypeError("For 'set_auto_parallel_context', the element of argument "
523
467
  "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
468
+ if context.get_context('mode') == context.PYNATIVE_MODE:
469
+ raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' "
470
+ f"or 'data_parallel', but got {dataset_strategy}.")
524
471
  self._dataset_strategy_using_str = False
525
472
  self._context_handle.set_dataset_strategy(dataset_strategy)
526
473
 
@@ -531,7 +478,11 @@ class _AutoParallelContext:
531
478
  if self._context_handle.get_full_batch():
532
479
  return "full_batch"
533
480
  return "data_parallel"
534
- return self._context_handle.get_dataset_strategy()
481
+ dataset_strategy = self._context_handle.get_dataset_strategy()
482
+ if context.get_context('mode') == context.PYNATIVE_MODE:
483
+ raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' "
484
+ f"or 'data_parallel', but got the setting value is {dataset_strategy}.")
485
+ return dataset_strategy
535
486
 
536
487
  def set_grad_accumulation_step(self, grad_accumulation_step):
537
488
  """
@@ -567,6 +518,52 @@ class _AutoParallelContext:
567
518
  self.check_context_handle()
568
519
  return self._context_handle.get_strategy_ckpt_save_file()
569
520
 
521
+ def set_strategy_ckpt_config(self, strategy_ckpt_config):
522
+ """
523
+ Set strategy checkpoint config.
524
+
525
+ Args:
526
+ strategy_ckpt_config (dict): The strategy checkpoint config.
527
+ """
528
+ self.check_context_handle()
529
+ if not isinstance(strategy_ckpt_config, dict):
530
+ raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' "
531
+ "must be dict, but got the type : {}.".format(type(strategy_ckpt_config)))
532
+ for config_name in strategy_ckpt_config:
533
+ unknown_config = []
534
+ if config_name not in ["load_file", "save_file", "only_trainable_params"]:
535
+ unknown_config.append(config_name)
536
+
537
+ if unknown_config:
538
+ raise ValueError("Unknown config: {}".format(unknown_config))
539
+ if "load_file" in strategy_ckpt_config:
540
+ load_file = strategy_ckpt_config.get("load_file")
541
+ if not isinstance(load_file, str):
542
+ raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
543
+ "the argument 'load_file' must be str, but got the type : {} .".format(type(load_file)))
544
+ self._context_handle.set_strategy_ckpt_load_file(load_file)
545
+ if "save_file" in strategy_ckpt_config:
546
+ save_file = strategy_ckpt_config.get("save_file")
547
+ if not isinstance(save_file, str):
548
+ raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
549
+ "the argument 'save_file' must be str, but got the type : {} .".format(type(save_file)))
550
+ self._context_handle.set_strategy_ckpt_save_file(save_file)
551
+ if "only_trainable_params" in strategy_ckpt_config:
552
+ only_trainable_params = strategy_ckpt_config.get("only_trainable_params")
553
+ if not isinstance(only_trainable_params, bool):
554
+ raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
555
+ "the argument 'only_trainable_params' must be bool,"
556
+ " but got the type : {} .".format(type(only_trainable_params)))
557
+ self._context_handle.set_stra_file_only_trainable_params(only_trainable_params)
558
+
559
+ def get_strategy_ckpt_config(self):
560
+ """Get strategy checkpoint config."""
561
+ self.check_context_handle()
562
+ load_file = self._context_handle.get_strategy_ckpt_load_file()
563
+ save_file = self._context_handle.get_strategy_ckpt_save_file()
564
+ only_trainable_param = self._context_handle.get_stra_file_only_trainable_params()
565
+ return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param}
566
+
570
567
  def set_group_ckpt_save_file(self, group_ckpt_save_file):
571
568
  """Set group checkpoint save path."""
572
569
  self.check_context_handle()
@@ -912,6 +909,7 @@ class _AutoParallelContext:
912
909
  return self._context_handle.get_optimizer_weight_shard_aggregated_save()
913
910
 
914
911
  def get_full_batch_is_set(self):
912
+ """Get full batch attr"""
915
913
  self.check_context_handle()
916
914
  return self._context_handle.get_full_batch_is_set()
917
915
 
@@ -919,6 +917,7 @@ class _AutoParallelContext:
919
917
  """Reset all settings."""
920
918
  self.check_context_handle()
921
919
  self._context_handle.reset()
920
+ _ParallelFusionConfig.reset()
922
921
 
923
922
  def _check_and_default_group(self, group):
924
923
  """Validate the given group, if group is empty, returns a default fusion group"""
@@ -936,6 +935,99 @@ class _AutoParallelContext:
936
935
  group = _DEFAULT_NCCL_FUSION_GROUP_NAME
937
936
  return group
938
937
 
938
+ def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
939
+ """
940
+ Set allgather and reducescatter fusion method for auto parallel.
941
+
942
+ Args:
943
+ comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
944
+ supports four fusion methods: `auto` and `size`.
945
+ comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
946
+
947
+ Raises:
948
+ KeyError: When key of comm_fusion is not 'mode' or 'config'.
949
+ KeyError: When `mode` is not 'auto', 'size'.
950
+ """
951
+ self.check_context_handle()
952
+ if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
953
+ return
954
+ if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
955
+ return
956
+ if not isinstance(comm_fusion, dict):
957
+ raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
958
+ comm_type, type(comm_fusion)))
959
+ if _ParallelFusionConfig.MODE not in comm_fusion:
960
+ raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
961
+ if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
962
+ raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
963
+ check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
964
+ if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
965
+ self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
966
+ else:
967
+ raise KeyError("fusion method mode must be auto or size, but got {}".format(
968
+ comm_fusion[_ParallelFusionConfig.MODE]))
969
+
970
+ fusion_threshold = 64
971
+ if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
972
+ fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
973
+ self.set_fusion_threshold_mb(fusion_threshold, comm_type)
974
+
975
+ def _set_allreduce_comm_fusion(self, comm_fusion):
976
+ """
977
+ Set fusion method for auto parallel.
978
+
979
+ Args:
980
+ comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
981
+ supports four fusion methods: `auto`, `size` and `index`.
982
+
983
+ Raises:
984
+ KeyError: When key of comm_fusion is not 'mode' or 'config'.
985
+ KeyError: When `mode` is not 'auto', 'size' or 'index'.
986
+ """
987
+ self.check_context_handle()
988
+ if not self.get_enable_all_reduce_fusion():
989
+ return
990
+ if not isinstance(comm_fusion, dict):
991
+ raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
992
+ type(comm_fusion)))
993
+ if _ParallelFusionConfig.MODE not in comm_fusion:
994
+ raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
995
+ if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
996
+ raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
997
+ check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
998
+ if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
999
+ self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
1000
+ else:
1001
+ raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
1002
+ comm_fusion[_ParallelFusionConfig.MODE]))
1003
+ if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
1004
+ self.set_fusion_threshold_mb(fusion_threshold=64)
1005
+ if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
1006
+ self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
1007
+ if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
1008
+ self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
1009
+
1010
+ def _set_openstate_comm_fusion(self, openstate):
1011
+ """
1012
+ Set open state for comm fusion.
1013
+
1014
+ Args:
1015
+ openstate (bool): The open state value to set the fusion method whether or not. Currently it
1016
+ supports two states: `True`, or `Flase`.
1017
+
1018
+ Raises:
1019
+ TypeError: When the value is not bool.
1020
+ """
1021
+ self.check_context_handle()
1022
+ if not isinstance(openstate, bool):
1023
+ raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format(
1024
+ type(openstate)))
1025
+ if not openstate:
1026
+ self.set_enable_all_reduce_fusion(openstate)
1027
+ self.set_enable_all_gather_fusion(openstate)
1028
+ self.set_enable_reduce_scatter_fusion(openstate)
1029
+
1030
+
939
1031
 
940
1032
  _AUTO_PARALLEL_CONTEXT = None
941
1033
 
@@ -978,6 +1070,7 @@ _set_auto_parallel_context_func_map = {
978
1070
  "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
979
1071
  "sharding_propagation": auto_parallel_context().set_sharding_propagation,
980
1072
  "enable_alltoall": auto_parallel_context().set_enable_alltoall,
1073
+ "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
981
1074
  "comm_fusion": auto_parallel_context().set_comm_fusion}
982
1075
 
983
1076
 
@@ -1005,6 +1098,7 @@ _get_auto_parallel_context_func_map = {
1005
1098
  "sharding_propagation": auto_parallel_context().get_sharding_propagation,
1006
1099
  "enable_alltoall": auto_parallel_context().get_enable_alltoall,
1007
1100
  "comm_fusion": auto_parallel_context().get_comm_fusion,
1101
+ "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
1008
1102
  "full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
1009
1103
 
1010
1104
 
@@ -1014,7 +1108,8 @@ _get_auto_parallel_context_func_map = {
1014
1108
  strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
1015
1109
  grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
1016
1110
  communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
1017
- optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict)
1111
+ optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
1112
+ strategy_ckpt_config=dict)
1018
1113
 
1019
1114
  def _set_auto_parallel_context(**kwargs):
1020
1115
  """
@@ -1091,12 +1186,23 @@ def _set_auto_parallel_context(**kwargs):
1091
1186
  communication fusion config has two keys: "mode" and "config".
1092
1187
  It supports following communication fusion types and configurations:
1093
1188
 
1189
+ - openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
1190
+ the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
1191
+
1094
1192
  - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
1095
1193
  and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
1096
1194
  fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
1097
1195
  manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
1098
1196
  `all_reduce_fusion_config`.
1099
1197
 
1198
+ - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
1199
+ In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
1200
+ threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
1201
+ manually, and the fusion threshold must be larger than `0` MB.
1202
+
1203
+ - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
1204
+ and `size`. Config is same as `allgather`.
1205
+
1100
1206
 
1101
1207
  Raises:
1102
1208
  ValueError: If input key is not attribute in auto parallel context.