mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,14 @@ import ast
21
21
  import importlib
22
22
  import types
23
23
  import time
24
-
25
24
  import astunparse
26
25
 
27
26
  from mindspore.nn import Cell
28
27
  from mindspore import log as logger
29
28
  from mindspore.rewrite.ast_creator_register import ast_creator_registry
30
- from .node import Node, TreeNode, PASS_THROUGH_METHOD
29
+ from .node import Node, TreeNode
31
30
  from .api.node_type import NodeType
32
- from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder
31
+ from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, CheckPropertyIsUsed
33
32
  from .api.scoped_value import ScopedValue, ValueType
34
33
  from .symbol_tree_dumper import SymbolTreeDumper
35
34
  from .topological_manager import TopoManager
@@ -160,7 +159,6 @@ class SymbolTree(Observer, Observable):
160
159
  self._topo_mgr = TopoManager()
161
160
  self._topo_mgr.reg_observer(self)
162
161
 
163
- self._global_vars: {str, object} = {origin_network_key: origin_network}
164
162
  self._nodes: {str, Node} = {}
165
163
  # parameters of forward method
166
164
  self._inputs: [Node] = []
@@ -171,6 +169,10 @@ class SymbolTree(Observer, Observable):
171
169
  self._class_ast: Optional[ast.ClassDef] = None
172
170
  self._root_ast: Optional[ast.FunctionDef] = None
173
171
  self._init_func_ast: Optional[ast.FunctionDef] = None
172
+ self._deleted_field = {}
173
+ self._deleted_node = []
174
+ self._external_func_ast = []
175
+ self._father_class_ast = []
174
176
 
175
177
  # head node is always point to the first node(in source code order) of SymbolTree
176
178
  self._head = None
@@ -263,6 +265,8 @@ class SymbolTree(Observer, Observable):
263
265
  for node in stree.nodes():
264
266
  if not isinstance(node, TreeNode):
265
267
  continue
268
+ if node.symbol_tree._class_ast is None:
269
+ continue
266
270
  sub_stree: SymbolTree = node.symbol_tree
267
271
  SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
268
272
  # all modified ast.ClassDef should export to code
@@ -281,31 +285,7 @@ class SymbolTree(Observer, Observable):
281
285
  """Add Event.TopologicalChangeEvent event when build is finished."""
282
286
  self.add_event(Event.TopologicalChangeEvent)
283
287
 
284
- def create_assign_node(self, targets, func_name, args, kwargs):
285
- """
286
- Create a ast.Assign type node.
287
-
288
- Args:
289
- targets (list): _description_
290
- func_name (_type_): _description_
291
- args (_type_): _description_
292
- kwargs (_type_): _description_
293
-
294
- Returns:
295
- _type_: _description_
296
- """
297
- # create targets
298
- ast_targets = [ast_creator_registry.get("Name")(targets)]
299
- # create call
300
- ast_func = ast_creator_registry.get("Attribute")(func_name)
301
- ast_args = ast_creator_registry.get("Args")(args)
302
- ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
303
- ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
304
- # create assign
305
- ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
306
- return ast_node
307
-
308
- def create_call_function(self, func, targets, args, kwargs):
288
+ def _create_call_function(self, func, targets, args, kwargs):
309
289
  """
310
290
  Create a Node object and generate the execution code to insert into the source code.
311
291
  The source code calls the 'func' function with 'args' and' kwargs' as parameters.
@@ -345,6 +325,30 @@ class SymbolTree(Observer, Observable):
345
325
  call_kwargs)
346
326
  return node
347
327
 
328
+ def create_assign_node(self, targets, func_name, args, kwargs):
329
+ """
330
+ Create a ast.Assign type node.
331
+
332
+ Args:
333
+ targets (list): _description_
334
+ func_name (_type_): _description_
335
+ args (_type_): _description_
336
+ kwargs (_type_): _description_
337
+
338
+ Returns:
339
+ _type_: _description_
340
+ """
341
+ # create targets
342
+ ast_targets = [ast_creator_registry.get("Name")(targets)]
343
+ # create call
344
+ ast_func = ast_creator_registry.get("Attribute")(func_name)
345
+ ast_args = ast_creator_registry.get("Args")(args)
346
+ ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
347
+ ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
348
+ # create assign
349
+ ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
350
+ return ast_node
351
+
348
352
  def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
349
353
  '''
350
354
  Instantiate an instance of node whose type is `CallFunction`.
@@ -458,12 +462,6 @@ class SymbolTree(Observer, Observable):
458
462
  self._init_func_ast = ast_node
459
463
 
460
464
  def get_inputs(self):
461
- """
462
- Getter of `_inputs` which represents parameters of current forward method.
463
-
464
- Returns:
465
- A list of instance of Node whose node_type is NodeType.Input as input nodes.
466
- """
467
465
  return self._inputs
468
466
 
469
467
  def get_head_node(self):
@@ -484,17 +482,6 @@ class SymbolTree(Observer, Observable):
484
482
  """
