mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Ast utils for create or update ast node."""
16
- from typing import Optional
16
+ from typing import Optional, List
17
17
  import ast
18
18
 
19
19
  from ..api.scoped_value import ScopedValue, ValueType
@@ -34,20 +34,15 @@ class AstModifier(ast.NodeTransformer):
34
34
  Returns:
35
35
  A bool if to_erase-node been found and been erased.
36
36
  """
37
- for body in ast_func.body:
37
+ return AstModifier.erase_ast_from_bodies(ast_func.body, to_erase)
38
+
39
+ @staticmethod
40
+ def erase_ast_from_bodies(ast_bodies: List[ast.AST], to_erase: ast.AST) -> bool:
41
+ """Erase ast node from ast bodies."""
42
+ for body in ast_bodies:
38
43
  if id(body) == id(to_erase):
39
- ast_func.body.remove(body)
44
+ ast_bodies.remove(body)
40
45
  return True
41
- # hardcode for ast.If
42
- if isinstance(body, ast.If):
43
- for if_body in body.body:
44
- if id(if_body) == id(to_erase):
45
- body.body.remove(if_body)
46
- return True
47
- for else_body in body.orelse:
48
- if id(else_body) == id(to_erase):
49
- body.orelse.remove(else_body)
50
- return True
51
46
  return False
52
47
 
53
48
  @staticmethod
@@ -147,13 +142,6 @@ class AstModifier(ast.NodeTransformer):
147
142
  RuntimeError: If 'index_ast' is not contained in 'ast_func'.
148
143
  """
149
144
  assign = AstModifier.create_call_assign(targets, expr, args, kwargs)
150
- arguments: ast.arguments = ast_func.args
151
- if arguments.args:
152
- for arg in arguments.args:
153
- if id(arg) == id(index_ast):
154
- ast_func.body.insert(0, assign)
155
- ast.fix_missing_locations(ast_func)
156
- return assign
157
145
  return AstModifier.insert_assign_ast_to_function(ast_func, assign, index_ast, insert_before)
158
146
 
159
147
  @staticmethod
@@ -177,49 +165,63 @@ class AstModifier(ast.NodeTransformer):
177
165
  Raises:
178
166
  RuntimeError: If 'index_ast' is not contained in 'ast_func'.
179
167
  """
180
- if index_ast is None:
181
- ast_func.body.append(ast_assign)
182
- ast.fix_missing_locations(ast_func)
183
- return ast_assign
168
+ # Insert ast at the frontmost position of function body when index_ast is an argument of function
184
169
  arguments: ast.arguments = ast_func.args
185
- if arguments.args:
170
+ if index_ast and arguments.args:
186
171
  for arg in arguments.args:
187
172
  if id(arg) == id(index_ast):
188
173
  ast_func.body.insert(0, ast_assign)
189
174
  ast.fix_missing_locations(ast_func)
190
175
  return ast_assign
