mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  164. mindspore/lib/libmindspore_shared_lib.so +0 -0
  165. mindspore/lib/libnnacl.so +0 -0
  166. mindspore/lib/libopencv_core.so.4.5 +0 -0
  167. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  169. mindspore/lib/libps_cache.so +0 -0
  170. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  183. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  187. mindspore/lib/plugin/ascend/libakg.so +0 -0
  188. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  189. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  190. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  191. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  193. mindspore/lib/plugin/cpu/libakg.so +0 -0
  194. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  196. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  197. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  198. mindspore/nn/__init__.py +0 -2
  199. mindspore/nn/cell.py +313 -74
  200. mindspore/nn/dynamic_lr.py +21 -21
  201. mindspore/nn/layer/activation.py +22 -30
  202. mindspore/nn/layer/basic.py +15 -13
  203. mindspore/nn/layer/channel_shuffle.py +1 -1
  204. mindspore/nn/layer/container.py +271 -9
  205. mindspore/nn/layer/conv.py +323 -204
  206. mindspore/nn/layer/dense.py +8 -5
  207. mindspore/nn/layer/embedding.py +33 -27
  208. mindspore/nn/layer/flash_attention.py +61 -95
  209. mindspore/nn/layer/image.py +8 -6
  210. mindspore/nn/layer/math.py +16 -25
  211. mindspore/nn/layer/normalization.py +107 -66
  212. mindspore/nn/layer/padding.py +1 -1
  213. mindspore/nn/layer/pooling.py +131 -109
  214. mindspore/nn/layer/rnn_cells.py +27 -22
  215. mindspore/nn/layer/rnns.py +13 -16
  216. mindspore/nn/layer/thor_layer.py +1 -1
  217. mindspore/nn/layer/transformer.py +221 -154
  218. mindspore/nn/learning_rate_schedule.py +9 -1
  219. mindspore/nn/loss/loss.py +235 -174
  220. mindspore/nn/optim/ada_grad.py +2 -1
  221. mindspore/nn/optim/adadelta.py +1 -0
  222. mindspore/nn/optim/adafactor.py +2 -1
  223. mindspore/nn/optim/adam.py +7 -4
  224. mindspore/nn/optim/adamax.py +3 -2
  225. mindspore/nn/optim/adasum.py +2 -2
  226. mindspore/nn/optim/asgd.py +2 -3
  227. mindspore/nn/optim/ftrl.py +6 -5
  228. mindspore/nn/optim/lamb.py +7 -4
  229. mindspore/nn/optim/lars.py +1 -1
  230. mindspore/nn/optim/lazyadam.py +5 -3
  231. mindspore/nn/optim/momentum.py +2 -1
  232. mindspore/nn/optim/optimizer.py +53 -4
  233. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  234. mindspore/nn/optim/rmsprop.py +4 -3
  235. mindspore/nn/optim/rprop.py +23 -12
  236. mindspore/nn/optim/sgd.py +26 -11
  237. mindspore/nn/optim/thor.py +9 -7
  238. mindspore/nn/probability/bijector/bijector.py +5 -5
  239. mindspore/nn/probability/bijector/power_transform.py +27 -27
  240. mindspore/nn/probability/bijector/softplus.py +3 -3
  241. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  242. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  243. mindspore/nn/probability/distribution/beta.py +3 -3
  244. mindspore/nn/probability/distribution/categorical.py +7 -7
  245. mindspore/nn/probability/distribution/cauchy.py +0 -1
  246. mindspore/nn/probability/distribution/distribution.py +3 -3
  247. mindspore/nn/probability/distribution/gamma.py +3 -3
  248. mindspore/nn/probability/distribution/geometric.py +4 -4
  249. mindspore/nn/probability/distribution/gumbel.py +4 -4
  250. mindspore/nn/probability/distribution/log_normal.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +2 -2
  252. mindspore/nn/probability/distribution/poisson.py +4 -4
  253. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  254. mindspore/nn/probability/distribution/uniform.py +6 -6
  255. mindspore/nn/wrap/__init__.py +4 -2
  256. mindspore/nn/wrap/cell_wrapper.py +87 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  274. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  275. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  276. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  277. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  278. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  279. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  280. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  281. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  283. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  284. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  285. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  286. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  287. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  288. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  289. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  290. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  291. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  293. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  294. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  295. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  297. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  298. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  299. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  300. mindspore/ops/_primitive_cache.py +1 -1
  301. mindspore/ops/_tracefunc.py +45 -13
  302. mindspore/ops/_utils/utils.py +6 -1
  303. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  304. mindspore/ops/_vmap/vmap_base.py +3 -3
  305. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  306. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  307. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  308. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  309. mindspore/ops/arg_dtype_cast.py +54 -0
  310. mindspore/ops/composite/base.py +37 -10
  311. mindspore/ops/composite/math_ops.py +5 -4
  312. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  313. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  314. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  315. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  316. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  319. mindspore/ops/deprecated.py +304 -0
  320. mindspore/ops/function/__init__.py +4 -1
  321. mindspore/ops/function/array_func.py +174 -193
  322. mindspore/ops/function/clip_func.py +81 -13
  323. mindspore/ops/function/debug_func.py +1 -1
  324. mindspore/ops/function/grad/grad_func.py +18 -9
  325. mindspore/ops/function/image_func.py +10 -4
  326. mindspore/ops/function/linalg_func.py +5 -5
  327. mindspore/ops/function/math_func.py +575 -386
  328. mindspore/ops/function/nn_func.py +568 -260
  329. mindspore/ops/function/random_func.py +88 -57
  330. mindspore/ops/function/sparse_func.py +1 -1
  331. mindspore/ops/function/sparse_unary_func.py +14 -12
  332. mindspore/ops/function/vmap_func.py +6 -5
  333. mindspore/ops/functional.py +15 -10
  334. mindspore/ops/op_info_register.py +244 -25
  335. mindspore/ops/operations/__init__.py +31 -19
  336. mindspore/ops/operations/_grad_ops.py +71 -7
  337. mindspore/ops/operations/_inner_ops.py +350 -17
  338. mindspore/ops/operations/_quant_ops.py +4 -8
  339. mindspore/ops/operations/_sequence_ops.py +42 -0
  340. mindspore/ops/operations/array_ops.py +68 -282
  341. mindspore/ops/operations/comm_ops.py +107 -59
  342. mindspore/ops/operations/custom_ops.py +94 -70
  343. mindspore/ops/operations/debug_ops.py +8 -4
  344. mindspore/ops/operations/image_ops.py +18 -12
  345. mindspore/ops/operations/inner_ops.py +26 -3
  346. mindspore/ops/operations/math_ops.py +192 -144
  347. mindspore/ops/operations/nn_ops.py +857 -489
  348. mindspore/ops/operations/other_ops.py +0 -22
  349. mindspore/ops/operations/random_ops.py +53 -111
  350. mindspore/ops/operations/sparse_ops.py +3 -1
  351. mindspore/ops/primitive.py +24 -18
  352. mindspore/parallel/_auto_parallel_context.py +68 -8
  353. mindspore/parallel/_cost_model_context.py +2 -2
  354. mindspore/parallel/_offload_context.py +17 -3
  355. mindspore/parallel/_parallel_serialization.py +12 -5
  356. mindspore/parallel/_ps_context.py +12 -0
  357. mindspore/parallel/_tensor.py +18 -13
  358. mindspore/parallel/_transformer/layers.py +5 -3
  359. mindspore/parallel/_transformer/loss.py +1 -0
  360. mindspore/parallel/_transformer/moe.py +2 -2
  361. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  362. mindspore/parallel/_transformer/transformer.py +23 -3
  363. mindspore/parallel/_utils.py +11 -7
  364. mindspore/parallel/algo_parameter_config.py +85 -5
  365. mindspore/parallel/checkpoint_transform.py +19 -12
  366. mindspore/parallel/shard.py +21 -14
  367. mindspore/profiler/common/struct_type.py +3 -3
  368. mindspore/profiler/common/util.py +4 -2
  369. mindspore/profiler/envprofiling.py +1 -1
  370. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  371. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  372. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  373. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  374. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  375. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  376. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  377. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  378. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  379. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  380. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  381. mindspore/profiler/parser/flops_parser.py +15 -11
  382. mindspore/profiler/parser/framework_parser.py +38 -22
  383. mindspore/profiler/parser/hccl_parser.py +16 -12
  384. mindspore/profiler/parser/integrator.py +22 -11
  385. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  386. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  387. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  388. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  389. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  390. mindspore/profiler/parser/optime_parser.py +1 -1
  391. mindspore/profiler/parser/profiler_info.py +21 -2
  392. mindspore/profiler/parser/step_trace_parser.py +11 -14
  393. mindspore/profiler/profiling.py +179 -89
  394. mindspore/rewrite/api/node.py +102 -19
  395. mindspore/rewrite/api/node_type.py +5 -1
  396. mindspore/rewrite/api/pattern_engine.py +1 -1
  397. mindspore/rewrite/api/scoped_value.py +9 -17
  398. mindspore/rewrite/api/symbol_tree.py +131 -47
  399. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  400. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  401. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  402. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  403. mindspore/rewrite/common/rewrite_elog.py +5 -1
  404. mindspore/rewrite/namer.py +33 -24
  405. mindspore/rewrite/namespace.py +14 -5
  406. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  407. mindspore/rewrite/node/call_function.py +79 -0
  408. mindspore/rewrite/node/cell_container.py +135 -0
  409. mindspore/rewrite/node/control_flow.py +88 -0
  410. mindspore/rewrite/{node.py → node/node.py} +273 -234
  411. mindspore/rewrite/node/node_manager.py +254 -0
  412. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  413. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  414. mindspore/rewrite/parsers/assign_parser.py +216 -221
  415. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  416. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  417. mindspore/rewrite/parsers/constant_parser.py +9 -6
  418. mindspore/rewrite/parsers/container_parser.py +9 -7
  419. mindspore/rewrite/parsers/for_parser.py +42 -21
  420. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  421. mindspore/rewrite/parsers/if_parser.py +28 -24
  422. mindspore/rewrite/parsers/module_parser.py +196 -25
  423. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  424. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  425. mindspore/rewrite/parsers/return_parser.py +6 -6
  426. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  427. mindspore/rewrite/sparsify/utils.py +1 -1
  428. mindspore/rewrite/symbol_tree.py +523 -578
  429. mindspore/rewrite/symbol_tree_builder.py +9 -193
  430. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  431. mindspore/run_check/_check_version.py +6 -4
  432. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  433. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  434. mindspore/scipy/linalg.py +1 -1
  435. mindspore/scipy/ops.py +55 -5
  436. mindspore/scipy/optimize/__init__.py +3 -2
  437. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  438. mindspore/scipy/optimize/minimize.py +7 -3
  439. mindspore/train/_utils.py +7 -3
  440. mindspore/train/amp.py +323 -123
  441. mindspore/train/anf_ir_pb2.py +14 -2
  442. mindspore/train/callback/_backup_and_restore.py +2 -12
  443. mindspore/train/callback/_callback.py +29 -4
  444. mindspore/train/callback/_checkpoint.py +23 -8
  445. mindspore/train/callback/_early_stop.py +2 -2
  446. mindspore/train/callback/_landscape.py +4 -4
  447. mindspore/train/callback/_loss_monitor.py +2 -2
  448. mindspore/train/callback/_on_request_exit.py +2 -2
  449. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  450. mindspore/train/callback/_summary_collector.py +15 -8
  451. mindspore/train/callback/_time_monitor.py +58 -5
  452. mindspore/train/data_sink.py +5 -11
  453. mindspore/train/dataset_helper.py +84 -57
  454. mindspore/train/loss_scale_manager.py +2 -2
  455. mindspore/train/metrics/__init__.py +3 -3
  456. mindspore/train/metrics/cosine_similarity.py +1 -1
  457. mindspore/train/metrics/hausdorff_distance.py +3 -2
  458. mindspore/train/metrics/mean_surface_distance.py +3 -2
  459. mindspore/train/metrics/metric.py +39 -19
  460. mindspore/train/metrics/roc.py +2 -2
  461. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  462. mindspore/train/mind_ir_pb2.py +85 -36
  463. mindspore/train/model.py +187 -47
  464. mindspore/train/serialization.py +487 -161
  465. mindspore/train/summary/_summary_adapter.py +1 -1
  466. mindspore/train/summary/_writer_pool.py +3 -2
  467. mindspore/train/summary/summary_record.py +37 -17
  468. mindspore/train/train_thor/convert_utils.py +3 -3
  469. mindspore/train/train_thor/dataset_helper.py +1 -1
  470. mindspore/version.py +1 -1
  471. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  474. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  475. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  477. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  478. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  479. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  480. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  481. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  482. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  483. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  485. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  486. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  487. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  488. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  489. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  490. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  491. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  492. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  493. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  494. mindspore/_extends/graph_kernel/expander.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  496. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  497. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  498. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  499. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  500. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  501. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  502. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  503. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  504. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  505. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  507. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  508. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  509. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  510. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  511. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  513. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  514. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  515. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  516. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  517. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  518. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  519. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  520. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  522. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  523. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  524. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  526. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  527. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  528. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  531. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  533. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  534. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  535. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  536. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  537. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  538. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  540. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  541. mindspore/dataset/datapreprocess/__init__.py +0 -20
  542. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  543. mindspore/include/api/net.h +0 -142
  544. mindspore/nn/lr_scheduler.py +0 -262
  545. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  546. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  547. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  548. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  549. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  550. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  559. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  563. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  574. mindspore/rewrite/node_visitor.py +0 -44
  575. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  576. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -15,24 +15,25 @@