485
483
  return self._origin_network
486
484
 
487
- def get_global_vars(self):
488
- """Get global variables."""
489
- return self._global_vars
490
-
491
- def add_global_vars(self, key: str, value):
492
- """Add global variables."""
493
- if self._global_vars.get(key) is not None:
494
- logger.info(f"The key '{key}' is duplicated")
495
- return
496
- self._global_vars[key] = value
497
-
498
485
  def get_nodes_dict(self):
499
486
  """Get dict of nodes"""
500
487
  return self._nodes
@@ -614,7 +601,6 @@ class SymbolTree(Observer, Observable):
614
601
  RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
615
602
  SymbolTree.
616
603
  """
617
-
618
604
  node = self._get_real_node(node_or_name)
619
605
  if node is None:
620
606
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
@@ -653,7 +639,12 @@ class SymbolTree(Observer, Observable):
653
639
  RuntimeError: If 'position' is not in current SymbolTree.
654
640
  RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
655
641
  """
656
-
642
+ if position is not None and hasattr(position.node, "container"):
643
+ cellcontainer = getattr(position.node, "container")
644
+ index = cellcontainer.node_list.index(position.node)
645
+ index = index if position.before_node else index + 1
646
+ cellcontainer.insert(index, node)
647
+ return node
657
648
  # if position in current SymbolTree
658
649
  if position is not None and position.symbol_tree is not self:
659
650
  raise RuntimeError("Position is not in current SymbolTree:", position)
@@ -678,37 +669,7 @@ class SymbolTree(Observer, Observable):
678
669
  self._node_visitor.append_node(node)
679
670
  # update init-function-ast and construct-function-ast
680
671
  if insert_to_ast:
681
- node.set_func(ScopedValue.create_naming_value(node_name, "self"))
682
- node_ast = node.get_ast()
683
- if not isinstance(node_ast, ast.Assign):
684
- raise RuntimeError("Only support insert cell op now")
685
- if isinstance(node, TreeNode):
686
- global_vars_key = node.get_name() + "_args"
687
- self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
688
- args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
689
- [ScopedValue.create_variable_value(global_vars_key)])
690
- value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
691
- col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
692
-
693
- ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
694
- assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
695
- AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
696
-
697
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
698
- None if position is None else position.node.get_ast(),
699
- position.before_node)
700
- sub_stree: SymbolTree = node.symbol_tree
701
- from .symbol_tree_builder import SymbolTreeBuilder
702
- SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
703
- else:
704
- AstModifier.insert_assign_to_function(self._init_func_ast,
705
- targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
706
- expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
707
- args=[ScopedValue(ValueType.StringValue, "", node_name)])
708
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
709
- None if position is None else position.node.get_ast(),
710
- position.before_node)
711
- self._global_vars[node_name] = node.get_instance()
672
+ self._insert_to_ast_while_insert_node(node, position)
712
673
  return node
713
674
 
714
675
  def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
@@ -807,8 +768,9 @@ class SymbolTree(Observer, Observable):
807
768
  Returns:
808
769
  An instance of python node which has been appended to SymbolTree.