191
- for index in range(0, len(ast_func.body)):
192
- body = ast_func.body[index]
176
+ # Insert ast at position specified by index_ast in function body
177
+ ast_assign = AstModifier.insert_assign_ast_to_bodies(ast_func.body, ast_assign, index_ast, insert_before)
178
+ ast.fix_missing_locations(ast_assign)
179
+ return ast_assign
180
+
181
+ @staticmethod
182
+ def insert_assign_ast_to_bodies(ast_bodies: List[ast.AST], ast_assign: ast.Assign,
183
+ index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST:
184
+ """Insert ast at position specified by index_ast of ast_bodies"""
185
+ # Append ast_assign to ast_bodies when index_ast is None
186
+ if index_ast is None:
187
+ ast_bodies.append(ast_assign)
188
+ return ast_assign
189
+ # Append ast_assign to ast_bodies
190
+ for index, body in enumerate(ast_bodies):
193
191
  if id(body) == id(index_ast):
194
- if insert_before:
195
- ast_func.body.insert(index, ast_assign)
196
- else:
197
- ast_func.body.insert(index + 1, ast_assign)
198
- ast.fix_missing_locations(ast_func)
199
- return ast_assign
200
- # hardcode for ast.If
201
- if isinstance(body, ast.If):
202
- for if_index in range(0, len(body.body)):
203
- if_body = body.body[if_index]
204
- if id(if_body) != id(index_ast):
205
- continue
206
- if insert_before:
207
- body.body.insert(if_index, ast_assign)
208
- else:
209
- body.body.insert(if_index + 1, ast_assign)
210
- ast.fix_missing_locations(body)
211
- return ast_assign
212
- for if_index in range(0, len(body.orelse)):
213
- else_body = body.orelse[if_index]
214
- if id(else_body) != id(index_ast):
215
- continue
216
- if insert_before:
217
- body.orelse.insert(if_index, ast_assign)
218
- else:
219
- body.orelse.insert(if_index + 1, ast_assign)
220
- ast.fix_missing_locations(body)
221
- return ast_assign
222
- raise RuntimeError("insert position is not contained in ast_func")
192
+ if not insert_before:
193
+ index += 1
194
+ ast_bodies.insert(index, ast_assign)
195
+ ast.fix_missing_locations(body)
196
+ break
197
+ else:
198
+ raise ValueError("insert position is not contained in ast_bodies")
199
+ return ast_assign
200
+
201
+ @staticmethod
202
+ def append_arg_to_function(ast_func: ast.FunctionDef, ast_arg: ast.arg) -> ast.AST:
203
+ """
204
+ Append an ast.arg to an ast.FunctionDef (e.g. self.construct).
205
+
206
+ Args:
207
+ ast_func (ast.FunctionDef): An instance of ast.FunctionDef which is "construct" function of network.
208
+ ast_arg (ast.arg): An instance of ast.arg to be inserted in.
209
+
210
+ Returns:
211
+ An instance of ast.arg which has been appended to 'ast_func'.
212
+
213
+ Raises:
214
+ RuntimeError: If 'ast_arg' is not an instance of ast_arg.
215
+ """
216
+ if not isinstance(ast_arg, ast.arg):
217
+ raise RuntimeError("ast_arg should be an instance of ast.arg.")
218
+ arguments: ast.arguments = ast_func.args
219
+ args: [ast.arg] = arguments.args
220
+ args.append(ast_arg)
221
+ defaults = arguments.defaults
222
+ arg_default = ast.Constant(value=None, kind=None)
223
+ defaults.append(arg_default)
224
+ return ast_arg
223
225
 
224
226
  @staticmethod
225
227
  def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue],
@@ -241,7 +243,7 @@ class AstModifier(ast.NodeTransformer):
241
243
  An instance of ast.Assign which has been appended to 'init_func'.
242
244
  """
243
245
  return AstModifier.insert_assign_to_function(init_func, targets=targets,
244
- expr=ScopedValue(ValueType.NamingValue, "", "setattr"),
246
+ expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
245
247
  args=[ScopedValue(ValueType.NamingValue, "obj"),
246
248
  ScopedValue.create_variable_value(field)])
247
249
 
@@ -265,23 +267,30 @@ class AstModifier(ast.NodeTransformer):
265
267
  RuntimeError: If 'targets' is None.
266
268
  RuntimeError: If value_type of element of 'targets' is not ValueType.NamingValue.
267
269
 
268
- RuntimeError: If length of 'targets' is not 1. Multi-targets will be support in the future.
269
270
  """
270
- if targets is None or len(targets) != 1:
271
- raise RuntimeError("Only support one target in insert_cell_to_init now")
272
- if targets[0].type != ValueType.NamingValue:
273
- raise RuntimeError("Target must be a right-value, got: ", targets[0])
274
- if targets[0].scope:
275
- ast_target = ast.Attribute(ast.Name(targets[0].scope, ast.Load()), targets[0].value, ast.Store())
276
- else:
277
- ast_target = ast.Name(targets[0].value, ast.Store())
271
+ if targets is None:
272
+ raise RuntimeError("'Targets should not be None.")
273
+ targets_list = []
274
+ for target in targets:
275
+ if target.type != ValueType.NamingValue:
276
+ raise RuntimeError("Target must be a right-value, got: ", target)
277
+ if target.scope:
278
+ ast_target = ast.Attribute(ast.Name(target.scope, ast.Load()), target.value, ast.Store())
279
+ else:
280
+ ast_target = ast.Name(target.value, ast.Store())
281
+ targets_list.append(ast_target)
278
282
  call = AstModifier.create_call(expr, args, kwargs)