15
15
  """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
16
16
  import ast
17
17
  from mindspore import log as logger
18
- from ..parser_register import ParserRegister, reg_parser
19
- from ..parser import Parser
18
+ from .parser_register import ParserRegister, reg_parser
19
+ from .parser import Parser
20
20
  from ..symbol_tree import SymbolTree
21
21
  from ..api.node_type import NodeType
22
+ from ..node.node_manager import NodeManager
22
23
 
23
24
 
24
25
  class FunctionDefParser(Parser):
25
- """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
26
+ """Parse bodies of ast.FunctionDef in SymbolTree."""
26
27
 
27
28
  def target(self):
28
29
  """Parse target type"""
29
30
  return ast.FunctionDef
30
31
 
31
- def remove_dead_code(self, stree: SymbolTree):
32
+ def remove_dead_code(self, node_manager: NodeManager):
32
33
  """Remove dead codes"""
33
34
  # Find out return node position
34
35
  return_idx = -1
35
- for idx, node in enumerate(stree.nodes()):
36
+ for idx, node in enumerate(node_manager.nodes()):
36
37
  if node.get_node_type() == NodeType.Output:
37
38
  return_idx = idx
38
39
  break
@@ -40,29 +41,36 @@ class FunctionDefParser(Parser):
40
41
  return