809
770
  """
810
- logger.warning("Ignoring unsupported node(%s) in %s.", type(ast_node).__name__, type(ast_scope).__name__)
771
+ logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
811
772
  node_name = self._node_name_namer.get_name(type(ast_node).__name__)
773
+ self._update_names_for_unique(ast_node)
812
774
  node = Node.create_python_node(ast_node, node_name)
813
775
  self._insert_node(Position.create(self, self._tail, False), node)
814
776
  return node
@@ -851,6 +813,10 @@ class SymbolTree(Observer, Observable):
851
813
  node = self._get_real_node(node_or_name)
852
814
  if node is None:
853
815
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
816
+ if hasattr(node, "container"):
817
+ cellcontainer = getattr(node, "container")
818
+ cellcontainer.erase(node)
819
+ return node
854
820
  ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
855
821
  if not ret:
856
822
  raise RuntimeError("node not in function ast tree.")
@@ -860,6 +826,7 @@ class SymbolTree(Observer, Observable):
860
826
  value.isolate()
861
827
  break
862
828
  self._topo_mgr.on_erase_node(node)
829
+ self._deleted_node.append(node.get_name())
863
830
  return node
864
831
 
865
832
  def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
@@ -884,6 +851,9 @@ class SymbolTree(Observer, Observable):
884
851
  RuntimeError: If 'old_node' is not belong to current SymbolTree.
885
852
  """
886
853
 
854
+ if hasattr(old_node, "container"):
855
+ self._replace_container_node(old_node, new_nodes)
856
+ return new_nodes[0]
887
857
  real_old_node = self._get_real_node(old_node)
888
858
  if real_old_node is None:
889
859
  raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
@@ -981,6 +951,13 @@ class SymbolTree(Observer, Observable):
981
951
  dump_st = SymbolTreeDumper(self)
982
952
  dump_st.dump()
983
953
 
954
+ def update_module_ast(self):
955
+ for node in self._external_func_ast:
956
+ self._module_ast.body.append(node)
957
+ for node in self._father_class_ast:
958
+ index = self._module_ast.body.index(self._class_ast)
959
+ self._module_ast.body.insert(index, node)
960
+
984
961
  def get_code(self) -> str:
985
962
  """
986
963
  Get source code of modified network.
@@ -992,6 +969,7 @@ class SymbolTree(Observer, Observable):
992
969
  if self._init_func_ast:
993
970
  self._remove_unused_field()
994
971
  self._remove_duplicated_import()
972
+ self.update_module_ast()
995
973
  ast.fix_missing_locations(self._module_ast)
996
974
  # Find all ast.ClassDef which can be export to code
997
975
  # Replace duplicated ast.ClassDef reference in main-ClassDef
@@ -1026,21 +1004,20 @@ class SymbolTree(Observer, Observable):
1026
1004
  A network object.
1027
1005
  """
1028
1006
  cls = self._get_cls_through_file()
1029
- return cls(self._global_vars)
1007
+ new_net = cls(self._origin_network)
1008
+ self._merge_origin_property(new_net)
1009
+ return new_net
1030
1010
 
1031
1011
  def set_saved_file_name(self, file_name: str):
1032
- """Sets the filename used to save the network."""
1033
1012
  if file_name.endswith(".py"):
1034
1013
  self._saved_file_name = file_name
1035
1014
  else:
1036
1015
  self._saved_file_name = file_name + ".py"
1037
1016
 
1038
1017
  def get_saved_file_name(self):
1039
- """Gets the filename used to save the network."""
1040
1018
  return self._saved_file_name
1041
1019
 
1042
1020
  def save_network_to_file(self):
1043
- """Save the modified network to a file."""
1044
1021
  abs_path = os.path.abspath(self._saved_file_name)
1045
1022
  if os.path.isfile(abs_path):
1046
1023
  os.remove(abs_path)
@@ -1049,6 +1026,58 @@ class SymbolTree(Observer, Observable):
1049
1026
  f.write(source.encode('utf-8'))
1050
1027
  f.flush()
1051
1028
 
