mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,448 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Sparsify transformer"""
16
+ import ast
17
+ import inspect
18
+ import textwrap
19
+ from collections import deque
20
+ import astunparse
21
+
22
+ from mindspore import ops, nn
23
+ from mindspore import log as logger
24
+ from mindspore.rewrite.parsers.assign_parser import AssignParser
25
+ from mindspore.rewrite.sparsify.utils import ArgType, SparseFunc, sparse_rules, get_sparse_func, builtin_ops, \
26
+ get_binop_name, get_sparse_method_outputs, arg_type_to_prefix_map, get_inputs_outputs
27
+
28
+
29
+ OPS_MODULE = "mindspore.ops."
30
+ MAX_RECURSION_DEPTH = 10
31
+
32
+
33
+ def sparsify_helper(f, arg_types, user_defined_rules=None, sparse_name="", full_sparse_rules=None, depth=0):
34
+ """Calls sparse_transformer from raw function."""
35
+ if isinstance(f, nn.Cell):
36
+ tree = ast.parse(textwrap.dedent(inspect.getsource(f.construct)))
37
+ # remove self
38
+ tree.body[0].args.args.pop(0)
39
+ global_vars = f.construct.__globals__
40
+ # pylint: disable=protected-access
41
+ init_vars = f._cells
42
+ else:
43
+ tree = ast.parse(textwrap.dedent(inspect.getsource(f)))
44
+ global_vars = f.__globals__
45
+ init_vars = {}
46
+ functiondef = tree.body[0]
47
+ args = [arg.arg for arg in functiondef.args.args]
48
+ type_map = dict(zip(args, arg_types))
49
+
50
+ sparse_transformer = SparseTransformer(
51
+ type_map, global_vars, init_vars, user_defined_rules, full_sparse_rules, depth)
52
+ sparse_tree = []
53
+ if not sparse_name:
54
+ sparse_name = functiondef.name
55
+ changed = False
56
+ for body in functiondef.body:
57
+ sparse_body = sparse_transformer.transform(body)
58
+ changed |= sparse_transformer.has_changed()
59
+ sparse_tree.append(sparse_body)
60
+ return_types = sparse_transformer.return_types
61
+
62
+ if changed:
63
+ sparse_tree = list(x[0] for x in sparse_transformer.sparse_functiondef.values()) + sparse_tree
64
+ ast_module = ast.Module([ast.FunctionDef(
65
+ sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)])
66
+ return ast_module, True, return_types
67
+ return tree, False, return_types
68
+
69
+
70
+ class SparseTransformer(ast.NodeTransformer):
71
+ """Transformer class for sparsify."""
72
+ def __init__(self, type_map, global_vars, init_vars, user_defined_rules=None, full_sparse_rules=None, depth=0):
73
+ """Init method."""
74
+ super().__init__()
75
+ self.type_map = type_map
76
+ self.global_vars = global_vars
77
+ self.init_vars = init_vars
78
+ self.depth = depth
79
+ self.return_types = (ArgType.NONSPARSE,)
80
+ # maps function name and arg types to sparsified ast and return types, which are then inserted into module
81
+ self.sparse_functiondef = {}
82
+ # maps function name and arg types to return types for ast that do not change after sparsify
83
+ self.origin_functiondef = {}
84
+
85
+ # keeps track of arg_type for each operand on the call stack recursively
86
+ self._frames = deque()
87
+ self._changed = False
88
+ # variables for which arg_types diverge with control flow are not supported, and are considered dead
89
+ # after exiting the block
90
+ self._dead_vars = {}
91
+ # full_sparse_rules are inherited from caller cell and takes precedence over generic rules
92
+ if full_sparse_rules:
93
+ self.full_sparse_rules = full_sparse_rules
94
+ else:
95
+ self.full_sparse_rules = {}
96
+ user_defined_rules = user_defined_rules or {}
97
+ self.get_sparse_rules(user_defined_rules)
98
+
99
+ @staticmethod
100
+ def make_call(node, name="", args=None):
101
+ """Returns a call node with given name and args, if provided."""
102
+ if name:
103
+ func = ast.Name(name, ast.Load())
104
+ else:
105
+ func = node.func
106
+ if args is None:
107
+ args = node.args
108
+ return ast.Call(func, args, node.keywords)
109
+
110
+ def get_sparse_rules(self, user_defined_rules):
111
+ """Generates sparse rules for the transformer from generic sparse rules and user-defined sparse rules."""
112
+ for func, rules in {**sparse_rules, **user_defined_rules}.items():
113
+ for r in rules:
114
+ sparse_func = get_sparse_func(r)
115
+ # sparse rules are accessed by the function object and input arg_types pair
116
+ sparse_func_map = self.full_sparse_rules.get(func, {})
117
+ sparse_func_map[tuple(sparse_func.inputs)] = sparse_func
118
+ self.full_sparse_rules[func] = sparse_func_map
119
+
120
+ def transform(self, node):
121
+ """Transforms a single node which represents a stmt in the ast."""
122
+ self.clear_stack()
123
+ self._changed = False
124
+ stmt = self.visit(node)
125
+ return stmt
126
+
127
+ def has_changed(self):
128
+ """Whether the SparseTransformer has changed"""
129
+ return self._changed
130
+
131
+ def add_frame(self):
132
+ """Add a frame into deque."""
133
+ self._frames.append([])
134
+
135
+ def pop_frame(self):
136
+ """Pop a frame in deque."""
137
+ return tuple(self._frames.pop())
138
+
139
+ def push_onto_frame(self, t):
140
+ """Push an arg_type into frame deque."""
141
+ if not self._frames:
142
+ raise ValueError("Current frame not initialized!")
143
+ self._frames[-1].append(t)
144
+
145
+ def push_all_onto_frame(self, t):
146
+ """Push all arg_types into frame deque."""
147
+ if not self._frames:
148
+ raise ValueError("Current frame not initialized!")
149
+ for i in t:
150
+ self._frames[-1].append(i)
151
+
152
+ def clear_stack(self):
153
+ """Clear frame deque"""
154
+ self._frames.clear()
155
+
156
+ def make_sparse_func(self, func, node_type, inputs):
157
+ """Returns SparseFunc by looking up sparse_rules."""
158
+ rules = {}
159
+ if node_type == ast.Call:
160
+ if isinstance(func, nn.Cell):
161
+ func_name = func.__class__.__name__.lower()
162
+ else:
163
+ func_name = getattr(func, "__name__", func)
164
+ elif node_type == ast.BinOp:
165
+ func_name = func
166
+ rules = self.full_sparse_rules.get(func, {})
167
+
168
+ if ArgType.ANY in rules:
169
+ sparse_func = rules[ArgType.ANY]
170
+ elif inputs in rules:
171
+ sparse_func = rules[inputs]
172
+ else:
173
+ # attempts to find sparse op based on sparse prefix if sparse rules not found
174
+ sparse_func_name = arg_type_to_prefix_map.get(inputs[0], "$") + "_" + func_name
175
+ sparse_op = getattr(ops, sparse_func_name, None)
176
+ if sparse_op is None:
177
+ if any(input_type != ArgType.NONSPARSE for input_type in inputs):
178
+ return None
179
+ outputs = (ArgType.NONSPARSE,)
180
+ else:
181
+ func_name = sparse_func_name
182
+ _, outputs = get_inputs_outputs(sparse_op)
183
+ sparse_func = SparseFunc(func_name, inputs, outputs)
184
+
185
+ if sparse_func.fn != func:
186
+ self._changed = True
187
+ return sparse_func
188
+
189
+ def get_sparse_node(self, node, args, func, arg_types):
190
+ """
191
+ Retrieves target from sparse rules if matches, otherwise sparsify the node by recursively expanding `func`
192
+ until maximum recursion depth is reached. Functions in mindspore.ops are not expanded.
193
+ If no matching sparse rule is found, an error is raised.
194
+ """
195
+ sparse_func = self.make_sparse_func(func, type(node), arg_types)
196
+ if sparse_func is not None:
197
+ if self._changed:
198
+ func_node = ast.Name(sparse_func.fn, ast.Load())
199
+ if sparse_func.fn in self.global_vars:
200
+ func_node = ast.Name(sparse_func.fn, ast.Load())
201
+ else:
202
+ func_node = ast.Name("ops", ast.Load())
203
+ func_node = ast.Attribute(func_node, sparse_func.fn, ast.Load())
204
+ node = ast.Call(func_node, args, node.keywords)
205
+ self.push_all_onto_frame(sparse_func.outputs)
206
+ return node
207
+
208
+ if func.__module__[:len(OPS_MODULE)] == OPS_MODULE:
209
+ raise ValueError(f"Sparse rules not registered for {func}!")
210
+
211
+ if isinstance(func, nn.Cell):
212
+ class_name = func.__class__.__name__
213
+ func_name = class_name.lower()
214
+ init_args = inspect.getfullargspec(func).args
215
+ if len(init_args) != 1:
216
+ raise ValueError(f"Nested cell {class_name} with arguments for init supported!")
217
+ else:
218
+ func_name = func.__name__
219
+ sparse_func_name = f"sparse_{'_'.join(arg_type_to_prefix_map.get(t, 'default') for t in arg_types)}_{func_name}"
220
+ if (func_name, arg_types) in self.sparse_functiondef:
221
+ self._changed = True
222
+ # pylint: disable=get-dict-value-exception
223
+ self.push_all_onto_frame(self.sparse_functiondef[(func_name, arg_types)][1])
224
+ return SparseTransformer.make_call(node, sparse_func_name, args)
225
+ if (func_name, arg_types) in self.origin_functiondef:
226
+ # pylint: disable=get-dict-value-exception
227
+ self.push_all_onto_frame(self.origin_functiondef[(func_name, arg_types)])
228
+ return node
229
+ if self.depth == MAX_RECURSION_DEPTH:
230
+ raise RuntimeError(f"Maximum recursion depth {MAX_RECURSION_DEPTH} for sparsify reached at {func}!")
231
+ functiondef, changed, return_types = sparsify_helper(
232
+ func, arg_types, sparse_name=sparse_func_name, full_sparse_rules=self.full_sparse_rules,
233
+ depth=self.depth + 1)
234
+ self.push_all_onto_frame(return_types)
235
+ if changed:
236
+ self._changed = True
237
+ self.sparse_functiondef[(func_name, arg_types)] = (functiondef, return_types)
238
+ return SparseTransformer.make_call(node, sparse_func_name, args)
239
+ self.origin_functiondef[(func_name, arg_types)] = return_types
240
+ return SparseTransformer.make_call(node, args=args)
241
+
242
+ def map_type_to_target(self, node_target, value_types):
243
+ """Records arg_type for each target."""
244
+ if isinstance(node_target, (ast.Tuple, ast.List)):
245
+ targets = node_target.elts
246
+ if len(targets) != len(value_types):
247
+ raise ValueError(f"Target {astunparse.unparse(node_target)} size and value size not match for "
248
+ f"ast.Assign {len(targets)} != {len(value_types)}")
249
+ target_vars = []
250
+ for target in targets:
251
+ if not isinstance(target, ast.Name):
252
+ raise ValueError(f"Each target {ast.dump(target)} for ast.Assign should be ast.Name!")
253
+ target_vars.append(target.id)
254
+ for var, t in zip(target_vars, value_types):
255
+ self.type_map[var] = t
256
+ elif isinstance(node_target, ast.Name):
257
+ var = node_target.id
258
+ if len(value_types) == 1:
259
+ self.type_map[var] = value_types[0]
260
+ else:
261
+ self.type_map[var] = value_types
262
+ else:
263
+ raise ValueError(f"Targets for ast.Assign not supported for {type(node_target)}!")
264
+
265
+ def visit_method(self, node):
266
+ """Visits each node based on node class."""
267
+ method = "visit_" + node.__class__.__name__
268
+ visitor = getattr(self, method, None)
269
+ if visitor is None:
270
+ raise ValueError(f"{type(node)} is not supported in SparseTransformer!")
271
+ return visitor(node)
272
+
273
+ def visit(self, node):
274
+ """Visitor interface for all nodes."""
275
+ if not node._fields:
276
+ return node
277
+ if isinstance(node, (ast.AugAssign, ast.Expr)):
278
+ return self.visit_generic_stmt(node)
279
+ if isinstance(node, (ast.BoolOp, ast.Compare, ast.Subscript)):
280
+ # node always evaluates to non-sparse values
281
+ return self.visit_generic_expr(node)
282
+ if isinstance(node, (ast.Tuple, ast.List, ast.UnaryOp)):
283
+ # node contains multiple expressions but is not composable
284
+ return self.visit_composite_generic_expr(node)
285
+ if isinstance(node, (ast.Attribute, ast.Num, ast.Str)):
286
+ return self.visit_scalar_expr(node)
287
+ if isinstance(node, (ast.Index, ast.Slice)):
288
+ # node forms only a part of an expression and does not exist as standalone expression
289
+ return self.visit_partial_expr(node)
290
+ return self.visit_method(node)
291
+
292
+ def visit_generic_stmt(self, node):
293
+ """Visitor for generic statement."""
294
+ self.add_frame()
295
+ node = self.generic_visit(node)
296
+ self.pop_frame()
297
+ return node
298
+
299
+ def visit_scalar_expr(self, node):
300
+ """Visitor for scalar expression."""
301
+ self.push_onto_frame(ArgType.NONSPARSE)
302
+ return node
303
+
304
+ def visit_generic_expr(self, node):
305
+ """Visitor for generic expression."""
306
+ self.add_frame()
307
+ node = self.generic_visit(node)
308
+ self.pop_frame()
309
+ self.push_onto_frame(ArgType.NONSPARSE)
310
+ return node
311
+
312
+ def visit_composite_generic_expr(self, node):
313
+ """Visitor for composite generic expression."""
314
+ return self.generic_visit(node)
315
+
316
+ def visit_partial_expr(self, node):
317
+ """Visitor for a part of an expression."""
318
+ return node
319
+
320
+ def visit_Assign(self, node): # pylint: disable=invalid-name
321
+ """Visitor for ast.Assign."""
322
+ self.add_frame()
323
+ value = self.visit(node.value)
324
+ value_types = self.pop_frame()
325
+ for node_target in node.targets:
326
+ self.map_type_to_target(node_target, value_types)
327
+ return ast.Assign(node.targets, value)
328
+
329
+ def visit_BinOp(self, node): # pylint: disable=invalid-name
330
+ """Visitor for ast.Binop."""
331
+ self.add_frame()
332
+ node = self.generic_visit(node)
333
+ arg_types = self.pop_frame()
334
+ if len(arg_types) != 2:
335
+ raise ValueError(f"Binary op {astunparse.unparse(node)} values for arg_type len({arg_types}) != 2")
336
+ func = get_binop_name(node.op)
337
+ if func:
338
+ sparse_func = self.make_sparse_func(func, type(node), arg_types)
339
+ if sparse_func is None:
340
+ raise ValueError(f"Sparse rules not defined for {arg_types[0]} {func} {arg_types[1]}!")
341
+ outputs = sparse_func.outputs
342
+ else:
343
+ outputs = (ArgType.NONSPARSE,)
344
+ self.push_all_onto_frame(outputs)
345
+ return node
346
+
347
+ def visit_Call(self, node): # pylint: disable=invalid-name
348
+ """Visitor for ast.Call."""
349
+ self.add_frame()
350
+ args = []
351
+ for arg in node.args:
352
+ args.append(self.visit(arg))
353
+ arg_types = self.pop_frame()
354
+
355
+ if all(t == ArgType.NONSPARSE for t in arg_types):
356
+ # if none of the arguments is sparse, do nothing
357
+ self.push_onto_frame(ArgType.NONSPARSE)
358
+ return node
359
+
360
+ # pylint: disable=protected-access
361
+ func_name = AssignParser._get_func_name(node)
362
+ if func_name is None or func_name == "":
363
+ raise RuntimeError(f"Function not exist for {ast.dump(node)}!")
364
+ # pylint: disable=protected-access
365
+ func_scope = AssignParser._get_func_scope(node)
366
+
367
+ if not func_scope:
368
+ if func_name in builtin_ops:
369
+ self.push_onto_frame(ArgType.NONSPARSE)
370
+ return node
371
+ if func_name in self.global_vars:
372
+ # external function with sparse arguments are inlined and cached
373
+ func = self.global_vars[func_name]
374
+ return self.get_sparse_node(node, args, func, arg_types)
375
+ raise ValueError(f"Call to undefined {func_name}!")
376
+
377
+ if func_scope in self.global_vars:
378
+ namespace = self.global_vars[func_scope]
379
+ func = getattr(namespace, func_name, None)
380
+ if func is None:
381
+ raise ValueError(f"{func_name} not defined in {namespace}!")
382
+ return self.get_sparse_node(node, args, func, arg_types)
383
+
384
+ if func_scope == "self":
385
+ func = self.init_vars.get(func_name, None)
386
+ if func is None:
387
+ raise ValueError(f"{func_name} not defined in in Cell.__init__!")
388
+ return self.get_sparse_node(node, args, func, arg_types)
389
+
390
+ func_scope_type = self.type_map.get(func_scope, None)
391
+ if func_scope_type is not None:
392
+ # tensor methods
393
+ if func_scope_type == ArgType.NONSPARSE:
394
+ outputs = (ArgType.NONSPARSE,)
395
+ else:
396
+ outputs = get_sparse_method_outputs(func_name, func_scope_type)
397
+ self.push_all_onto_frame(outputs)
398
+ return node
399
+ raise ValueError(f"Undefined var {func_scope}!")
400
+
401
+ def visit_Name(self, node): # pylint: disable=invalid-name
402
+ """Visitor for ast.Name."""
403
+ if node.id in self.type_map:
404
+ tensor_type = self.type_map[node.id]
405
+ elif node.id in self.global_vars:
406
+ logger.warning(f"Global variable {node.id} treaded as nonsparse value by default.")
407
+ tensor_type = ArgType.NONSPARSE
408
+ elif node.id in self._dead_vars:
409
+ raise ValueError(f"Divergent arg_types {self._dead_vars.get(node.id)} for {node.id} are currently not "
410
+ f"supported in control flow and the variable is considered dead upon leaving "
411
+ f"the block")
412
+ else:
413
+ raise ValueError(f"Undefined variable {node.id}!")
414
+
415
+ if isinstance(tensor_type, tuple):
416
+ self.push_all_onto_frame(tensor_type)
417
+ else:
418
+ self.push_onto_frame(tensor_type)
419
+ return node
420
+
421
+ def visit_Return(self, node): # pylint: disable=invalid-name
422
+ """Visitor for ast.Return."""
423
+ self.add_frame()
424
+ node = self.generic_visit(node)
425
+ self.return_types = self.pop_frame()
426
+ return node
427
+
428
+ def visit_While(self, node): # pylint: disable=invalid-name
429
+ """
430
+ Visitor for ast.While.
431
+ Variables for which arg_types diverge with control flow are not supported, and as a fallback routine,
432
+ unsupported variables are treated as out-of-scope after leaving the control flow body.
433
+ """
434
+ self.add_frame()
435
+ test = self.visit(node.test)
436
+ self.pop_frame()
437
+ orig_type_map = self.type_map.copy()
438
+ body = list(self.visit(expr) for expr in node.body)
439
+ for var, t in self.type_map.items():
440
+ if var not in orig_type_map:
441
+ # new variables in while body are considered active after the leaving the block
442
+ orig_type_map[var] = t
443
+ elif orig_type_map[var] != t:
444
+ # variables for which arg_types diverge are considered dead after leaving the block
445
+ self._dead_vars[var] = (t, orig_type_map.pop(var))
446
+ self.type_map = orig_type_map
447
+ orelse = list(self.visit(expr) for expr in node.orelse)
448
+ return ast.While(test, body, orelse)
@@ -0,0 +1,109 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """sparsify implementation"""
16
+ import os
17
+
18
+ from mindspore import ops
19
+ from mindspore.rewrite import SymbolTree, ScopedValue
20
+ from mindspore.rewrite.ast_helpers import AstModifier
21
+ from mindspore.rewrite.sparsify.sparse_transformer import SparseTransformer
22
+ from mindspore.rewrite.sparsify.utils import SparseFunc, ArgType
23
+
24
+
25
+ op_vars = vars(ops)
26
+
27
+
28
+ def get_user_defined_rules(sparse_rules, global_vars, tree):
29
+ """Register user-defined sparse rules."""
30
+ user_defined_rules = {}
31
+
32
+ def register_callable(fn):
33
+ func_name = fn.__name__
34
+ if global_vars.get(func_name, None) is fn:
35
+ init_targets = [ScopedValue.create_naming_value(func_name, "self")]
36
+ AstModifier.append_global_vars_expr_to_init(tree.get_init_func_ast(), init_targets, func_name)
37
+ elif not op_vars.get(func_name, None) is fn:
38
+ raise ValueError(f"{fn} not found in globals or mindspore.ops!")
39
+
40
+ for source, targets in sparse_rules.items():
41
+ if not isinstance(targets, (tuple, list)) or isinstance(targets, SparseFunc):
42
+ targets = [targets]
43
+ else:
44
+ targets = list(targets)
45
+ for sparse_func in targets:
46
+ if isinstance(sparse_func, SparseFunc) and callable(sparse_func.fn):
47
+ register_callable(sparse_func.fn)
48
+ elif callable(sparse_func):
49
+ register_callable(sparse_func)
50
+ rule = user_defined_rules.get(source, [])
51
+ rule.append(sparse_func)
52
+ user_defined_rules[source] = rule
53
+
54
+ return user_defined_rules
55
+
56
+
57
+ def sparsify_tree(tree, arg_types, sparse_rules, f):
58
+ """Sparsify SymbolTree object."""
59
+ global_vars = f.construct.__globals__
60
+ user_defined_rules = get_user_defined_rules(sparse_rules, global_vars, tree)
61
+
62
+ # skip self
63
+ args = [arg.arg for arg in tree.get_ast_root().args.args[1:]]
64
+ if isinstance(arg_types, tuple):
65
+ if len(args) != len(arg_types):
66
+ raise ValueError(f"arg_types should have the same length as function parameters, but "
67
+ f"{len(arg_types)} != {len(args)}!")
68
+ type_map = dict(zip(args, arg_types))
69
+ elif isinstance(arg_types, dict):
70
+ if all(isinstance(i, int) for i in arg_types.keys()):
71
+ type_map = {args[i]: arg_types[i] if i in arg_types else ArgType.NONSPARSE for i in range(len(args))}
72
+ elif all(isinstance(i, str) for i in arg_types.keys()):
73
+ type_map = {arg: arg_types[arg] if arg in arg_types else ArgType.NONSPARSE for arg in args}
74
+ else:
75
+ raise ValueError(f"Keys for arg_types {list(arg_types.keys())} should be all ints or all strings!")
76
+ else:
77
+ raise ValueError(f"Unsupported type for arg_types {type(arg_types)}!")
78
+
79
+ # pylint: disable=protected-access
80
+ init_vars = f._cells
81
+ sparse_transformer = SparseTransformer(type_map, global_vars, init_vars, user_defined_rules)
82
+ for i, node_ast in enumerate(tree.get_ast_root().body):
83
+ sp_ast = sparse_transformer.transform(node_ast)
84
+ if sparse_transformer.has_changed():
85
+ tree.get_ast_root().body[i] = sp_ast
86
+ for module, _ in sparse_transformer.sparse_functiondef.values():
87
+ tree.get_module_ast().body.append(module)
88
+
89
+
90
+ def sparsify(f, arg_types, sparse_rules=None):
91
+ """
92
+ Sparsify a Cell object by inferring the appropriate sparse function calls to replace the original function calls by
93
+ propagating sparse properties provided in `arg_types`.
94
+
95
+ Args:
96
+ f (Cell): Cell object to be sparsified.
97
+ arg_types (Tuple[ArgType] | Dict[int, ArgType]): The type of argument (sparse csr, sparse coo,
98
+ non-sparse etc.) expected by `f`. If `arg_type` is a tuple, its length should be the same as the number of
99
+ arguments for `f`; if `arg_type` is a dictionary, each key represents an index into the arguments, and
100
+ arguments not referenced by the dictionary are considered to be non-sparse.
101
+ sparse_rules (Dict[str, SparseFunc], Optional): Additional sparse rules.
102
+ """
103
+ os.environ["STREE_PYTHON_FALLBACK"] = "1"
104
+ tree = SymbolTree.create(f)
105
+ handler = tree.get_handler()
106
+ sparse_rules = sparse_rules or {}
107
+ sparsify_tree(handler, arg_types, sparse_rules, f)
108
+ os.unsetenv("STREE_PYTHON_FALLBACK")
109
+ return tree.get_network()