41
42
  # Remove nodes after return node.
42
43
  # Reverse traversal to ensure that nodes are orphaned and can be deleted.
43
- for idx, node in reversed(list(enumerate(stree.nodes()))):
44
+ for idx, node in reversed(list(enumerate(node_manager.nodes()))):
44
45
  if idx <= return_idx:
45
46
  break
46
47
  logger.info(f"Remove dead code node:{node.get_name()}")
47
- stree.erase_node(node)
48
+ node_manager.erase_node(node)
48
49
 
49
- def process(self, stree: SymbolTree, node: ast.FunctionDef):
50
- """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
51
- stree.set_ast_root(node)
50
+ def process(self, stree: SymbolTree, ast_node: ast.FunctionDef, node_manager: NodeManager):
51
+ """
52
+ Parse bodies of ast.FunctionDef in SymbolTree.
53
+
54
+ Args:
55
+ stree (SymbolTree): symbol tree under parsing.
56
+ ast_node (ast.FunctionDef): Ast FunctionDef node in construct.
57
+ node_manager (NodeManager): NodeManager those asts belong to.
58
+ """
52
59
  # parse args as inputs of stree
53
- arguments: ast.arguments = node.args
60
+ arguments: ast.arguments = ast_node.args
54
61
  parser: Parser = ParserRegister.instance().get_parser(ast.arguments)
55
- parser.process(stree, arguments)
62
+ parser.process(stree, arguments, node_manager)
56
63
 
57
64
  # parse body as node of stree
58
- for body in node.body:
65
+ for body in ast_node.body:
59
66
  # avoid add dead code, so we need to break if return is added.
60
67
  parser: Parser = ParserRegister.instance().get_parser(type(body))
61
68
  if parser is None:
62
- stree.append_python_node(node, body)
69
+ stree.append_python_node(ast_node, body, node_manager)
63
70
  else:
64
- parser.process(stree, body)
65
- self.remove_dead_code(stree)
71
+ parser.process(stree, body, node_manager)
72
+
73
+ self.remove_dead_code(node_manager)
66
74
 
67
75
 
68
76
  g_functiondef_parser = reg_parser(FunctionDefParser())
@@ -15,11 +15,12 @@
15
15
  """Parse ast.If in construct function to node of SymbolTree."""
16
16
 
17
17
  import ast
18
- import astunparse
19
18
 
20
19
  from ..symbol_tree import SymbolTree
21
- from ..parser import Parser
22
- from ..parser_register import reg_parser
20
+ from .parser import Parser
21
+ from .parser_register import ParserRegister, reg_parser
22
+ from ..node import NodeManager, ControlFlow
23
+ from ..ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
23
24
 
24
25
 
25
26
  class IfParser(Parser):
@@ -29,35 +30,38 @@ class IfParser(Parser):
29
30
  """Parse target type"""
30
31
  return ast.If
31
32
 
32
- def process(self, stree: SymbolTree, node: ast.If):
33
+ def process(self, stree: SymbolTree, node: ast.If, node_manager: NodeManager):
33
34
  """