1029
+ def update_scope_for_unique(self, node: Union[ast.Attribute, ast.Call, ast.Subscript]):
1030
+ """ Update scope of ast node because of unique-ing of targets of other nodes. """
1031
+ if isinstance(node, ast.Call):
1032
+ self.update_scope_for_unique(node.func)
1033
+ return
1034
+ if not isinstance(node, (ast.Attribute, ast.Subscript)):
1035
+ logger.warning(f"Cannot update node {astunparse.unparse(node)} for unique, type of node should "
1036
+ f"be one of (ast.Attribute, ast.Subscript).")
1037
+ return
1038
+ scope = node.value
1039
+ if not isinstance(scope, ast.Name):
1040
+ self.update_scope_for_unique(scope)
1041
+ return
1042
+ scope_name = scope.id
1043
+ scope_name_unique = self._target_namer.get_real_arg(scope_name)
1044
+ scope.id = scope_name_unique
1045
+
1046
+ def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
1047
+ """ insert_to_ast_while_insert_node. """
1048
+ node.set_func(ScopedValue.create_naming_value(node.get_name(), "self"))
1049
+ node_ast = node.get_ast()
1050
+ if not isinstance(node_ast, ast.Assign):
1051
+ raise RuntimeError("Only support insert cell op now")
1052
+ if isinstance(node, TreeNode):
1053
+ setattr(self._origin_network, node.get_name(), node.get_instance())
1054
+ args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
1055
+ [ScopedValue(ValueType.NamingValue, "", "obj"),
1056
+ ScopedValue(ValueType.StringValue, "", node.get_name())])
1057
+ value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
1058
+ col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
1059
+
1060
+ ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
1061
+ assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
1062
+ AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
1063
+
1064
+ AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1065
+ None if position is None else position.node.get_ast(),
1066
+ position.before_node)
1067
+ sub_stree: SymbolTree = node.symbol_tree
1068
+ from .symbol_tree_builder import SymbolTreeBuilder
1069
+ SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
1070
+ else:
1071
+ AstModifier.insert_assign_to_function(self._init_func_ast,
1072
+ targets=[ScopedValue(ValueType.NamingValue, "self", node.get_name())],
1073
+ expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
1074
+ args=[ScopedValue(ValueType.NamingValue, "", "obj"),
1075
+ ScopedValue(ValueType.StringValue, "", node.get_name())])
1076
+ AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1077
+ None if position is None else position.node.get_ast(),
1078
+ position.before_node)
1079
+ setattr(self._origin_network, node.get_name(), node.get_instance())
1080
+
1052
1081
  def _remove_unused_import(self):
1053
1082
  """remove unused import in self._module_ast"""
1054
1083
  str_checker = StrChecker(self._module_ast)
@@ -1070,49 +1099,43 @@ class SymbolTree(Observer, Observable):
1070
1099
  else:
1071
1100
  body.names.remove(alias)
1072
1101
 
1102
+ def _replace_container_node(self, old_node, new_nodes):
1103
+ cellcontainer = getattr(old_node, "container")
1104
+ index = cellcontainer.node_list.index(old_node)
1105
+ for n in reversed(new_nodes):
1106
+ cellcontainer.insert(index, n)
1107
+ index = cellcontainer.node_list.index(old_node)
1108
+ cellcontainer.erase(old_node)
1109
+
1073
1110
  def _filter_out_to_delete_field(self, to_delete_field):
1074
1111
  """filter out used field from `to_delete_field`"""
1075
- # filter _handler field
1076
- if to_delete_field.get("_handler"):
1077
- to_delete_field.pop("_handler")
1078
- # filter field used in node of construct
1079
- for node in self._nodes.values():
1080
- if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
1081
- func: ScopedValue = node.get_func()
1082
- if func.scope == "self" and to_delete_field.get(func.value):
1083
- to_delete_field.pop(func.value)
1084
- if node.get_node_type() == NodeType.CallMethod and node.get_func() == PASS_THROUGH_METHOD:
1085
- var_name = node.get_args()[0].value
1086
- if to_delete_field.get(var_name):
1087
- to_delete_field.pop(var_name)
1088
- # filter field used in test-of-if of construct function
1089
- for body in self._root_ast.body:
1090
- if not isinstance(body, ast.If):
1112
+ for func_def in self._class_ast.body:
1113
+ if not isinstance(func_def, ast.FunctionDef):
1091
1114
  continue
1092
- test = body.test
1093
- field_finder = FieldFinder(test)
1094
- to_delete_to_delete_keys = []
1095
- for key, _ in to_delete_field.items():
1096
- if field_finder.check(key):
1097
- to_delete_to_delete_keys.append(key)
1098
- for key in to_delete_to_delete_keys:
1099
- to_delete_field.pop(key)
1100
- # filter field used in test-of-if of init function
1101
- for body in self._init_func_ast.body:
1102
- if not isinstance(body, ast.If):
1103
- continue
1104
- test = body.test
1105
- field_finder = FieldFinder(test)
1106
- to_delete_to_delete_keys = []
1107
- for key, _ in to_delete_field.items():
1108
- if field_finder.check(key):
1109
- to_delete_to_delete_keys.append(key)
1110
- for key in to_delete_to_delete_keys:
1111
- to_delete_field.pop(key)
1115
+ if func_def.name != "__init__":
1116
+ to_delete_to_delete_keys = []
1117
+ property_checker = CheckPropertyIsUsed(func_def)
1118
+ for key, _ in self._deleted_field.items():
1119
+ if property_checker.check("self", key):
1120
+ to_delete_to_delete_keys.append(key)
1121
+ property_checker = CheckPropertyIsUsed(func_def)
1122
+ for key in to_delete_to_delete_keys:
1123
+ self._deleted_field.pop(key)
1124
+ else:
1125
+ for body in func_def.body:
1126
+ if not isinstance(body, ast.If):
1127
+ continue
1128
+ test = body.test
1129
+ field_finder = FieldFinder(test)
1130
+ to_delete_to_delete_keys = []
1131
+ for key, _ in self._deleted_field.items():
1132
+ if field_finder.check(key):
1133
+ to_delete_to_delete_keys.append(key)
1134
+ for key in to_delete_to_delete_keys:
1135
+ self._deleted_field.pop(key)
1112
1136
 
1113
1137
  def _remove_unused_field(self):
1114
1138
  """remove unused field in __init__ function"""
1115
- to_delete_field = {}
1116
1139
  multi_targets = []
1117
1140
  for index, body in enumerate(self._init_func_ast.body):
1118
1141
  if not isinstance(body, ast.Assign):
@@ -1121,12 +1144,12 @@ class SymbolTree(Observer, Observable):
1121
1144
  for target in targets:
1122
1145
  if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
1123
1146
  and target.value.id == "self":
1124
- to_delete_field[target.attr] = index
1147
+ self._deleted_field[target.attr] = index
1125
1148
  if len(targets) > 1:
1126
1149
  multi_targets.append(index)
1127
- self._filter_out_to_delete_field(to_delete_field)
1150
+ self._filter_out_to_delete_field(self._deleted_field)
1128
1151
  for i in range(len(self._init_func_ast.body) - 1, -1, -1):
1129
- if i in to_delete_field.values():
1152
+ if i in self._deleted_field.values():
1130
1153
  if i in multi_targets:
1131
1154
  raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
1132
1155
  AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
@@ -1144,12 +1167,9 @@ class SymbolTree(Observer, Observable):
1144
1167
  self._module_ast.body.remove(body)
1145
1168
 
1146
1169
  def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1147
- if isinstance(node_or_name, Node):
1148
- result = self.get_node(node_or_name.get_name())
1149
- return result if result is node_or_name else None
1150
1170
  if isinstance(node_or_name, str):
1151
1171
  return self.get_node(node_or_name)
1152
- return None
1172
+ return node_or_name
1153
1173
 
1154
1174
  def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
1155
1175
  """
@@ -1298,7 +1318,7 @@ class SymbolTree(Observer, Observable):
1298
1318
  raise TypeError("value should be ScopedValue, got: ", type(value))
1299
1319
  if value.type == ValueType.CustomObjValue:
1300
1320
  field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
1301
- self._global_vars[field] = value.value
1321
+ setattr(self._origin_network, field, value.value)
1302
1322
  init_targets = [ScopedValue.create_naming_value(field, "self")]
1303
1323
  AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
1304
1324
  result[arg] = init_targets[0]
@@ -1316,15 +1336,34 @@ class SymbolTree(Observer, Observable):
1316
1336
  Returns:
1317
1337
  A class handle.
1318
1338
  """
1319
- file_name = "new_network_{0}.py".format(int(time.time() * 10000))
1320
- with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1339
+ self._update_container()
1340
+ file_path = os.getcwd()
1341
+ file_path = os.path.join(file_path, "rewritten_network")
1342
+ if not os.path.exists(file_path):
1343
+ os.mkdir(file_path)
1344
+ file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
1345
+ network_file = os.path.join(file_path, file_name)
1346
+ with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1321
1347
  source = self.get_code()
1322
1348
  f.write(source.encode('utf-8'))
1323
1349
  f.flush()
1324
- tmp_module_path, tmp_module_file = os.path.split(file_name)
1350
+ os.fsync(f)
1351
+ tmp_module_path, tmp_module_file = os.path.split(network_file)
1325
1352
  tmp_module_name = tmp_module_file[:-3]
1326
1353
  sys.path.append(tmp_module_path)
1327
- tmp_module = importlib.import_module(tmp_module_name)
1354
+ tmp_module = None
1355
+
1356
+ i = 0
1357
+ while not tmp_module:
1358
+ try:
1359
+ tmp_module = importlib.import_module(tmp_module_name)
1360
+ except ModuleNotFoundError:
1361
+ if i > 10:
1362
+ break
1363
+ time.sleep(0.1)
1364
+ i += 1
1365
+ if not tmp_module:
1366
+ logger.error(f"load module {tmp_module_name} failed.")
1328
1367
  network_cls = getattr(tmp_module, self._opt_cls_name)
1329
1368
  if network_cls is None:
1330
1369
  raise RuntimeError("Can not find network class:", self._opt_cls_name)
@@ -1333,3 +1372,87 @@ class SymbolTree(Observer, Observable):
1333
1372
  def _on_change(self, event: Event):
1334
1373
  self._modified = True
1335
1374
  self.changed(event)
1375
+
1376
+ def _update_container(self):
1377
+ """Update instance of node in container."""
1378
+ for node in self.nodes():
1379
+ index = 0
1380
+ if node.get_node_type() == NodeType.CellContainer:
1381
+ for n in node.node_list:
1382
+ if not n.valid:
1383
+ continue
1384
+ if n.get_node_type() == NodeType.Tree:
1385
+ obj = n.symbol_tree.get_network()
1386
+ node.get_instance()[index] = obj
1387
+ else:
1388
+ node.get_instance()[index] = n.get_instance()
1389
+ index += 1
1390
+
1391
+ def _cal_difference_set(self, input, other):
1392
+ """Calculate different set of two sets."""
1393
+ set1 = set(input)
1394
+ set2 = set(other)
1395
+ return set1 - set2
1396
+
1397
+ def _merge_origin_property(self, new_net):
1398
+ """Merge property of two network."""
1399
+ tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net))
1400
+ new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys())
1401
+ for name in new_attr_names:
1402
+ setattr(new_net, name, getattr(self._origin_network, name))
1403
+ # merger cells
1404
+ cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys())
1405
+ cells = self._cal_difference_set(cells, self._deleted_node)
1406
+ for c in cells:
1407
+ new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
1408
+ # merge primitives
1409
+ primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1410
+ for p in primitives:
1411
+ new_net._primitives[p] = self._origin_network._primitives[p]
1412
+
1413
+ def _update_names_for_unique(self, node: ast.AST):
1414
+ """ Update names of ast nodes for unique. """
1415
+ if isinstance(node, (ast.For, ast.If, ast.While)):
1416
+ self._update_names_for_unique_branchs(node)
1417
+ elif isinstance(node, ast.Assign):
1418
+ self._update_names_for_unique(node.value)
1419
+ for target in node.targets:
1420
+ self._update_names_for_unique(target)
1421
+ elif isinstance(node, ast.Call):
1422
+ if isinstance(node.func, ast.Attribute):
1423
+ self._update_names_for_unique(node.func.value)
1424
+ for arg in node.args:
1425
+ self._update_names_for_unique(arg)
1426
+ for keyword in node.keywords:
1427
+ self._update_names_for_unique(keyword)
1428
+ elif isinstance(node, ast.UnaryOp):
1429
+ self._update_names_for_unique(node.operand)
1430
+ elif isinstance(node, ast.BinOp):
1431
+ self._update_names_for_unique(node.left)
1432
+ self._update_names_for_unique(node.right)
1433
+ elif isinstance(node, (ast.Attribute, ast.Subscript, ast.Return)):
1434
+ self._update_names_for_unique(node.value)
1435
+ elif isinstance(node, (ast.List, ast.Tuple)):
1436
+ for elt in node.elts:
1437
+ self._update_names_for_unique(elt)
1438
+ elif isinstance(node, ast.Compare):
1439
+ for comparator in node.comparators:
1440
+ self._update_names_for_unique(comparator)
1441
+ elif isinstance(node, ast.Name):
1442
+ node.id = self._target_namer.get_real_arg(node.id)
1443
+
1444
+ def _update_names_for_unique_branchs(self, node: Union[ast.For, ast.If, ast.While]):
1445
+ """ Update names of ast nodes for unique with ast.For, ast.If or ast.While """
1446
+ if isinstance(node, ast.For):
1447
+ self._update_names_for_unique(node.target)
1448
+ self._update_names_for_unique(node.iter)
1449
+ for body in node.body:
1450
+ self._update_names_for_unique(body)
1451
+ for body in node.orelse:
1452
+ self._update_names_for_unique(body)
1453
+ elif isinstance(node, (ast.If, ast.While)):
1454
+ self._update_names_for_unique(node.test)
1455
+ for body in node.body:
1456
+ self._update_names_for_unique(body)
1457
+ for body in node.orelse:
1458
+ self._update_names_for_unique(body)
@@ -28,6 +28,41 @@ from .ast_helpers import AstModifier
28
28
  from .ast_helpers import AstFinder