279
- result = ast.Assign(targets=[ast_target], value=call)
283
+
284
+ if len(targets) == 1:
285
+ result = ast.Assign(targets=[targets_list[0]], value=call)
286
+ elif len(targets) > 1:
287
+ ast_targets = ast.Tuple(elts=targets_list, ctx=ast.Store())
288
+ result = ast.Assign(targets=[ast_targets], value=call)
280
289
  ast.fix_missing_locations(result)
281
290
  return result
282
291
 
283
292
  @staticmethod
284
- def _create_arg_by_single_value(value: ScopedValue):
293
+ def _create_arg_by_constant_value(value: ScopedValue):
285
294
  """
286
295
  Create an instance of ast.Constant.
287
296
 
@@ -290,17 +299,16 @@ class AstModifier(ast.NodeTransformer):
290
299
 
291
300
  Raises:
292
301
  RuntimeError: if scope of value is not empty.
293
- TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]
302
+ TypeError: type of arg is not ValueType.ConstantValue
294
303
 
295
304
  Returns:
296
305
  ast.Constant: An instance of ast.Constant
297
306
  """
298
- if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
307
+ if value.type == ValueType.ConstantValue:
299
308
  if value.scope:
300
309
  raise RuntimeError("For arg the scope should be empty")
301
310
  return ast.Constant(value=value.value, kind=None)
302
- raise TypeError("Type of arg only support [ValueType.IntValue, ValueType.FloatValue,"
303
- f" ValueType.StringValue], but got {type(value)}")
311
+ raise TypeError("Type of arg only support ValueType.ConstantValue, but got {type(value)}")
304
312
 
305
313
  @staticmethod
306
314
  def _create_list_or_tuple(value: ScopedValue):
@@ -315,7 +323,7 @@ class AstModifier(ast.NodeTransformer):
315
323
  """
316
324
  elts = []
317
325
  for v in value.value:
318
- elts.append(AstModifier._create_arg_by_single_value(v))
326
+ elts.append(AstModifier._create_arg_by_constant_value(v))
319
327
  if isinstance(value, list):
320
328
  return ast.List(elts=elts)
321
329
  return ast.Tuple(elts=elts)
@@ -331,22 +339,20 @@ class AstModifier(ast.NodeTransformer):
331
339
 
332
340
  Raises:
333
341
  RuntimeError: if scope of value is not empty.
334
- TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
335
- ValueType.ListValue, ValueType.TupleValue]
342
+ TypeError: type of arg is not ValueType.ConstantValue
336
343
 
337
344
  Returns:
338
345
  ast.keyword: a instance of ast.keyword.
339
346
  """
340
347
  if value.scope:
341
348
  raise RuntimeError("value.scope should be empty")
342
- if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
349
+ if value.type == ValueType.ConstantValue:
343
350
  v = ast.Constant(value=value.value, kind=None)
344
351
  elif value.type in (ValueType.ListValue, ValueType.TupleValue):
345
352
  v = AstModifier._create_list_or_tuple(value)
346
353
  else:
347
- raise TypeError("Type of keyword value only support [ValueType.IntValue, ValueType.FloatValue,"
348
- "ValueType.StringValue, ValueType.ListValue, ValueType.TupleValue],"
349
- f"but got {type(value)}")
354
+ raise TypeError("Type of keyword value only support [ValueType.ConstantValue, ValueType.ListValue, "
355
+ f"ValueType.TupleValue], but got {type(value)}")
350
356
  return ast.keyword(arg=arg, value=v)
351
357
 
352
358
  @staticmethod
@@ -371,14 +377,14 @@ class AstModifier(ast.NodeTransformer):
371
377
  for arg in args:
372
378
  if not isinstance(arg, ScopedValue):
373
379
  raise TypeError("arg should be ScopedValue, got: ", type(arg))
374
- if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
375
- results.append(AstModifier._create_arg_by_single_value(arg))
380
+ if arg.type == ValueType.ConstantValue:
381
+ results.append(AstModifier._create_arg_by_constant_value(arg))
376
382
  elif arg.type == ValueType.NamingValue:
377
383
  if arg.scope:
378
384
  results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
379
385
  else:
380
386
  results.append(ast.Name(arg.value, ast.Store()))