34
- Parse ast.If and create a node in symbol tree.
35
+ Parse ast.If and create nodes into symbol tree.
35
36
 
36
37
  Args:
37
38
  stree ([SymbolTree]): Symbol Tree under parsing.
38
39
  node ([ast.If]): An ast.If node.
40
+ node_manager (NodeManager): NodeManager those asts belong to.
39
41
 
40
42
  Raises:
41
43
  NotImplementedError: If test of ast.If can not be eval.
42
44
  """
45
+ # expand codes in ast.if
46
+ ast_if = FlattenRecursiveStmt().transform_if(node, stree)
47
+ # parse ast codes of if branch into ControlFlow Node
48
+ if_node = ControlFlow("if_node", ast_if.body, stree)
49
+ for body in ast_if.body:
50
+ parser: Parser = ParserRegister.instance().get_parser(type(body))
51
+ if parser is None:
52
+ stree.append_python_node(ast_if, body, node_manager=if_node)
53
+ else:
54
+ parser.process(stree, body, node_manager=if_node)
55
+ stree.append_origin_field(if_node, node_manager)
56
+ # parse ast codes of else branch into ControlFlow Node
57
+ if ast_if.orelse:
58
+ else_node = ControlFlow("else_node", ast_if.orelse, stree)
59
+ for body in ast_if.orelse:
60
+ parser: Parser = ParserRegister.instance().get_parser(type(body))
61
+ if parser is None:
62
+ stree.append_python_node(ast_if, body, node_manager=else_node)
63
+ else:
64
+ parser.process(stree, body, node_manager=else_node)
65
+ stree.append_origin_field(else_node, node_manager)
43
66
 
44
- test_code = astunparse.unparse(node.test)
45
- test_code = test_code.replace("self", "stree.get_origin_network()")
46
- bodies = None
47
- try:
48
- test_value = eval(test_code)
49
- except (NameError, TypeError):
50
- stree.try_append_python_node(node, node)
51
- return
52
-
53
- bodies = node.body if test_value else node.orelse
54
- index = stree.get_ast_root().body.index(node) + 1
55
- info_node = ast.Name(id=f"# If node has been replaced by {bool(test_value)} branch.",
56
- lineno=0, col_offset=0, ctx=ast.Load)
57
- exp_node = ast.Expr(value=info_node, lineno=0, col_offset=0, ctx=ast.Load)
58
- stree.get_ast_root().body.insert(index-1, exp_node)
59
- for body in bodies:
60
- stree.get_ast_root().body.insert(index, body)
61
- index += 1
62
- stree.get_ast_root().body.remove(node)
63
67
  g_if_parser = reg_parser(IfParser())
@@ -13,23 +13,32 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Parse ast.Module to SymbolTrees."""
16
+ import sys
16
17
  from typing import Any