29
29
 
30
30
 
31
+ class FunctionSymbolTreeBuilder:
32
+ """Create function SymbolTree"""
33
+ def __init__(self, network: Cell, ast_root):
34
+ self._origin_net = network
35
+ self._ast_root: ast.Module = ast_root
36
+ self._root_tree: Optional[SymbolTree] = None
37
+
38
+ @staticmethod
39
+ def _ast_transform(ast_root: ast.AST) -> ast.AST:
40
+ """
41
+ Optimize ast before parse.
42
+
43
+ Args:
44
+ ast_root (ast.AST): An instance of ast to be optimized.
45
+
46
+ Returns:
47
+ An instance of ast been optimized.
48
+ """
49
+ transform_list = [FlattenRecursiveStmt()]
50
+ for transformer in transform_list:
51
+ ast_root = transformer.transform(ast_root)
52
+ return ast_root
53
+
54
+ def build(self) -> SymbolTree:
55
+ """
56
+ Build SymbolTree.
57
+
58
+ Returns:
59
+ An instance of SymbolTree.
60
+ """
61
+ self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
62
+ self._root_tree.finish_build()
63
+ return self._root_tree
64
+
65
+
31
66
  class SymbolTreeBuilder:
32
67
  """
33
68
  `SymbolTreeBuilder` for building a SymbolTree from network.
@@ -43,6 +78,8 @@ class SymbolTreeBuilder:
43
78
  network_str = inspect.getsource(type(network))
44
79
  self._ast_root: ast.Module = ast.parse(network_str)
45
80
  self._root_tree: Optional[SymbolTree] = None
81
+ if isinstance(network, Cell) and network.jit_config_dict:
82
+ self._jit_config_dict = network.jit_config_dict
46
83
 
47
84
  @staticmethod
48
85
  def merge_module_of_subtree(main_tree: SymbolTree, sub_stree: SymbolTree):
@@ -140,7 +177,7 @@ class SymbolTreeBuilder:
140
177
  """
141
178
 
142
179
  for node in self._root_tree.nodes():
143
- if isinstance(node, TreeNode):
180
+ if isinstance(node, TreeNode) and node.get_instance():
144
181
  SymbolTreeBuilder.merge_module_of_subtree(self._root_tree, node.symbol_tree)
145
182
 
146
183
  def _reduce_redundant_import(self):