381
- elif arg.type == ValueType.ListValue or arg.type == ValueType.TupleValue:
387
+ elif arg.type in (ValueType.ListValue, ValueType.TupleValue):
382
388
  results.append(AstModifier._create_list_or_tuple(arg))
383
389
  else:
384
390
  raise RuntimeError("Please handle custom-object first")
@@ -406,8 +412,7 @@ class AstModifier(ast.NodeTransformer):
406
412
  for arg, value in kwargs.items():
407
413
  if not isinstance(value, ScopedValue):
408
414
  raise TypeError("value should be ScopedValue, got: ", type(value))
409
- if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
410
- ValueType.ListValue, ValueType.TupleValue):
415
+ if value.type in (ValueType.ConstantValue, ValueType.ListValue, ValueType.TupleValue):
411
416
  results.append(AstModifier._create_keyword(arg, value))
412
417
  elif value.type == ValueType.NamingValue:
413
418
  if value.scope:
@@ -466,7 +471,7 @@ class AstModifier(ast.NodeTransformer):
466
471
  Raises:
467
472
  TypeError: Input src_argument is not a ScopedValue
468
473
  RuntimeError: If 'dst_ast' is an instance of ast.Constant but type of 'src_argument' is not
469
- ValueType.IntValue, ValueType.FloatValue or ValueType.StringValue.
474
+ ValueType.ConstantValue.
470
475
  RuntimeError: If 'dst_ast' is an instance of ast.Name or ast.Attribute but type of 'src_argument' is not
471
476
  ValueType.NamingValue.
472
477
  RuntimeError: When 'dst_ast' is an instance of ast.Name, scope of 'src_argument' is not empty.