17
18
  import os
18
19
  import ast
19
20
  import copy
20
21
  import inspect
21
- import astunparse
22
22
 
23
23
  from mindspore import log as logger
24
24
  from ..symbol_tree import SymbolTree
25
- from ..parser import Parser
26
- from ..parser_register import ParserRegister, reg_parser
25
+ from .parser import Parser
26
+ from .parser_register import ParserRegister, reg_parser
27
27
  from ..ast_helpers import AstFinder
28
28
  from ..common import error_str
29
+ from ..node.node_manager import NodeManager
29
30
 
31
+ if sys.version_info >= (3, 9):
32
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
33
+ else:
34
+ import astunparse
30
35
 
31
36
  class ModuleParser(Parser):
32
37
  """Parse ast.Module to SymbolTrees."""
38
+
39
+ # a denied_class_decorator_list represents the decorators should be banned, which is registered by user
40
+ denied_class_decorator_list = []
41
+
33
42
  @staticmethod
34
43
  def _find_class(ast_node: ast.Module) -> ast.ClassDef:
35
44
  """Find all ast.ClassDef in ast.Module, only support one ast.ClassDef in ast.Module now."""
@@ -45,18 +54,27 @@ class ModuleParser(Parser):
45
54
  def _get_import_node(ast_root):
46
55
  """Iterate over ast_root and return all ast.Import nodes or ast.ImportFrom nodes in ast_root."""
47
56
  import_nodes = []
57
+ try_nodes = []
58
+ imports_str = []
48
59
 
49
60
  class GetImportNode(ast.NodeVisitor):
50
61
  """Find all import nodes from input ast node."""
51
62
 
63
+ def visit_Try(self, node: ast.Try) -> Any:
64
+ if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
65
+ try_nodes.append(copy.deepcopy(node))
66
+ return node
67
+
52
68
  def visit_Import(self, node: ast.Import) -> Any:
53
69
  """Iterate over all nodes and save ast.Import nodes."""
54
70
  import_nodes.append(copy.deepcopy(node))
71
+ imports_str.append(astunparse.unparse(node))
55
72
  return node
56
73
 
57
74
  def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
58
75
  """Iterate over all nodes and save ast.ImportFrom nodes."""
59
76
  import_nodes.append(copy.deepcopy(node))
77
+ imports_str.append(astunparse.unparse(node))
60
78
  return node
61
79
 
62
80
  def get_node(self, input_ast):
@@ -64,19 +82,145 @@ class ModuleParser(Parser):
64
82
  self.generic_visit(input_ast)
65
83
  return True
66
84
 
85
+ def _remove_duplicated_import_in_try(node: [ast.Import, ast.ImportFrom]):
86
+ import_str = astunparse.unparse(node)
87
+ if import_str in imports_str:
88
+ import_nodes.remove(import_nodes[imports_str.index(import_str)])
89
+
67
90
  get_node_handler = GetImportNode()
68
91
  get_node_handler.get_node(ast_root)
92
+ for Try in try_nodes:
93
+ for body in Try.body:
94
+ _remove_duplicated_import_in_try(body)
95
+ for handler in Try.handlers:
96
+ for body in handler.body:
97
+ _remove_duplicated_import_in_try(body)
98
+ import_nodes.extend(try_nodes)
69
99
  return import_nodes
70
100
 
71
101
  @staticmethod
72
- def _add_import_to_module(module: ast.Module, origin_net):
102
+ def save_file_path_to_sys(stree, level_num, file_path):
103
+ """
104
+ Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
105
+
106
+ When level_num = 0(e.g. from xxx import yyy), current path will be saved.
107
+ When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
108
+ When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
109
+ """
110
+ file_path = os.path.dirname(os.path.abspath(file_path))
111
+ if level_num > 1:
112
+ for _ in range(level_num - 1):
113
+ file_path = os.path.dirname(file_path)
114
+ sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
115
+ stree.get_import_asts().append(ast.Import([ast.alias(name='sys', asname=None)]))
116
+ stree.get_import_asts().append(sys_path_append_ast)
117
+
118
+ @staticmethod
119
+ def _save_imports(stree):
73
120
  """Insert two groups of import nodes to ast.Module, common ones and those from class definition file."""
74
- module.body.insert(0, ast.Import([ast.alias(name='mindspore', asname=None)]))
75
- module.body.insert(1, ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
76
- module.body.insert(2, ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)],
77
- level=0))
78
- module.body.insert(3, ast.ImportFrom(module='mindspore.ops', names=[ast.alias(name='functional', asname='F')],
79
- level=0))
121
+ stree.get_import_asts().append(ast.Import([ast.alias(name='mindspore', asname=None)]))
122
+ stree.get_import_asts().append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)],
123
+ level=0))
124
+ stree.get_import_asts().append(ast.ImportFrom(module='mindspore.nn',
125
+ names=[ast.alias(name='Cell', asname=None)], level=0))
126
+ stree.get_import_asts().append(ast.ImportFrom(module='mindspore.ops',
127
+ names=[ast.alias(name='functional', asname='F')], level=0))
128
+ origin_net = stree.get_origin_network()
129
+ net_path = inspect.getfile(type(origin_net))
130
+ ModuleParser.save_file_path_to_sys(stree, 0, net_path)
131
+ ModuleParser.save_imports_from_file(stree, net_path)
132
+
133
+ @staticmethod
134
+ def get_valid_import_info(import_node, file_path):
135
+ """Get valid import info while import_node.module is at form of relative path"""
136
+ # copy to a new node to avoid origin import_node being modified.
137
+ import_node_test = copy.deepcopy(import_node)
138
+ file_path = os.path.dirname(os.path.abspath(file_path))
139
+ # get real path from import_node.level
140
+ # from .(A) import xxx: current path
141
+ # from ..(A) import xxx: last level path
142
+ import_node_module_name = import_node.module
143
+ level = import_node.level
144
+ # from A import xxx: it does not need to pad, directly return the module name
145
+ if level == 0:
146
+ return import_node_module_name, None
147
+ if level > 1:
148
+ for _ in range(level - 1):
149
+ file_path = os.path.dirname(file_path)
150
+ file_path_tmp = file_path[:]
151
+ max_level_count = file_path.count('/') + file_path.count('\\') - 1
152
+ level_count = 0
153
+ # suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
154
+ suffix = ''
155
+ if import_node_module_name:
156
+ suffix = '.' + import_node_module_name
157
+ while level_count < max_level_count:
158
+ file_path_tmp = os.path.dirname(file_path_tmp)
159
+ import_node_test.module = file_path[len(file_path_tmp) + 1:].replace('/', '.') + suffix
160
+ import_node_test.level = 0
161
+ import_code = astunparse.unparse(import_node_test).strip()
162
+ test_code = f"import sys\nsys.path.insert(0, r'{file_path_tmp}')\n{import_code}"
163
+ try:
164
+ exec(test_code) # pylint: disable=W0122
165
+ except (ValueError, ImportError) as e:
166
+ # try upper level to avoid ValueError: attempted relative import beyond top-level package
167
+ # this exception is changed to ImportError after python3.9
168
+ logger.info(f"For MindSpore Rewrite, in module parser, test import code: "
169
+ f"{import_code} failed: {e}. Try upper level.")
170
+ level_count += 1
171
+ continue
172
+ except Exception as e: # pylint: disable=W0703
173
+ logger.info(f"For MindSpore Rewrite, in module parser, process import code: "
174
+ f"{import_code} failed: {e}. Ignore this import code.")
175
+ return None, None
176
+ else:
177
+ # try test code success
178
+ return import_node_test.module, file_path_tmp
179
+ # try codes with all level failed
180
+ logger.info(f"For MindSpore Rewrite, in module parser, test import code: "
181
+ f"{astunparse.unparse(import_node).strip()} failed. Ignore this import code.")
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def save_imports_from_file(stree, file_path):
186
+ """Save imports from file"""
187
+ if not os.path.exists(file_path):
188
+ raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
189
+ try:
190
+ with open(file_path, "r", encoding="utf-8") as f:
191
+ source_code = f.read()
192
+ import_nodes = ModuleParser._get_import_node(ast.parse(source_code))
193
+ except RuntimeError as err:
194
+ raise RuntimeError(f"For MindSpore Rewrite, in module parser, get import nodes error: {err}")
195
+ if not import_nodes:
196
+ return
197
+ for import_node in import_nodes:
198
+ import_node = ModuleParser._process_relative_import(stree, import_node, file_path)
199
+ if import_node:
200
+ stree.get_import_asts().append(import_node)
201
+
202
+ @staticmethod
203
+ def _process_relative_import(stree, import_node, file_path):
204
+ """Process relative imports"""
205
+ if isinstance(import_node, ast.ImportFrom):
206
+ # pad the ImportFrom with parent path
207
+ # e.g. from ..C import xxx -> from A.B.C import xxx
208
+ import_module, import_path = ModuleParser.get_valid_import_info(import_node, file_path)
209
+ if import_path:
210
+ ModuleParser.save_file_path_to_sys(stree, 0, import_path)
211
+ module_name_list = [alias.name.strip() for alias in import_node.names]
212
+ # add the module into _imported_modules to direct the class
213
+ stree.save_imported_modules(file_path, import_module, module_name_list)
214
+ import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
215
+ elif isinstance(import_node, ast.Import):
216
+ for alias in import_node.names:
217
+ name = alias.name
218
+ stree.save_imported_modules(file_path, name.strip(), [])
219
+ return import_node
220
+
221
+ @staticmethod
222
+ def _add_decorator_to_class(class_ast: ast.ClassDef, origin_net):
223
+ """Add decorators to class"""
80
224
  origin_net_source_code_file = inspect.getfile(type(origin_net))
81
225
  if not os.path.exists(origin_net_source_code_file):
82
226
  raise RuntimeError("For MindSpore Rewrite, in module parser, File ", origin_net_source_code_file,
@@ -84,35 +228,62 @@ class ModuleParser(Parser):
84
228
  try:
85
229
  with open(origin_net_source_code_file, "r", encoding="utf-8") as f:
86
230
  source_code = f.read()
87
- import_nodes = ModuleParser._get_import_node(ast.parse(source_code))
231
+ decorators = ModuleParser._get_decorator(ast.parse(source_code), origin_net)
88
232
  except RuntimeError:
89
- raise RuntimeError("For MindSpore Rewrite, in module parser, get import nodes error")
90
- if import_nodes:
91
- for import_index, import_node in enumerate(import_nodes):
92
- module.body.insert(import_index + 4, import_node)
93
- ast.fix_missing_locations(module)
233
+ raise RuntimeError("For MindSpore Rewrite, in module parser, get decorators error")
234
+ if decorators:
235
+ for decorator_index, decorator_node in enumerate(decorators):
236
+ class_ast.decorator_list.insert(decorator_index, decorator_node)
237
+ ast.fix_missing_locations(class_ast)
94
238
 
95
239
  @staticmethod
96
- def _save_net_file_path(stree: SymbolTree):
97
- origin_net_file = inspect.getfile(type(stree.get_origin_network()))
98
- file_full_path = os.path.abspath(origin_net_file)
99
- file_path, _ = os.path.split(file_full_path)
100
- stree.append_net_file_path(file_path)
240
+ def _get_decorator(ast_root, origin_net):
241
+ """Get the decorators of function"""
242
+ net_name = type(origin_net).__name__
243
+ decorators = []
244
+
245
+ class GetClassNode(ast.NodeVisitor):
246
+ """Find the class node from input ast node."""
247
+ def visit_ClassDef(self, node: ast.ClassDef) -> Any:
248
+ """Visit the class node and add the decorators to class node"""
249
+ if node.name == net_name:
250
+ for decorator in node.decorator_list[:]:
251
+ decorator_name = ""
252
+ if isinstance(decorator, ast.Call):
253
+ func = decorator.func
254
+ if isinstance(func, ast.Name):
255
+ decorator_name = func.id
256
+ elif isinstance(decorator, ast.Name):
257
+ decorator_name = decorator.id
258
+ # User should set the denied class_decorator,
259
+ # because the symbol_tree cant pass the correct parameters to decorators but the instance "obj".
260
+ if decorator_name not in ModuleParser.denied_class_decorator_list:
261
+ decorators.append(decorator)
262
+ return node
263
+
264
+ def get_node(self, input_ast):
265
+ """Interface of GetClassNode."""
266
+ self.generic_visit(input_ast)
267
+ return True
268
+
269
+ get_node_handler = GetClassNode()
270
+ get_node_handler.get_node(ast_root)
271
+ return decorators
101
272
 
102
273
  def target(self):
103
274
  """Parse target type"""
104
275
  return ast.Module
105
276
 
106
- def process(self, stree: SymbolTree, node: ast.Module):
277
+ def process(self, stree: SymbolTree, node: ast.Module, node_manager: NodeManager):
107
278
  """Process ast.ClassDef nodes in ast.Module."""
108
- ModuleParser._add_import_to_module(node, stree.get_origin_network())
279
+ ModuleParser._save_imports(stree)
109
280
  class_ast = ModuleParser._find_class(node)
281
+ ModuleParser._add_decorator_to_class(class_ast, stree.get_origin_network())
110
282
  stree.set_class_ast(class_ast)
111
- ModuleParser._save_net_file_path(stree)
112
283
  for body in node.body:
113
284
  if isinstance(body, ast.ClassDef):
114
285
  parser: Parser = ParserRegister.instance().get_parser(ast.ClassDef)
115
- parser.process(stree, body)
286
+ parser.process(stree, body, stree)
116
287
  else:
117
288
  logger.info(f"For MindSpore Rewrite, in module parser, Ignoring unsupported "
118
289
  f"node({astunparse.unparse(body)}) in ast.Module.")
@@ -16,7 +16,8 @@
16
16
  import abc
17
17
  import ast
18
18
 
19
- from .symbol_tree import SymbolTree
19
+ from ..symbol_tree import SymbolTree
20
+ from ..node.node_manager import NodeManager
20
21
 
21
22
 
22
23
  class Parser(abc.ABC):
@@ -34,12 +35,13 @@ class Parser(abc.ABC):
34
35
  return type(None)
35
36
 
36
37
  @abc.abstractmethod
37
- def process(self, stree: SymbolTree, node: ast.AST):
38
+ def process(self, stree: SymbolTree, node: ast.AST, node_manager: NodeManager):
38
39
  """
39
40
  Parse input ast node and add parse result into SymbolTree.
40
41
 
41
42
  Args:
42
43
  stree (SymbolTree): current symbol_tree
43
44
  node (ast.AST): node who is tried to be parsed
45
+ node_manager (NodeManager): NodeManager those asts belong to.
44
46
  """
45
47
  raise NotImplementedError
@@ -45,7 +45,7 @@ class ParserRegister:
45
45
  parser (Parser): An instance of Parser to be registered.
46
46
  """
47
47
  if isinstance(parser, Parser):
48
- ParserRegister.instance()._parsers[parser.target()] = parser
48
+ ParserRegister.instance().get_parsers()[parser.target()] = parser
49
49
 
50
50
  def get_parser(self, ast_type: type) -> Optional[Parser]:
51
51
  """
@@ -16,9 +16,10 @@
16
16
  import ast
17
17
 
18
18
  from ..symbol_tree import SymbolTree
19
- from ..node import Node
20
- from ..parser import Parser
21
- from ..parser_register import reg_parser
19
+ from ..node.node import Node
20
+ from ..node.node_manager import NodeManager
21
+ from .parser import Parser
22
+ from .parser_register import reg_parser
22
23
  from ..common import error_str
23
24
 
24
25
 
@@ -29,14 +30,13 @@ class ReturnParser(Parser):
29
30
  """Parse target type"""
30
31
  return ast.Return
31
32
 
32
- def process(self, stree: SymbolTree, node: ast.Return):
33
+ def process(self, stree: SymbolTree, node: ast.Return, node_manager: NodeManager):
33
34
  """Parse ast.Return to output-node of SymbolTree."""
34
35
  return_value = node.value
35
36
  if not isinstance(return_value, ast.Name):
36
37
  raise RuntimeError(error_str(f"only support ast.Name as return value, but got ast type "
37
38
  f"'{type(return_value).__name__}'", father_node=node, child_node=return_value))
38
39
  node_return = Node.create_output_node(node, [return_value.id])
39
- stree.append_origin_field(node_return)
40
-
40
+ stree.append_origin_field(node_return, node_manager)
41
41
 
42
42
  g_return_parser = reg_parser(ReturnParser())
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Sparsify transformer"""
16
+ import sys
16
17
  import ast
17
18
  import inspect
18
19
  import textwrap
19
20
  from collections import deque
20
- import astunparse
21
21
 
22
22
  from mindspore import ops, nn
23
23
  from mindspore import log as logger
@@ -25,6 +25,10 @@ from mindspore.rewrite.parsers.assign_parser import AssignParser
25
25
  from mindspore.rewrite.sparsify.utils import ArgType, SparseFunc, sparse_rules, get_sparse_func, builtin_ops, \
26
26
  get_binop_name, get_sparse_method_outputs, arg_type_to_prefix_map, get_inputs_outputs
27
27
 
28
+ if sys.version_info >= (3, 9):
29
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
30
+ else:
31
+ import astunparse
28
32
 
29
33
  OPS_MODULE = "mindspore.ops."
30
34
  MAX_RECURSION_DEPTH = 10
@@ -61,8 +65,13 @@ def sparsify_helper(f, arg_types, user_defined_rules=None, sparse_name="", full_
61
65
 
62
66
  if changed:
63
67
  sparse_tree = list(x[0] for x in sparse_transformer.sparse_functiondef.values()) + sparse_tree
64
- ast_module = ast.Module([ast.FunctionDef(
65
- sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)])
68
+ if sys.version_info >= (3, 9):
69
+ ast_module = ast.Module([ast.FunctionDef(
70
+ sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)],
71
+ type_ignores=[])
72
+ else:
73
+ ast_module = ast.Module([ast.FunctionDef(
74
+ sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)])
66
75
  return ast_module, True, return_types
67
76
  return tree, False, return_types
68
77
 
@@ -176,4 +176,4 @@ def get_binop_name(binop):
176
176
  return "*"
177
177
  if binop == ast.Div():
178
178
  return "/"
179
- return None
179
+ return ""