@@ -480,27 +485,14 @@ class AstModifier(ast.NodeTransformer):
480
485
  """
481
486
  if not isinstance(src_argument, ScopedValue):
482
487
  raise TypeError("src_argument should be ScopedValue, got: ", type(src_argument))
483
- if isinstance(dst_ast, ast.Constant):
484
- if src_argument.type not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]:
485
- raise RuntimeError("src_argument should be a IntValue, FloatValue or StringValue, got:",
486
- str(src_argument.type))
487
- dst_ast.value = src_argument.value
488
- return
489
- if isinstance(dst_ast, ast.Num):
490
- if src_argument.type not in [ValueType.IntValue, ValueType.FloatValue]:
491
- raise RuntimeError("src_argument should be a IntValue or FloatValue, but got:",
492
- str(src_argument.type))
493
- dst_ast.n = src_argument.value
494
- return
495
- if isinstance(dst_ast, ast.Str):
496
- if src_argument.type not in [ValueType.StringValue]:
497
- raise RuntimeError("src_argument should be a StringValue, but got:",
498
- str(src_argument.type))
499
- dst_ast.s = src_argument.value
488
+ if isinstance(dst_ast, (ast.Constant, ast.Num, ast.Str)):
489
+ AstModifier.update_arg_value_constant(src_argument, dst_ast)
500
490
  return
501
491
  if isinstance(dst_ast, ast.Name):
502
- if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
503
- raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
492
+ if src_argument.type not in [ValueType.NamingValue, ValueType.ConstantValue]\
493
+ or not isinstance(src_argument.value, str):
494
+ raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.ConstantValue, "
495
+ "but got:", type(src_argument.value).__name__)
504
496
  if src_argument.scope:
505
497
  raise RuntimeError("src_argument.scope should be empty")
506
498
  dst_ast.id = src_argument.value
@@ -523,3 +515,23 @@ class AstModifier(ast.NodeTransformer):
523
515
  AstModifier.update_arg_value(src_argument.value[elt_index], elt)
524
516
  return
525
517
  raise RuntimeError("keyword value type is not supported", type(dst_ast))
518
+
519
+ @staticmethod
520
+ def update_arg_value_constant(src_argument: ScopedValue, dst_ast: ast.AST):
521
+ """Update 'arg_value' of type constant by 'input_argument'"""
522
+ if src_argument.type != ValueType.ConstantValue:
523
+ raise RuntimeError("src_argument should be a ConstantValue, got:", str(src_argument.type))
524
+ if isinstance(dst_ast, ast.Constant):
525
+ dst_ast.value = src_argument.value
526
+ return
527
+ if isinstance(dst_ast, ast.Num):
528
+ if not isinstance(src_argument.value, (int, float)):
529
+ raise RuntimeError("Type of src_argument should be int or float, but got:",
530
+ type(src_argument.value).__name__)
531
+ dst_ast.n = src_argument.value
532
+ return
533
+ if isinstance(dst_ast, ast.Str):
534
+ if not isinstance(src_argument.value, str):
535
+ raise RuntimeError("Type of src_argument should be str, but got:", type(src_argument.value).__name__)
536
+ dst_ast.s = src_argument.value
537
+ return
@@ -14,14 +14,20 @@
14
14
  # ============================================================================
15
15
  """Ast optimizer for flatten recursive call."""
16
16
 
17
- from typing import Any, Tuple
17
+ import sys
18
+ from typing import Any, Tuple, List
19
+ import keyword
18
20
  import ast
19
- from ast import FunctionDef
20
- import astunparse
21
21
 
22
22
  from mindspore import log as logger
23
23
  from ..common import error_str
24
24
 
25
+ if sys.version_info >= (3, 9):
26
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
27
+ else:
28
+ import astunparse
29
+
30
+ FLATTEN_BLACK_LIST = ["set_vertex_attr",]
25
31
 
26
32
  class FlattenRecursiveStmt(ast.NodeTransformer):
27
33
  """Ast optimizer for flatten recursive call."""
@@ -40,17 +46,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
40
46
  ast.BoolOp: ["values"],
41
47
  ast.UnaryOp: ["operand"],
42
48
  ast.Compare: ["left", "comparators"],
49
+ ast.If: ["test"]
43
50
  }
51
+ self._transform_functions = []
52
+ self._transform_if = False
53
+ self._symbol_tree = None # Used to get unique name
44
54
 
45
55
  @staticmethod
46
- def _generate_target_name(node: ast.AST, target_names):
56
+ def _check_flatten_black_list(node: ast.AST):
57
+ """Check whether node in flatten black list"""
58
+ func_name = ""
59
+ # Get func name of node
60
+ if isinstance(node, ast.Call):
61
+ if isinstance(node.func, ast.Name):
62
+ func_name = node.func.id
63
+ elif isinstance(node.func, ast.Attribute):
64
+ func_name = node.func.attr
65
+ # Check func name of node
66
+ if func_name and func_name in FLATTEN_BLACK_LIST:
67
+ return True
68
+ return False
69
+
70
+ def _generate_target_name(self, node: ast.AST, target_names):
47
71
  """Generate unique target name."""
48
72
  if isinstance(node, ast.Call):
49
73
  func = node.func
50
74
  if isinstance(func, ast.Name):
51
- target_name = func.id
75
+ target_name = func.id + "_var"
52
76
  elif isinstance(func, ast.Attribute):
53
- target_name = func.attr
77
+ target_name = func.attr + "_var"
54
78
  else:
55
79
  logger.info("unhandled type of func of ast.Call while generating new target name: %s ", type(func))
56
80
  target_name = "function"
@@ -67,30 +91,33 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
67
91
  else:
68
92
  logger.info("unhandled type of node while generating new target name: %s ", type(node))
69
93
  target_name = type(node).__name__.lower() + "_var"
94
+ # avoid python keyword
95
+ if keyword.iskeyword(target_name):
96
+ target_name = target_name + "_var"
70
97
  suffix = 0
71
98
  result = target_name
72
99
  while result in target_names:
73
100
  suffix += 1
74
101
  result = f"{target_name}_{suffix}"
102
+ if self._symbol_tree:
103
+ result = self._symbol_tree.unique_name(result)
75
104
  target_names.append(result)
76
105
  return result
77
106
 
78
- @staticmethod
79
- def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
107
+ def _create_new_assign_node(self, node: ast.AST, target_names) -> Tuple[str, ast.AST]:
80
108
  """Create new assign node to be inserted into ast.FunctionDef."""
81
109
  if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)):
82
110
  return "", node
83
- new_target_name = FlattenRecursiveStmt._generate_target_name(node, target_names)
111
+ new_target_name = self._generate_target_name(node, target_names)
84
112
  return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
85
113
 
86
- @staticmethod
87
- def _flatten_list(node_list, target_names):
114
+ def _flatten_list(self, node_list, target_names):
88
115
  """Flatten a list of node."""
89
116
  results = list()
90
117
  new_list = list()
91
118
  for node in node_list:
92
119
  if isinstance(node, ast.Call):
93
- new_target, new_node = FlattenRecursiveStmt._create_new_assign_node(node, target_names)
120
+ new_target, new_node = self._create_new_assign_node(node, target_names)
94
121
  results.append(new_node)
95
122
  new_list.append(ast.Name(id=new_target, ctx=ast.Load()))
96
123
  else:
@@ -99,6 +126,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
99
126
 
100
127
  def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
101
128
  """Flatten recursive statement according to different node type."""
129
+ if FlattenRecursiveStmt._check_flatten_black_list(node):
130
+ return []
102
131
  flatten_config = self._flatten_table.get(type(node))
103
132
  if flatten_config is None:
104
133
  return []
@@ -112,21 +141,21 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
112
141
  if isinstance(todo, ast.Starred):
113
142
  new_list.append(todo)
114
143
  continue
115
- new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todo, target_names)
144
+ new_target_name, new_node = self._create_new_assign_node(todo, target_names)
116
145
  if id(new_node) == id(todo):
117
146
  new_list.append(todo)
118
147
  else:
119
148
  new_list.append(ast.Name(id=new_target_name, ctx=ast.Load()))
120
149
  results.append(new_node)
121
150
  if isinstance(todo, (ast.Tuple, tuple)):
122
- _res, _new_list = FlattenRecursiveStmt._flatten_list(new_node.value.elts, [new_target_name])
151
+ _res, _new_list = self._flatten_list(new_node.value.elts, [new_target_name])
123
152
  new_node.value.elts = _new_list
124
153
  results.extend(_res)
125
154
  setattr(node, todo_name, new_list)
126
155
  elif isinstance(todos, dict):
127
156
  new_dict = []
128
157
  for key, value in todos:
129
- new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(value, target_names)
158
+ new_target_name, new_node = self._create_new_assign_node(value, target_names)
130
159
  if id(new_node) == id(value):
131
160
  new_dict[key] = value
132
161
  else:
@@ -134,16 +163,15 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
134
163
  results.append(new_node)
135
164
  setattr(node, todo_name, new_dict)
136
165
  else:
137
- new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todos, target_names)
166
+ new_target_name, new_node = self._create_new_assign_node(todos, target_names)
138
167
  if id(new_node) != id(todos):
139
168
  setattr(node, todo_name, ast.Name(id=new_target_name, ctx=ast.Load()))
140
169
  results.append(new_node)
141
170
  return results
142
171
 
143
- def _fill_in_original_target_names(self, target_names, node):
144
- """Fill in original target names before getting unique names."""
145
- for function_index in range(len(node.body)):
146
- child = node.body[function_index]
172
+ def _save_target_names(self, target_names, ast_body: List[ast.AST]):
173
+ """Saving target names in ast_body before getting unique names."""
174
+ for child in ast_body:
147
175
  if isinstance(child, (ast.Assign, ast.Expr)):
148
176
  child_value = child.value
149
177
  else:
@@ -155,7 +183,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
155
183
  continue
156
184
  targets = child.targets
157
185
  for target in targets:
158
- if not isinstance(target, (ast.Name, ast.Tuple)):
186
+ if not isinstance(target, (ast.Name, ast.Tuple, ast.List)):
159
187
  raise RuntimeError(
160
188
  error_str(f"currently only support ast.Name targets, but got ast type "
161
189
  f"'{type(target).__name__}'", child_node=target, father_node=child))
@@ -163,7 +191,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
163
191
  target_name = target.id
164
192
  if target_name not in target_names:
165
193
  target_names.append(target_name)
166
- elif isinstance(target, ast.Tuple):
194
+ elif isinstance(target, (ast.Tuple, ast.List)):
167
195
  for elt in target.elts:
168
196
  if not isinstance(elt, ast.Name):
169
197
  raise RuntimeError(
@@ -174,47 +202,66 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
174
202
  if target_name not in target_names:
175
203
  target_names.append(target_name)
176
204
 
177
- def visit_FunctionDef(self, node: FunctionDef) -> Any:
178
- """Traverse construct node and flatten recursive nodes."""
179
- if node.name != "construct":
180
- return node
181
-
205
+ def _visit_ast_bodies(self, ast_body: List[ast.AST]):
206
+ """Traverse nodes in ast_body and flatten nodes recursive."""
182
207
  target_names = []
183
- self._fill_in_original_target_names(target_names, node)
184
- index = len(node.body) - 1
208
+ self._save_target_names(target_names, ast_body)
209
+ index = len(ast_body) - 1
185
210
  while index >= 0:
186
- child = node.body[index]
211
+ child = ast_body[index]
187
212
  if isinstance(child, ast.Assign):
188
213
  stmt = child.value
189
214
  elif isinstance(child, ast.If):
190
215
  if isinstance(child.body[0], ast.Return) and not isinstance(child.test, ast.UnaryOp):
191
- if isinstance(child.body[0].value, ast.Call):
192
- if_body = child.body
193
- if_func = if_body[0].value
194
- expr = "x = " + astunparse.unparse(if_func)
195
- if_body = ast.parse(expr)
196
- if_body = if_body.body+ast.parse("return x").body
197
- child.body = if_body
198
- stmt = child
199
- else:
200
- stmt = child
201
- else:
202
- stmt = child
216
+ if not isinstance(child.body[0].value, (ast.Name, ast.Constant)):
217
+ return_val_ast = child.body[0].value
218
+ return_name = self._generate_target_name(return_val_ast, target_names)
219
+ new_assign_code = f"{return_name} = {astunparse.unparse(return_val_ast)}"
220
+ new_assign_ast = ast.parse(new_assign_code).body[0]
221
+ new_return_ast = ast.parse(f"return {return_name}").body[0]
222
+ child.body = [new_assign_ast, new_return_ast]
223
+ stmt = child
203
224
  elif isinstance(child, ast.Expr):
204
225
  stmt = child.value
205
226
  else:
206
227
  stmt = child
207
228
  results = self._flatten_statement(stmt, target_names)
208
229
  if results:
209
- results.reverse()
210
- for result in results:
211
- node.body.insert(index, result)
230
+ for result in reversed(results):
231
+ ast_body.insert(index, result)
212
232
  index += 1
213
233
  index -= 1
234
+
235
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name
236
+ """Traverse nodes in _transform_functions and flatten recursive nodes."""
237
+ if node.name not in self._transform_functions:
238
+ return node
239
+ self._visit_ast_bodies(node.body)
240
+ return node
241
+
242
+ def visit_If(self, node: ast.If) -> Any: # pylint: disable=invalid-name
243
+ """Traverse nodes in if node and flatten recursive nodes."""
244
+ if not self._transform_if:
245
+ return node
246
+ self._visit_ast_bodies(node.body)
247
+ if node.orelse:
248
+ self._visit_ast_bodies(node.orelse)
214
249
  return node
215
250
 
216
- def transform(self, ast_root):
251
+ def transform(self, ast_root, transform_functions=None, stree=None):
217
252
  """Interface of FlattenRecursiveStmt."""
253
+ self._transform_functions = transform_functions if transform_functions else ["construct"]
254
+ self._transform_if = False
255
+ self._symbol_tree = stree
218
256
  ast_root = self.visit(ast_root)
219
257
  ast_root = ast.fix_missing_locations(ast_root)
220
258
  return ast_root
259
+
260
+ def transform_if(self, ast_if, stree=None):
261
+ """Interface of FlattenRecursiveStmt."""
262
+ self._transform_functions = []
263
+ self._transform_if = True
264
+ self._symbol_tree = stree
265
+ ast_if = self.visit(ast_if)
266
+ ast_if = ast.fix_missing_locations(ast_if)
267
+ return ast_if
@@ -14,8 +14,12 @@
14
14
  # ============================================================================
15
15
  """Error Log for Rewrite."""
16
16
 
17
+ import sys
17
18
  import ast
18
- import astunparse
19
+ if sys.version_info >= (3, 9):
20
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
21
+ else:
22
+ import astunparse
19
23
 
20
24
 
21
25
  def error_str(reason: str, child_node: ast.expr = None, father_node: ast.expr = None) -> str: