mindspore 2.2.14__cp39-cp39-manylinux1_x86_64.whl → 2.3.0rc1__cp39-cp39-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1154) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -4
  3. mindspore/_akg/akg/composite/build_module.py +155 -11
  4. mindspore/_akg/akg/config/repository.json +38 -0
  5. mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
  7. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
  8. mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
  9. mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
  10. mindspore/_akg/akg/utils/gen_random.py +14 -8
  11. mindspore/_akg/akg/utils/op_dsl.py +11 -0
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +5 -5
  13. mindspore/_c_dataengine.cpython-39-x86_64-linux-gnu.so +0 -0
  14. mindspore/_c_expression.cpython-39-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_mindrecord.cpython-39-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +58 -0
  17. mindspore/_extends/builtin_operations.py +2 -1
  18. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  19. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  20. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  24. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  25. mindspore/_extends/parse/__init__.py +18 -14
  26. mindspore/_extends/parse/compile_config.py +229 -0
  27. mindspore/_extends/parse/parser.py +155 -59
  28. mindspore/_extends/parse/resources.py +40 -7
  29. mindspore/_extends/parse/standard_method.py +124 -204
  30. mindspore/_extends/remote/kernel_build_server.py +2 -0
  31. mindspore/_mindspore_offline_debug.cpython-39-x86_64-linux-gnu.so +0 -0
  32. mindspore/_profiler.py +30 -0
  33. mindspore/amp.py +24 -18
  34. mindspore/bin/cache_admin +0 -0
  35. mindspore/bin/cache_server +0 -0
  36. mindspore/boost/boost_cell_wrapper.py +1 -1
  37. mindspore/boost/group_loss_scale_manager.py +1 -1
  38. mindspore/common/__init__.py +3 -1
  39. mindspore/common/_jit_fallback_utils.py +2 -3
  40. mindspore/common/_register_for_adapter.py +7 -0
  41. mindspore/common/_stub_tensor.py +6 -1
  42. mindspore/common/_utils.py +5 -17
  43. mindspore/common/api.py +91 -48
  44. mindspore/common/auto_dynamic_shape.py +27 -14
  45. mindspore/common/dtype.py +5 -4
  46. mindspore/common/dump.py +5 -4
  47. mindspore/common/initializer.py +1 -1
  48. mindspore/common/jit_config.py +20 -11
  49. mindspore/common/lazy_inline.py +58 -17
  50. mindspore/common/mindir_util.py +12 -2
  51. mindspore/common/mutable.py +79 -14
  52. mindspore/common/parameter.py +19 -4
  53. mindspore/common/seed.py +9 -9
  54. mindspore/common/sparse_tensor.py +251 -18
  55. mindspore/common/symbol.py +122 -0
  56. mindspore/common/tensor.py +321 -433
  57. mindspore/communication/__init__.py +3 -3
  58. mindspore/communication/_comm_helper.py +5 -0
  59. mindspore/communication/management.py +53 -38
  60. mindspore/config/op_info.config +22 -54
  61. mindspore/context.py +167 -59
  62. mindspore/dataset/__init__.py +5 -5
  63. mindspore/dataset/audio/__init__.py +6 -6
  64. mindspore/dataset/audio/transforms.py +711 -158
  65. mindspore/dataset/callback/ds_callback.py +2 -2
  66. mindspore/dataset/engine/cache_client.py +2 -2
  67. mindspore/dataset/engine/datasets.py +72 -38
  68. mindspore/dataset/engine/datasets_audio.py +14 -14
  69. mindspore/dataset/engine/datasets_standard_format.py +33 -3
  70. mindspore/dataset/engine/datasets_text.py +38 -38
  71. mindspore/dataset/engine/datasets_user_defined.py +7 -7
  72. mindspore/dataset/engine/datasets_vision.py +75 -71
  73. mindspore/dataset/engine/offload.py +5 -7
  74. mindspore/dataset/text/__init__.py +3 -3
  75. mindspore/dataset/text/transforms.py +408 -121
  76. mindspore/dataset/text/utils.py +9 -9
  77. mindspore/dataset/transforms/__init__.py +1 -1
  78. mindspore/dataset/transforms/transforms.py +261 -76
  79. mindspore/dataset/utils/browse_dataset.py +9 -9
  80. mindspore/dataset/vision/__init__.py +3 -3
  81. mindspore/dataset/vision/c_transforms.py +5 -5
  82. mindspore/dataset/vision/transforms.py +2264 -514
  83. mindspore/dataset/vision/utils.py +40 -9
  84. mindspore/dataset/vision/validators.py +7 -1
  85. mindspore/experimental/optim/__init__.py +12 -2
  86. mindspore/experimental/optim/adadelta.py +161 -0
  87. mindspore/experimental/optim/adagrad.py +168 -0
  88. mindspore/experimental/optim/adam.py +35 -34
  89. mindspore/experimental/optim/adamax.py +170 -0
  90. mindspore/experimental/optim/adamw.py +40 -16
  91. mindspore/experimental/optim/asgd.py +153 -0
  92. mindspore/experimental/optim/lr_scheduler.py +60 -119
  93. mindspore/experimental/optim/nadam.py +157 -0
  94. mindspore/experimental/optim/optimizer.py +15 -8
  95. mindspore/experimental/optim/radam.py +194 -0
  96. mindspore/experimental/optim/rmsprop.py +154 -0
  97. mindspore/experimental/optim/rprop.py +164 -0
  98. mindspore/experimental/optim/sgd.py +28 -19
  99. mindspore/hal/__init__.py +34 -0
  100. mindspore/hal/_ascend.py +57 -0
  101. mindspore/hal/_base.py +57 -0
  102. mindspore/hal/_cpu.py +56 -0
  103. mindspore/hal/_gpu.py +57 -0
  104. mindspore/hal/device.py +356 -0
  105. mindspore/hal/event.py +179 -0
  106. mindspore/hal/stream.py +337 -0
  107. mindspore/include/api/data_type.h +2 -2
  108. mindspore/include/api/dual_abi_helper.h +16 -3
  109. mindspore/include/api/model.h +1 -3
  110. mindspore/include/api/status.h +14 -0
  111. mindspore/include/c_api/model_c.h +173 -0
  112. mindspore/include/c_api/ms/base/types.h +1 -0
  113. mindspore/include/c_api/types_c.h +19 -0
  114. mindspore/include/dataset/execute.h +1 -3
  115. mindspore/include/mindapi/base/format.h +125 -23
  116. mindspore/include/mindapi/base/types.h +7 -0
  117. mindspore/lib/libdnnl.so.2 +0 -0
  118. mindspore/lib/libmindspore.so +0 -0
  119. mindspore/lib/libmindspore_backend.so +0 -0
  120. mindspore/lib/libmindspore_common.so +0 -0
  121. mindspore/lib/libmindspore_core.so +0 -0
  122. mindspore/lib/libmindspore_glog.so.0 +0 -0
  123. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  124. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  125. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  126. mindspore/lib/libmindspore_shared_lib.so +0 -0
  127. mindspore/lib/libmpi_adapter.so +0 -0
  128. mindspore/lib/libmpi_collective.so +0 -0
  129. mindspore/lib/libnnacl.so +0 -0
  130. mindspore/lib/libopencv_core.so.4.5 +0 -0
  131. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  132. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  133. mindspore/lib/libps_cache.so +0 -0
  134. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
  135. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
  136. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
  137. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
  138. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
  139. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
  140. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  141. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
  142. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  143. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  144. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6325 -1767
  145. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  146. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
  147. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
  148. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
  149. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
  150. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
  151. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
  152. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
  153. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
  154. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
  155. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
  156. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
  157. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
  158. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
  159. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
  162. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
  163. mindspore/lib/plugin/ascend/libakg.so +0 -0
  164. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  165. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  166. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  167. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  168. mindspore/lib/plugin/cpu/libakg.so +0 -0
  169. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  170. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  171. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  172. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  174. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  175. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  180. mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/__init__.py +5 -1
  185. mindspore/mindrecord/config.py +809 -0
  186. mindspore/mindrecord/filereader.py +25 -0
  187. mindspore/mindrecord/filewriter.py +74 -56
  188. mindspore/mindrecord/mindpage.py +40 -6
  189. mindspore/mindrecord/shardutils.py +3 -2
  190. mindspore/mindrecord/shardwriter.py +7 -0
  191. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  192. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  193. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  194. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  195. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  196. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  197. mindspore/multiprocessing/__init__.py +68 -0
  198. mindspore/nn/cell.py +86 -133
  199. mindspore/nn/dynamic_lr.py +2 -2
  200. mindspore/nn/layer/activation.py +79 -90
  201. mindspore/nn/layer/basic.py +4 -80
  202. mindspore/nn/layer/channel_shuffle.py +3 -16
  203. mindspore/nn/layer/container.py +3 -3
  204. mindspore/nn/layer/conv.py +71 -71
  205. mindspore/nn/layer/embedding.py +105 -44
  206. mindspore/nn/layer/image.py +4 -7
  207. mindspore/nn/layer/normalization.py +46 -38
  208. mindspore/nn/layer/padding.py +26 -39
  209. mindspore/nn/layer/pooling.py +13 -9
  210. mindspore/nn/layer/rnn_cells.py +5 -15
  211. mindspore/nn/layer/rnns.py +6 -5
  212. mindspore/nn/layer/thor_layer.py +1 -2
  213. mindspore/nn/layer/timedistributed.py +1 -1
  214. mindspore/nn/layer/transformer.py +52 -50
  215. mindspore/nn/learning_rate_schedule.py +6 -5
  216. mindspore/nn/loss/loss.py +43 -64
  217. mindspore/nn/optim/ada_grad.py +4 -2
  218. mindspore/nn/optim/adadelta.py +3 -1
  219. mindspore/nn/optim/adafactor.py +1 -1
  220. mindspore/nn/optim/adam.py +102 -181
  221. mindspore/nn/optim/adamax.py +4 -2
  222. mindspore/nn/optim/adasum.py +2 -2
  223. mindspore/nn/optim/asgd.py +4 -2
  224. mindspore/nn/optim/ftrl.py +31 -61
  225. mindspore/nn/optim/lamb.py +5 -3
  226. mindspore/nn/optim/lars.py +2 -2
  227. mindspore/nn/optim/lazyadam.py +6 -4
  228. mindspore/nn/optim/momentum.py +13 -25
  229. mindspore/nn/optim/optimizer.py +6 -3
  230. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  231. mindspore/nn/optim/rmsprop.py +9 -3
  232. mindspore/nn/optim/rprop.py +4 -2
  233. mindspore/nn/optim/sgd.py +6 -5
  234. mindspore/nn/optim/thor.py +2 -2
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  236. mindspore/nn/probability/distribution/beta.py +2 -2
  237. mindspore/nn/probability/distribution/categorical.py +4 -6
  238. mindspore/nn/probability/distribution/cauchy.py +2 -2
  239. mindspore/nn/probability/distribution/exponential.py +1 -1
  240. mindspore/nn/probability/distribution/gumbel.py +2 -2
  241. mindspore/nn/probability/distribution/poisson.py +2 -2
  242. mindspore/nn/probability/distribution/uniform.py +2 -2
  243. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  244. mindspore/nn/wrap/__init__.py +2 -1
  245. mindspore/nn/wrap/cell_wrapper.py +33 -12
  246. mindspore/nn/wrap/grad_reducer.py +148 -8
  247. mindspore/nn/wrap/loss_scale.py +7 -7
  248. mindspore/numpy/__init__.py +2 -0
  249. mindspore/numpy/array_creations.py +2 -0
  250. mindspore/numpy/array_ops.py +1 -5
  251. mindspore/numpy/fft.py +431 -0
  252. mindspore/numpy/math_ops.py +54 -60
  253. mindspore/numpy/utils.py +3 -0
  254. mindspore/ops/__init__.py +5 -4
  255. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  256. mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -22
  257. mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
  258. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  259. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  260. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  261. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  262. mindspore/ops/_op_impl/__init__.py +0 -1
  263. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  264. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  265. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  266. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  267. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  268. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  269. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  270. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  271. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  272. mindspore/ops/_vmap/vmap_array_ops.py +137 -101
  273. mindspore/ops/_vmap/vmap_base.py +8 -1
  274. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  275. mindspore/ops/_vmap/vmap_grad_nn_ops.py +102 -56
  276. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  277. mindspore/ops/_vmap/vmap_math_ops.py +74 -49
  278. mindspore/ops/_vmap/vmap_nn_ops.py +164 -89
  279. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  280. mindspore/ops/auto_generate/__init__.py +31 -0
  281. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +133 -0
  282. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +248 -0
  283. mindspore/ops/auto_generate/gen_arg_handler.py +147 -0
  284. mindspore/ops/auto_generate/gen_extend_func.py +130 -0
  285. mindspore/ops/auto_generate/gen_ops_def.py +4786 -0
  286. mindspore/ops/auto_generate/gen_ops_prim.py +8335 -0
  287. mindspore/ops/auto_generate/pyboost_inner_prim.py +77 -0
  288. mindspore/ops/composite/__init__.py +5 -2
  289. mindspore/ops/composite/base.py +118 -17
  290. mindspore/ops/composite/math_ops.py +9 -48
  291. mindspore/ops/composite/multitype_ops/_compile_utils.py +166 -601
  292. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +15 -133
  293. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  294. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  296. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  297. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  298. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  299. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  300. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  301. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  302. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  303. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  304. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  305. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  307. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  308. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  309. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  310. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  313. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
  315. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  316. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  319. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  321. mindspore/ops/deprecated.py +14 -3
  322. mindspore/ops/extend/__init__.py +46 -0
  323. mindspore/ops/extend/array_func.py +152 -0
  324. mindspore/ops/extend/math_func.py +76 -0
  325. mindspore/ops/{_op_impl/tbe/atomic_addr_clean.py → extend/nn_func.py} +5 -15
  326. mindspore/ops/function/__init__.py +19 -11
  327. mindspore/ops/function/array_func.py +251 -1440
  328. mindspore/ops/function/clip_func.py +12 -13
  329. mindspore/ops/function/debug_func.py +1 -4
  330. mindspore/ops/function/fft_func.py +31 -0
  331. mindspore/ops/function/grad/grad_func.py +24 -17
  332. mindspore/ops/function/image_func.py +27 -21
  333. mindspore/ops/function/linalg_func.py +35 -68
  334. mindspore/ops/function/math_func.py +451 -2360
  335. mindspore/ops/function/nn_func.py +459 -780
  336. mindspore/ops/function/other_func.py +4 -5
  337. mindspore/ops/function/parameter_func.py +5 -93
  338. mindspore/ops/function/random_func.py +24 -80
  339. mindspore/ops/function/sparse_unary_func.py +9 -16
  340. mindspore/ops/function/spectral_func.py +1 -1
  341. mindspore/ops/function/vmap_func.py +14 -14
  342. mindspore/ops/functional.py +56 -62
  343. mindspore/ops/op_info_register.py +22 -19
  344. mindspore/ops/operations/__init__.py +19 -19
  345. mindspore/ops/operations/_grad_ops.py +20 -723
  346. mindspore/ops/operations/_inner_ops.py +178 -286
  347. mindspore/ops/operations/_scalar_ops.py +5 -480
  348. mindspore/ops/operations/_sequence_ops.py +4 -34
  349. mindspore/ops/operations/array_ops.py +99 -2491
  350. mindspore/ops/operations/comm_ops.py +38 -46
  351. mindspore/ops/operations/custom_ops.py +8 -8
  352. mindspore/ops/operations/debug_ops.py +100 -31
  353. mindspore/ops/operations/image_ops.py +1 -217
  354. mindspore/ops/operations/inner_ops.py +3 -38
  355. mindspore/ops/operations/linalg_ops.py +1 -49
  356. mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
  357. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  358. mindspore/ops/operations/manually_defined/ops_def.py +1391 -0
  359. mindspore/ops/operations/math_ops.py +703 -4601
  360. mindspore/ops/operations/nn_ops.py +374 -1748
  361. mindspore/ops/operations/other_ops.py +50 -42
  362. mindspore/ops/operations/random_ops.py +3 -52
  363. mindspore/ops/primitive.py +196 -96
  364. mindspore/ops_generate/__init__.py +27 -0
  365. mindspore/ops_generate/arg_dtype_cast.py +248 -0
  366. mindspore/ops_generate/arg_handler.py +147 -0
  367. mindspore/ops_generate/gen_aclnn_implement.py +266 -0
  368. mindspore/ops_generate/gen_ops.py +1062 -0
  369. mindspore/ops_generate/gen_ops_inner_prim.py +129 -0
  370. mindspore/ops_generate/gen_pyboost_func.py +932 -0
  371. mindspore/ops_generate/gen_utils.py +188 -0
  372. mindspore/ops_generate/op_proto.py +138 -0
  373. mindspore/ops_generate/pyboost_utils.py +364 -0
  374. mindspore/ops_generate/template.py +238 -0
  375. mindspore/parallel/__init__.py +5 -4
  376. mindspore/parallel/_auto_parallel_context.py +21 -76
  377. mindspore/parallel/_cell_wrapper.py +16 -9
  378. mindspore/parallel/_cost_model_context.py +1 -1
  379. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  380. mindspore/parallel/_parallel_serialization.py +30 -46
  381. mindspore/parallel/_ps_context.py +1 -1
  382. mindspore/parallel/_recovery_context.py +1 -1
  383. mindspore/parallel/_tensor.py +19 -7
  384. mindspore/parallel/_transformer/__init__.py +1 -1
  385. mindspore/parallel/_transformer/layers.py +1 -1
  386. mindspore/parallel/_transformer/loss.py +1 -1
  387. mindspore/parallel/_transformer/moe.py +1 -1
  388. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  389. mindspore/parallel/_transformer/transformer.py +1 -1
  390. mindspore/parallel/_utils.py +131 -6
  391. mindspore/parallel/algo_parameter_config.py +6 -6
  392. mindspore/parallel/checkpoint_transform.py +180 -196
  393. mindspore/parallel/cluster/__init__.py +15 -0
  394. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  395. mindspore/parallel/cluster/process_entity/_api.py +345 -0
  396. mindspore/parallel/cluster/process_entity/_utils.py +116 -0
  397. mindspore/parallel/cluster/run.py +139 -0
  398. mindspore/parallel/mpi/__init__.py +1 -1
  399. mindspore/parallel/mpi/_mpi_config.py +1 -1
  400. mindspore/parallel/parameter_broadcast.py +152 -0
  401. mindspore/parallel/shard.py +99 -2
  402. mindspore/profiler/common/util.py +20 -0
  403. mindspore/profiler/envprofiling.py +1 -1
  404. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  405. mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
  406. mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
  407. mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
  408. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +108 -0
  409. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
  410. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
  411. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +104 -0
  412. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  413. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
  414. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  415. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  416. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  417. mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
  418. mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
  419. mindspore/profiler/parser/ascend_msprof_generator.py +68 -285
  420. mindspore/profiler/parser/ascend_op_generator.py +75 -42
  421. mindspore/profiler/parser/ascend_timeline_generator.py +293 -135
  422. mindspore/profiler/parser/base_timeline_generator.py +6 -0
  423. mindspore/profiler/parser/framework_parser.py +3 -2
  424. mindspore/profiler/parser/integrator.py +3 -1
  425. mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
  426. mindspore/profiler/parser/msadvisor_parser.py +1 -1
  427. mindspore/profiler/parser/profiler_info.py +5 -0
  428. mindspore/profiler/profiling.py +296 -166
  429. mindspore/rewrite/__init__.py +2 -13
  430. mindspore/rewrite/api/node.py +121 -35
  431. mindspore/rewrite/api/pattern_engine.py +2 -3
  432. mindspore/rewrite/api/scoped_value.py +16 -15
  433. mindspore/rewrite/api/symbol_tree.py +45 -29
  434. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  435. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  436. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  437. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  438. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  439. mindspore/rewrite/common/__init__.py +1 -2
  440. mindspore/rewrite/common/config.py +24 -0
  441. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  442. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  443. mindspore/rewrite/common/namespace.py +118 -0
  444. mindspore/rewrite/node/__init__.py +5 -5
  445. mindspore/rewrite/node/call_function.py +23 -7
  446. mindspore/rewrite/node/cell_container.py +7 -3
  447. mindspore/rewrite/node/control_flow.py +53 -28
  448. mindspore/rewrite/node/node.py +212 -196
  449. mindspore/rewrite/node/node_manager.py +51 -22
  450. mindspore/rewrite/node/node_topological_manager.py +3 -23
  451. mindspore/rewrite/parsers/__init__.py +12 -0
  452. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  453. mindspore/rewrite/parsers/assign_parser.py +635 -413
  454. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  455. mindspore/rewrite/parsers/class_def_parser.py +107 -144
  456. mindspore/rewrite/parsers/constant_parser.py +5 -5
  457. mindspore/rewrite/parsers/container_parser.py +4 -6
  458. mindspore/rewrite/parsers/expr_parser.py +55 -0
  459. mindspore/rewrite/parsers/for_parser.py +31 -98
  460. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  461. mindspore/rewrite/parsers/if_parser.py +28 -10
  462. mindspore/rewrite/parsers/module_parser.py +8 -182
  463. mindspore/rewrite/parsers/parser.py +1 -5
  464. mindspore/rewrite/parsers/parser_register.py +1 -1
  465. mindspore/rewrite/parsers/return_parser.py +5 -10
  466. mindspore/rewrite/parsers/while_parser.py +59 -0
  467. mindspore/rewrite/sparsify/utils.py +1 -1
  468. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  469. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
  470. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  471. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  472. mindspore/run_check/_check_version.py +6 -14
  473. mindspore/run_check/run_check.py +1 -1
  474. mindspore/safeguard/rewrite_obfuscation.py +9 -19
  475. mindspore/scipy/__init__.py +2 -1
  476. mindspore/scipy/fft.py +133 -0
  477. mindspore/scipy/linalg.py +140 -55
  478. mindspore/scipy/ops.py +15 -71
  479. mindspore/scipy/ops_grad.py +5 -34
  480. mindspore/scipy/optimize/line_search.py +2 -2
  481. mindspore/scipy/optimize/minimize.py +1 -1
  482. mindspore/train/__init__.py +3 -2
  483. mindspore/train/_utils.py +178 -4
  484. mindspore/train/amp.py +167 -245
  485. mindspore/train/callback/_backup_and_restore.py +4 -4
  486. mindspore/train/callback/_callback.py +4 -4
  487. mindspore/train/callback/_checkpoint.py +39 -13
  488. mindspore/train/callback/_early_stop.py +2 -2
  489. mindspore/train/callback/_landscape.py +14 -8
  490. mindspore/train/callback/_loss_monitor.py +2 -2
  491. mindspore/train/callback/_on_request_exit.py +2 -2
  492. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  493. mindspore/train/callback/_summary_collector.py +7 -7
  494. mindspore/train/callback/_time_monitor.py +2 -2
  495. mindspore/train/data_sink.py +1 -1
  496. mindspore/train/dataset_helper.py +13 -4
  497. mindspore/train/loss_scale_manager.py +2 -2
  498. mindspore/train/metrics/accuracy.py +7 -7
  499. mindspore/train/metrics/confusion_matrix.py +8 -6
  500. mindspore/train/metrics/cosine_similarity.py +6 -4
  501. mindspore/train/metrics/error.py +2 -2
  502. mindspore/train/metrics/metric.py +3 -3
  503. mindspore/train/metrics/perplexity.py +2 -1
  504. mindspore/train/metrics/topk.py +2 -2
  505. mindspore/train/mind_ir_pb2.py +75 -6
  506. mindspore/train/model.py +24 -22
  507. mindspore/train/serialization.py +256 -132
  508. mindspore/train/summary/summary_record.py +51 -28
  509. mindspore/train/train_thor/convert_utils.py +3 -3
  510. mindspore/version.py +1 -1
  511. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/METADATA +2 -2
  512. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/RECORD +515 -1061
  513. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/entry_points.txt +1 -0
  514. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  515. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  516. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  517. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  518. mindspore/config/super_bar_config.json +0 -544
  519. mindspore/gen_ops.py +0 -273
  520. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  521. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  522. mindspore/nn/layer/flash_attention.py +0 -189
  523. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  524. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  525. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  526. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  527. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  528. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  529. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  530. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  531. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  532. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  533. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  534. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  535. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  536. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  537. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  538. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  539. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  540. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  541. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  542. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  543. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  544. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  545. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  546. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  547. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  548. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  549. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  550. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  551. mindspore/ops/_op_impl/tbe/add.py +0 -42
  552. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  553. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  554. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  555. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  556. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  557. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  558. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  559. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  560. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  561. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  562. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  563. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  564. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  565. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  566. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  567. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  568. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  569. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  570. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  571. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  572. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  573. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  574. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  575. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  576. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  577. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  578. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  579. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  580. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  581. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  582. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  583. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  584. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  585. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  586. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  587. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  588. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  589. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  590. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  591. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  592. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  593. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  594. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  595. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  596. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  597. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  598. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  599. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  600. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  601. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  602. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  603. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  604. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  605. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  606. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  607. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  608. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  609. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  610. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  611. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  612. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  613. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  614. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  615. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  616. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  617. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  618. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  620. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  621. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  622. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  623. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  624. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  625. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  626. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  627. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  628. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  629. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  630. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  631. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  632. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  633. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  634. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  635. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  636. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  637. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  638. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  639. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  640. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  641. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  642. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  643. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  644. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  645. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  646. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  647. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  648. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  649. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  650. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  651. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  652. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  653. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  654. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  655. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  656. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  658. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  659. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  660. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  661. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  662. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  663. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  664. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  665. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  666. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  667. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  668. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  669. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  670. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  671. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  672. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  673. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  674. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  675. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  676. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  677. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  678. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  680. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  681. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  682. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  683. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  684. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  685. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  686. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  687. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  688. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  689. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  690. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  691. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  692. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  693. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  694. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  695. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  696. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  697. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  698. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  699. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  700. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  701. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  702. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  703. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  704. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  705. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  706. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  707. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  708. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  709. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  710. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  711. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  712. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  713. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  714. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  715. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  716. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  717. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  718. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  719. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  720. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  721. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  722. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  724. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  725. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  726. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  727. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  728. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  729. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  730. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  731. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  732. mindspore/ops/_op_impl/tbe/div.py +0 -41
  733. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  734. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  735. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  737. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  738. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  739. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  740. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  741. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  742. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  743. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  744. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  745. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  746. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  747. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  748. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  749. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  750. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  751. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  752. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  753. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  754. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  755. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  756. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  757. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  758. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  759. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  760. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  761. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  762. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  763. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  764. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  765. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  766. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  767. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  768. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  769. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  770. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  771. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  772. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  773. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  774. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  776. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  777. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  778. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  779. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  780. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  781. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  782. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  783. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  784. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  785. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  786. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  787. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  788. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  789. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  790. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  791. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  792. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  793. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  794. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  795. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  797. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  798. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  799. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  800. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  801. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  802. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  803. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  804. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  805. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  806. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  807. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  808. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  809. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  810. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  811. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  812. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  813. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  814. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  815. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  816. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  817. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  818. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  819. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  820. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  821. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  822. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  823. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  824. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  825. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  826. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  827. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  828. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  829. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  830. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  831. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  832. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  833. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  834. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  835. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  836. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  837. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  838. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  839. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  840. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  841. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  842. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  843. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  844. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  845. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  846. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  847. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  848. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  849. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  850. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  851. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  852. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  853. mindspore/ops/_op_impl/tbe/less.py +0 -41
  854. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  855. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  856. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  857. mindspore/ops/_op_impl/tbe/log.py +0 -40
  858. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  859. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  861. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  862. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  863. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  864. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  865. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  866. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  867. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  868. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  869. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  870. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  871. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  872. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  873. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  874. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  875. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  876. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  877. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  878. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  879. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  880. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  881. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  882. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  883. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  884. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  885. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  886. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  887. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  888. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  889. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  890. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  891. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  892. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  893. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  894. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  895. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  896. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  897. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  898. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  899. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  900. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  901. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  902. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  903. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  904. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  905. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  906. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  907. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  908. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  909. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  910. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  911. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  912. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  913. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  914. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  915. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  916. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  917. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  918. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  919. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  920. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  921. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  922. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  923. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  924. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  925. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  926. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  927. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  928. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  929. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  930. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  931. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  932. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  933. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  934. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  935. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  936. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  937. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  938. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  939. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  940. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  941. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  942. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  943. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  944. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  945. mindspore/ops/_op_impl/tbe/range.py +0 -39
  946. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  947. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  948. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  949. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  950. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  951. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  952. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  953. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  954. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  955. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  956. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  957. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  958. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  959. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  960. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  961. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  962. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  963. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  964. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  965. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  966. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  967. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  968. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  969. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  970. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  971. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  972. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  974. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  975. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  976. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  977. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  978. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  979. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  980. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  981. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  982. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  983. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  984. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  985. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  986. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  987. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  988. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  989. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  990. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  991. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  992. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  993. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  994. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  995. mindspore/ops/_op_impl/tbe/round.py +0 -38
  996. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  997. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  998. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  999. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1000. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1002. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1003. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1004. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1005. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1006. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1007. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1008. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1009. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1010. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1011. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1012. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1013. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1014. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1015. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1016. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1017. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1018. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1019. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1020. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1021. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1022. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1023. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1024. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1025. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1026. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1027. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1028. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1029. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1030. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1031. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1032. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1033. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1034. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1035. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1036. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1037. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1038. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1039. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1040. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1041. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1042. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1043. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1044. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1045. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1046. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1047. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1049. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1050. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1051. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1052. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1053. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1054. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1055. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1057. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1060. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1061. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1062. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1063. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1064. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1065. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1066. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1067. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1068. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1069. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1070. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1071. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1072. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1073. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1074. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1075. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1076. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1077. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1078. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1079. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1080. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1083. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1084. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1085. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1086. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1087. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1088. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1089. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1090. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1091. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1092. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1093. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1095. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1096. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1097. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1098. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1099. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1100. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1101. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1102. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1103. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1104. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1105. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1106. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1107. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1108. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1109. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1110. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1111. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1112. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1113. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1115. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1116. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1117. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1118. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1119. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1120. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1121. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1122. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1123. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1124. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1125. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1126. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1127. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1128. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1129. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1130. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1131. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1132. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1133. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1134. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1135. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1136. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1137. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1138. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1139. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1140. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1142. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1143. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1144. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1145. mindspore/ops/_tracefunc.py +0 -241
  1146. mindspore/ops/arg_dtype_cast.py +0 -54
  1147. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1148. mindspore/rewrite/ast_creator_register.py +0 -37
  1149. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1150. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1151. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1152. mindspore/rewrite/namespace.py +0 -53
  1153. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/WHEEL +0 -0
  1154. {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -15,25 +15,28 @@
15
15
  """SymbolTree class define of Rewrite according to forward function of a network."""
16
16
  import stat
17
17
  from typing import Optional, Union, Tuple, Any, Dict, List
18
+ import types
18
19
  import os
19
20
  import sys
20
21
  import ast
21
22
  import importlib.util
22
23
  import time
24
+ import inspect
25
+ from textwrap import dedent
26
+ from collections import OrderedDict
23
27
 
24
28
  from mindspore.nn import Cell
25
29
  from mindspore import log as logger
26
- from .node.node import Node, TreeNode
27
- from .api.node_type import NodeType
28
- from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
29
- from .api.scoped_value import ScopedValue, ValueType
30
30
  from .symbol_tree_dumper import SymbolTreeDumper
31
- from .node.node_topological_manager import TopoManager
32
- from .namer import TargetNamer, NodeNamer, ClassNamer
33
- from .common.observer import Observer
34
- from .common.observable import Observable
35
- from .common.event import Event
36
- from .node.node_manager import NodeManager
31
+ from ..node import Node, TreeNode, ControlFlow, CallFunction, NodeManager
32
+ from ..api.node_type import NodeType
33
+ from ..api.scoped_value import ScopedValue, ValueType
34
+ from ..ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder, \
35
+ AstImportFinder
36
+ from ..common.namer import TargetNamer, NodeNamer, ClassNamer
37
+ from ..common.observer import Observer
38
+ from ..common.observable import Observable
39
+ from ..common.event import Event
37
40
 
38
41
  if sys.version_info >= (3, 9):
39
42
  import ast as astunparse # pylint: disable=reimported, ungrouped-imports
@@ -115,27 +118,6 @@ class FieldFinder(AstFinder):
115
118
  return self._result
116
119
 
117
120
 
118
- class IfFixer(ast.NodeTransformer):
119
- """
120
- Fix ast.If if body is empty while orelse is not empty.
121
- """
122
-
123
- def visit_If(self, node: ast.If) -> Any:
124
- """Visit a node of type ast.If."""
125
- if not node.body and node.orelse:
126
- node.body.append(ast.Pass())
127
- return super().generic_visit(node)
128
-
129
- def fix(self, node):
130
- """
131
- Fix ast.If node in `node` if whose body is empty while whose orelse is not empty.
132
-
133
- Args:
134
- node (ast.AST): An ast node to be fixed.
135
- """
136
- self.generic_visit(node)
137
-
138
-
139
121
  class SymbolTree(Observer, Observable, NodeManager):
140
122
  """
141
123
  A symbol-tree usually corresponding to forward method of a network.
@@ -147,13 +129,16 @@ class SymbolTree(Observer, Observable, NodeManager):
147
129
  origin_network (Cell): A handler to original network instance.
148
130
  module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
149
131
  """
132
+ # whether parse CallFunction node inserted by user.
133
+ _unparse_inserted_function = True
150
134
 
151
135
  def __init__(self, origin_network: Cell, module_ast: ast.Module):
152
136
  Observer.__init__(self)
153
137
  Observable.__init__(self)
154
138
  self._node_namer = NodeNamer()
155
139
  self._node_namer.add_name('obj')
156
- NodeManager.__init__(self, self._node_namer)
140
+ NodeManager.__init__(self)
141
+ NodeManager.set_manager_node_namer(self, self._node_namer)
157
142
  NodeManager.reg_observer(self, observer=self)
158
143
  # init unique-namers
159
144
  self._target_namer = TargetNamer()
@@ -169,63 +154,69 @@ class SymbolTree(Observer, Observable, NodeManager):
169
154
  self._init_func_ast: Optional[ast.FunctionDef] = None
170
155
  self._deleted_field = {}
171
156
  self._deleted_node = []
172
- self._external_ast = []
173
- self._father_class_ast = []
157
+ # {ast_function: [import_asts]}
158
+ self._external_ast: Dict[ast.FunctionDef, list] = OrderedDict()
159
+ # {ast_class: [import_asts]}
160
+ self._father_class_ast: Dict[ast.ClassDef, list] = OrderedDict()
174
161
  self._modified = False
175
- self._tmp_file_limits = 20
176
- self._tmp_files = []
177
162
  self._saved_file_name = "./network_define.py"
178
163
  # used to insert "sys.path.append(xxx)"
179
164
  self._net_file_paths = []
180
165
  self._tmp_import_strs = []
181
- self._tmp_unmodified_strees: {type, str} = {}
166
+ self._tmp_unmodified_strees: {type, List[SymbolTree]} = {}
182
167
  self._tmp_replacers = []
183
- # Record imported modules and names of each files
184
- # The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
185
- # Format: {file_path: {module: [name, ...], ...}, ...}
186
- self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
187
-
188
- def __del__(self):
189
- for tmp_file in self._tmp_files:
190
- tmp_file.close()
168
+ # user custom codes
169
+ self._custom_codes: List[ast.AST] = []
170
+ # local primitive instances initialized during forward method, e.g. abs_inst = P.Abs()
171
+ self._local_prim_inits: List[Node] = []
191
172
 
192
173
  @staticmethod
193
174
  def _remove_unused_import(module_ast):
194
175
  """remove unused import in self._module_ast"""
195
- str_checker = StrChecker(module_ast)
196
- for i in range(len(module_ast.body) - 1, -1, -1):
197
- body = module_ast.body[i]
198
- if not isinstance(body, (ast.Import, ast.ImportFrom)):
199
- continue
200
- if isinstance(body, ast.Import):
201
- continue
202
- if isinstance(body, ast.ImportFrom) and body.module == "cell":
203
- module_ast.body.remove(body)
204
- continue
205
- for alias in body.names:
206
- name = alias.asname if alias.asname else alias.name
207
- if not str_checker.check(name):
208
- if len(body.names) == 1:
209
- module_ast.body.remove(body)
210
- i += 1
211
- else:
212
- body.names.remove(alias)
176
+ import_nodes: List[Union[ast.Import, ast.ImportFrom]] = []
177
+
178
+ def is_divider(ast_node):
179
+ """judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
180
+ return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
181
+
182
+ for ast_node in module_ast.body[:]:
183
+ if isinstance(ast_node, (ast.Import, ast.ImportFrom)):
184
+ import_nodes.append(ast_node)
185
+ if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
186
+ str_checker = StrChecker(ast_node)
187
+ for import_node in import_nodes:
188
+ for alias in import_node.names[:]:
189
+ name = alias.asname if alias.asname else alias.name
190
+ if name == '*':
191
+ continue
192
+ if not str_checker.check(name):
193
+ import_node.names.remove(alias)
194
+ if not import_node.names:
195
+ module_ast.body.remove(import_node)
196
+ if is_divider(ast_node):
197
+ import_nodes.clear()
213
198
 
214
199
  @staticmethod
215
200
  def _remove_duplicated_import(module_ast):
216
201
  """Remove duplicated import of 'net'."""
217
202
  imports = set()
218
203
  futures = set()
219
- classes = set()
204
+ names = set()
220
205
 
221
206
  class TransImportNode(ast.NodeTransformer):
222
207
  """Find all import nodes from input ast node."""
223
208
 
224
209
  def visit_ClassDef(self, node: ast.ClassDef) -> Any:
225
- class_str = astunparse.unparse(node)
226
- if class_str not in classes:
227
- classes.add(node.name)
210
+ if node.name not in names:
211
+ names.add(node.name)
212
+ return node
213
+ return None
214
+
215
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
216
+ if node.name not in names:
217
+ names.add(node.name)
228
218
  return node
219
+ return None
229
220
 
230
221
  def visit_Try(self, node: ast.Try) -> Any:
231
222
  if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
@@ -233,12 +224,14 @@ class SymbolTree(Observer, Observable, NodeManager):
233
224
  if import_str not in imports:
234
225
  imports.add(import_str)
235
226
  return node
227
+ return None
236
228
 
237
229
  def visit_Import(self, node: ast.Import) -> Any:
238
230
  import_str = astunparse.unparse(node)
239
231
  if import_str not in imports:
240
232
  imports.add(import_str)
241
233
  return node
234
+ return None
242
235
 
243
236
  def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
244
237
  """
@@ -259,21 +252,225 @@ class SymbolTree(Observer, Observable, NodeManager):
259
252
  # remove "__future__" module
260
253
  if node.module == '__future__':
261
254
  futures.add(node.module)
262
- return
255
+ return None
263
256
  # remove modules which have been defined in the code file
264
257
  # it occurs when class A is a father class and other sub-classes import A
265
258
  for alias in node.names[:]:
266
- if alias.name in classes:
259
+ if alias.name in names:
267
260
  node.names.remove(alias)
268
261
  # if the alias(es) in node.names are all removed, this import statement should be removed
269
262
  if not node.names:
270
- return
263
+ return None
271
264
  return node
272
- return
265
+ return None
273
266
 
274
267
  get_node_handler = TransImportNode()
275
268
  get_node_handler.generic_visit(module_ast)
276
269
 
270
+ @staticmethod
271
+ def _remove_arg_annotations(module_ast):
272
+ """Remove annotations in ast.arg to avoid 'xxx is not defined'."""
273
+ ast_args: List[ast.arg] = AstFinder(module_ast).find_all(ast.arg)
274
+ for ast_arg in ast_args:
275
+ ast_arg.annotation = None
276
+
277
+ @staticmethod
278
+ def _check_import(import_path: str, import_module: str):
279
+ """
280
+ Check whether import operation is valid when importing module from specific path.
281
+ """
282
+ if import_path not in sys.path:
283
+ sys.path.append(import_path)
284
+ try:
285
+ importlib.import_module(name=import_module)
286
+ except (ValueError, ImportError) as e:
287
+ logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
288
+ return False
289
+ except Exception as e: # pylint: disable=W0703
290
+ logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
291
+ return False
292
+ return True
293
+
294
+ @staticmethod
295
+ def _process_relative_import(import_node: Union[ast.Import, ast.ImportFrom], file_path: str):
296
+ """Process relative imports"""
297
+ file_path = os.path.normcase(file_path)
298
+ file_path = os.path.normpath(file_path)
299
+ if isinstance(import_node, ast.ImportFrom):
300
+ # pad the ImportFrom with parent path
301
+ # e.g. from ..C import xxx -> from A.B.C import xxx
302
+ import_module = SymbolTree._get_valid_import_info(import_node, file_path)
303
+ if import_module:
304
+ import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
305
+ return import_node
306
+
307
+ @staticmethod
308
+ def _get_valid_import_info(import_node: ast.ImportFrom, file_path: str):
309
+ """Get valid import info while import_node.module is at form of relative path"""
310
+ file_path = os.path.dirname(os.path.abspath(file_path))
311
+ # get real path from import_node.level
312
+ # from .(A) import xxx: current path
313
+ # from ..(A) import xxx: last level path
314
+ level = import_node.level
315
+ # from A import xxx: it does not need to pad, directly return the module name
316
+ if level == 0:
317
+ return import_node.module
318
+ if level > 1:
319
+ for _ in range(level - 1):
320
+ file_path = os.path.dirname(file_path)
321
+ file_path_tmp = file_path[:]
322
+ max_level_count = file_path.count(os.path.sep) - 1
323
+ level_count = 0
324
+ # suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
325
+ suffix = ''
326
+ if import_node.module:
327
+ suffix = '.' + import_node.module
328
+ while level_count < max_level_count:
329
+ file_path_tmp = os.path.dirname(file_path_tmp)
330
+ if file_path_tmp not in sys.path:
331
+ logger.debug(f"{file_path_tmp} not in sys.path, try upper level.")
332
+ level_count += 1
333
+ continue
334
+ import_module = file_path[len(file_path_tmp) + 1:].replace(os.path.sep, '.') + suffix
335
+ if SymbolTree._check_import(file_path_tmp, import_module):
336
+ # try test code success
337
+ return import_module
338
+ # test import ast failed, try upper level
339
+ level_count += 1
340
+ logger.info(f"Try upper level.")
341
+ # try codes with all level failed
342
+ logger.info(f"Test import code: {astunparse.unparse(import_node).strip()} failed, ignore this import code.")
343
+ return None
344
+
345
+ @staticmethod
346
+ def insert_to_ast_while_insert_input(new_node: Node, node_manager: NodeManager):
347
+ """update ast when inserting NodeType.Input node"""
348
+ if not isinstance(node_manager, (SymbolTree, CallFunction)):
349
+ raise ValueError(f"Only support insert Input node into a SymbolTree or a node with type of "
350
+ f"CallFunction, but get {type(node_manager)}")
351
+ # insert a new input
352
+ node_manager.get_input_nodes().append(new_node)
353
+ ast_function: ast.FunctionDef = node_manager.get_manager_ast()
354
+ arg: str = new_node.get_targets()[0].value
355
+ ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
356
+ AstModifier.append_arg_to_function(ast_function, ast_arg)
357
+
358
+ @staticmethod
359
+ def insert_to_ast_while_insert_cell_primitive(new_node: Node, base_node: Node, before_node: bool,
360
+ node_manager: NodeManager, stree):
361
+ """update ast when inserting NodeType.CallCell or NodeType.CallPrimitive node"""
362
+ # create a new assign statement
363
+ ast_assign = new_node.get_ast()
364
+ if ast_assign is None:
365
+ func_name = stree.unique_func_name(new_node.get_name())
366
+ new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
367
+ ast_assign = new_node.update_ast_node()
368
+ if not isinstance(ast_assign, ast.Assign):
369
+ raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
370
+ # Save instance into _origin_network.
371
+ setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
372
+ # Insert ast to __init__ function
373
+ if isinstance(new_node, TreeNode):
374
+ init_code = f"{new_node.get_func_name()} = " \
375
+ f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
376
+ else:
377
+ init_code = f"{new_node.get_func_name()} = obj.{new_node.get_name()}"
378
+ init_ast = ast.parse(init_code).body[0]
379
+ AstModifier.insert_ast_to_function(stree.get_init_func_ast(), init_ast)
380
+ # Insert ast to construct_function/class_internal_function
381
+ ast_base_node = base_node.get_ast() if base_node else None
382
+ ast_node_manager = node_manager.get_manager_ast()
383
+ if not ast_node_manager:
384
+ raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
385
+ "when inserting the ast.")
386
+ AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
387
+
388
+ @staticmethod
389
+ def insert_to_ast_while_insert_function(new_node: CallFunction, base_node: Node, before_node: bool,
390
+ node_manager: NodeManager, stree: 'SymbolTree'):
391
+ """update ast when inserting NodeType.CallFunction node"""
392
+ func_name = str(new_node.get_func_name())
393
+ # create a new assign statement
394
+ ast_assign = new_node.get_ast()
395
+ if ast_assign is None:
396
+ ast_assign = new_node.update_ast_node()
397
+ # Insert ast to node_manager
398
+ ast_base_node = base_node.get_ast() if base_node else None
399
+ ast_node_manager = node_manager.get_manager_ast()
400
+ if not ast_node_manager:
401
+ raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
402
+ "when inserting the ast.")
403
+ AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
404
+ # Ignore Python builtin functions
405
+ func_obj = new_node.get_instance()
406
+ if isinstance(func_obj, types.BuiltinFunctionType):
407
+ logger.warning(f"Ignore built in function: {func_name}")
408
+ return
409
+ # get ast.FunctionDef
410
+ source_code = inspect.getsource(func_obj)
411
+ ast_functiondef = ast.parse(dedent(source_code)).body[0]
412
+ if SymbolTree._unparse_inserted_function or not isinstance(ast_functiondef, ast.FunctionDef):
413
+ logger.debug(f"import '{func_name}' to access function object")
414
+ # add import to make sure that the function object can be accessed.
415
+ module = inspect.getmodule(func_obj)
416
+ top_node_manager = node_manager.get_top_manager()
417
+ belonging_ast = None if isinstance(top_node_manager, SymbolTree) else top_node_manager.get_manager_ast()
418
+ stree.add_import(module, func_name, belonging_ast)
419
+ return
420
+ # parse nodes in inserted function.
421
+ new_node.set_manager_ast(ast_functiondef)
422
+ new_node.set_manager_node_namer(stree.get_node_namer())
423
+ stree.get_external_ast()[ast_functiondef] = []
424
+ # import module which function defined in
425
+ func_file_path = inspect.getabsfile(func_obj)
426
+ stree.save_imports_from_file(func_file_path, ast_functiondef)
427
+ # expand ast codes in function
428
+ from ..ast_helpers import AstFlattener
429
+ ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name], stree)
430
+ # parse ast codes into CallFunction Node
431
+ from ..parsers import ParserRegister
432
+ parser = ParserRegister.instance().get_parser(ast.FunctionDef)
433
+ parser.process(stree, ast_functiondef, node_manager=new_node)
434
+
435
+ @staticmethod
436
+ def insert_to_ast_while_insert_node(new_node: Node, base_node: Node, before_node: bool):
437
+ """ insert_to_ast_while_insert_node. """
438
+ stree = new_node.get_belong_symbol_tree()
439
+ if not stree:
440
+ raise ValueError(f"When inserting node to ast, the belonging symbol tree of new_node is None.")
441
+ node_manager = new_node.get_node_manager()
442
+ if not isinstance(node_manager, (SymbolTree, CallFunction, ControlFlow)):
443
+ raise ValueError(f"When inserting node to ast, the node_manager of new_node {new_node.get_name()} can "
444
+ f"only be one of [SymbolTree, CallFunction, ControlFlow], but get {type(node_manager)}")
445
+ if new_node.get_node_type() == NodeType.Input:
446
+ SymbolTree.insert_to_ast_while_insert_input(new_node, node_manager)
447
+ elif new_node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
448
+ SymbolTree.insert_to_ast_while_insert_cell_primitive(new_node, base_node, before_node, node_manager,
449
+ stree)
450
+ elif new_node.get_node_type() == NodeType.CallFunction:
451
+ SymbolTree.insert_to_ast_while_insert_function(new_node, base_node, before_node, node_manager, stree)
452
+ else:
453
+ raise ValueError(f"When insert node '{new_node.get_name()}' into ast, the type of node can only be "
454
+ f"one of [Input, CallCell, CallPrimitive, CallFunction, Tree], but got "
455
+ f"{new_node.get_node_type()}.")
456
+
457
+ @staticmethod
458
+ def get_node_full_name(node: Node) -> str:
459
+ """Get full name of node"""
460
+ name = node.get_manager_name() if isinstance(node, NodeManager) else node.get_name()
461
+ # traverse node_manager with type of Node
462
+ node_manager = node.get_node_manager()
463
+ while isinstance(node_manager, Node):
464
+ name = f"{node_manager.get_manager_name()}.{name}"
465
+ node_manager = node_manager.get_node_manager()
466
+ # type of node_manager is SymbolTree now
467
+ name = f"{node_manager.get_manager_name()}.{name}"
468
+ return name
469
+
470
+ def local_prim_inits(self) -> List[Node]:
471
+ """get local primitives constructed during forward method"""
472
+ return self._local_prim_inits
473
+
277
474
  def finish_build(self):
278
475
  """Add Event.TopologicalChangeEvent event when build is finished."""
279
476
  self.add_event(Event.TopologicalChangeEvent)
@@ -333,7 +530,7 @@ class SymbolTree(Observer, Observable, NodeManager):
333
530
  corresponding network class.
334
531
  """
335
532
  self._root_ast = ast_node
336
- NodeManager.set_ast_functiondef(self, ast_node)
533
+ NodeManager.set_manager_ast(self, ast_node)
337
534
 
338
535
  def get_class_ast(self):
339
536
  """
@@ -346,7 +543,7 @@ class SymbolTree(Observer, Observable, NodeManager):
346
543
 
347
544
  def set_class_ast(self, ast_node: ast.ClassDef):
348
545
  """
349
- Setter of `_init_func_ast`.
546
+ Setter of `_class_ast`.
350
547
 
351
548
  Args:
352
549
  ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
@@ -420,19 +617,6 @@ class SymbolTree(Observer, Observable, NodeManager):
420
617
  """Get _father_class_ast"""
421
618
  return self._father_class_ast
422
619
 
423
- def get_imported_modules(self, file_path: str):
424
- """Get all modules and module_paths in file of `file_path` ."""
425
- return self._imported_modules.get(file_path, {})
426
-
427
- def save_imported_modules(self, file_path: str, module: str, names: List[str]):
428
- """Save module and names into _imported_modules."""
429
- imported_modules = self.get_imported_modules(file_path)
430
- if imported_modules.get(module):
431
- imported_modules[module].extend(names)
432
- else:
433
- imported_modules[module] = names
434
- self._imported_modules[file_path] = imported_modules
435
-
436
620
  def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
437
621
  """
438
622
  Getter of inputs in topological relation of current 'node_or_name'.
@@ -469,7 +653,13 @@ class SymbolTree(Observer, Observable, NodeManager):
469
653
  return []
470
654
  if real_node.get_node_type() == NodeType.Output:
471
655
  return []
472
- return TopoManager.get_node_users(real_node)
656
+ node_users = []
657
+ for target_users in real_node.get_target_users().values():
658
+ if not target_users:
659
+ continue
660
+ if target_users not in node_users:
661
+ node_users.extend(target_users)
662
+ return node_users
473
663
 
474
664
  def before(self, node_or_name: Union[Node, str]) -> Position:
475
665
  """
@@ -566,8 +756,8 @@ class SymbolTree(Observer, Observable, NodeManager):
566
756
  if base_node is not None:
567
757
  stree = base_node.get_belong_symbol_tree()
568
758
  if stree is not None and stree is not self:
569
- raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
570
- f"current: {self.get_ori_cls_name()}.")
759
+ raise ValueError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
760
+ f"current: {self.get_ori_cls_name()}.")
571
761
 
572
762
  # Check if node is inserted between Input node
573
763
  if base_node is not None and base_node.get_node_type() == NodeType.Input:
@@ -599,7 +789,7 @@ class SymbolTree(Observer, Observable, NodeManager):
599
789
  NodeManager.insert_node(self, new_node, base_node, before_node)
600
790
  if insert_to_ast:
601
791
  # update init-function-ast and construct-function-ast
602
- self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
792
+ self.insert_to_ast_while_insert_node(new_node, base_node, before_node)
603
793
  else:
604
794
  node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
605
795
 
@@ -668,7 +858,7 @@ class SymbolTree(Observer, Observable, NodeManager):
668
858
  # check param_name duplicated
669
859
  if node_manager is None:
670
860
  node_manager = self
671
- for input_node in node_manager._inputs:
861
+ for input_node in node_manager.get_input_nodes():
672
862
  targets = input_node.get_targets()
673
863
  if len(targets) != 1:
674
864
  raise RuntimeError("targets should have 1 elements")
@@ -782,11 +972,15 @@ class SymbolTree(Observer, Observable, NodeManager):
782
972
 
783
973
  if node_manager is self:
784
974
  NodeManager.erase_node(self, node)
785
- ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
975
+ if isinstance(node, ControlFlow):
976
+ ret = AstModifier.earse_ast_of_control_flow(self._root_ast.body, node.get_ast(), node.is_orelse)
977
+ else:
978
+ ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
786
979
  if not ret:
787
980
  raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
788
981
  else:
789
982
  node_manager.erase_node(node)
983
+ node.set_belong_symbol_tree(None)
790
984
  self._deleted_node.append(node.get_name())
791
985
  return node
792
986
 
@@ -815,7 +1009,7 @@ class SymbolTree(Observer, Observable, NodeManager):
815
1009
  for node in new_nodes:
816
1010
  self.insert_node(node, base_node, False, node_manager, True)
817
1011
  base_node = node
818
- _ = self.erase_node(old_node)
1012
+ self.erase_node(old_node)
819
1013
  return new_nodes[-1]
820
1014
 
821
1015
  def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
@@ -836,7 +1030,7 @@ class SymbolTree(Observer, Observable, NodeManager):
836
1030
  raise RuntimeError("Node is not belong to current SymbolTree: ", node)
837
1031
 
838
1032
  new_arg, old_arg = node.set_arg(arg, index)
839
- self._topo_mgr.on_update_arg(node, index, old_arg, new_arg)
1033
+ node.get_node_manager().on_update_arg(node, index, old_arg, new_arg)
840
1034
 
841
1035
  def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
842
1036
  out_idx: Optional[int] = None):
@@ -873,7 +1067,7 @@ class SymbolTree(Observer, Observable, NodeManager):
873
1067
  raise RuntimeError("out_idx out of range: ", out_idx)
874
1068
  new_arg = targets[out_idx]
875
1069
  real_dst_node.set_arg(new_arg, arg_idx)
876
- self._topo_mgr.on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
1070
+ real_dst_node.get_node_manager().on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
877
1071
 
878
1072
  def unique_name(self, name: str):
879
1073
  """Get a unique name in the symboltree"""
@@ -915,10 +1109,13 @@ class SymbolTree(Observer, Observable, NodeManager):
915
1109
  node.set_targets(targets)
916
1110
  self._topo_mgr.on_update_target(node, index, old_target, target)
917
1111
 
918
- def all_nodes(self):
1112
+ def all_nodes(self, subtree_nodes: bool = True):
919
1113
  """
920
1114
  Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
921
1115
 
1116
+ Args:
1117
+ subtree_nodes (bool): Whether include nodes in subtree. Default: True.
1118
+
922
1119
  Returns:
923
1120
  A list of nodes.
924
1121
  """
@@ -930,9 +1127,10 @@ class SymbolTree(Observer, Observable, NodeManager):
930
1127
  for node in node_manager.nodes():
931
1128
  if isinstance(node, NodeManager):
932
1129
  node_managers.append(node)
933
- for tree_node in self.get_tree_nodes():
934
- stree = tree_node.symbol_tree
935
- nodes.extend(stree.all_nodes())
1130
+ if subtree_nodes:
1131
+ for tree_node in self.get_tree_nodes():
1132
+ stree = tree_node.symbol_tree
1133
+ nodes.extend(stree.all_nodes())
936
1134
  return nodes
937
1135
 
938
1136
  def get_node_from_name(self, node_name: str):
@@ -956,13 +1154,16 @@ class SymbolTree(Observer, Observable, NodeManager):
956
1154
  node_managers.append(node)
957
1155
  return None
958
1156
 
959
- def print_node_tabulate(self, all_nodes: bool = False):
1157
+ def get_node_tabulate(self, all_nodes: bool = False) -> str:
960
1158
  """
961
- Print nodes information and nodes' topological relations.
1159
+ Get nodes information and nodes' topological relations.
962
1160
 
963
1161
  Args:
964
1162
  all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
965
1163
  nodes, CellContainer nodes and sub symbol trees.
1164
+
1165
+ Returns:
1166
+ String of nodes' information and topological relations.
966
1167
  """
967
1168
  try:
968
1169
  from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
@@ -971,18 +1172,19 @@ class SymbolTree(Observer, Observable, NodeManager):
971
1172
  "which could not be found on this machine. Run `pip "
972
1173
  "install tabulate` to install the library.")
973
1174
  return ""
974
- print(NodeManager.dump(self, self.get_manager_name()))
1175
+ dump_str = NodeManager.dump(self, self.get_manager_name())
975
1176
  if all_nodes:
976
1177
  node_managers = [self]
977
1178
  while node_managers:
978
1179
  node_manager = node_managers.pop()
979
1180
  for node in node_manager.nodes():
980
1181
  if isinstance(node, NodeManager):
981
- print(node.dump(node.get_manager_name()))
1182
+ dump_str += node.dump(SymbolTree.get_node_full_name(node))
982
1183
  node_managers.append(node)
983
1184
  for tree_node in self.get_tree_nodes():
984
1185
  stree = tree_node.symbol_tree
985
- stree.print_node_tabulate(all_nodes)
1186
+ dump_str += stree.get_node_tabulate(all_nodes)
1187
+ return dump_str
986
1188
 
987
1189
  def dump(self):
988
1190
  """Dump graph."""
@@ -1019,20 +1221,76 @@ class SymbolTree(Observer, Observable, NodeManager):
1019
1221
 
1020
1222
  return False
1021
1223
 
1022
- def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
1224
+ def deduplicate_unmodified_stree(self, code_bodies):
1225
+ """
1226
+ Init function may be different even if stree is not modified manually, when subnets in stree is
1227
+ initialized by different arguments.
1228
+ In this case, we need to wait for code_bodies being fully generated, so that the name of subnets
1229
+ will be updated, then we can deduplicate again according to ast of init function.
1230
+ """
1231
+ # prepare AstClassFinder and AstReplacer
1232
+ if sys.version_info >= (3, 9):
1233
+ class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1234
+ name_replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1235
+ else:
1236
+ class_finder = AstClassFinder(ast.Module(body=code_bodies))
1237
+ name_replacer = AstReplacer(ast.Module(body=code_bodies))
1238
+ # deduplicate all unmodified strees in self._tmp_unmodified_strees
1239
+ deduplicated = False
1240
+ for _, unmodified_strees in self._tmp_unmodified_strees.items():
1241
+ if len(unmodified_strees) <= 1:
1242
+ continue
1243
+ init_func_codes = [astunparse.unparse(stree.get_init_func_ast()) for stree in unmodified_strees]
1244
+ # If the index of an element is not its own, it means that it is a duplicate element
1245
+ to_be_erase = []
1246
+ for idx, code in enumerate(init_func_codes):
1247
+ first_idx = init_func_codes.index(code)
1248
+ if first_idx != idx:
1249
+ first_stree_cls_name = unmodified_strees[first_idx].get_opt_cls_name()
1250
+ duplicated_stree_cls_name = unmodified_strees[idx].get_opt_cls_name()
1251
+ logger.debug(f"replace stree:{duplicated_stree_cls_name} to {first_stree_cls_name}.")
1252
+ # delete duplicated class from code_bodies
1253
+ results = class_finder.find_all(duplicated_stree_cls_name)
1254
+ for ast_cls in results:
1255
+ code_bodies.remove(ast_cls)
1256
+ # replace name of duplicated class in code_bodies to first_stree_cls_name
1257
+ name_replacer.replace_all(duplicated_stree_cls_name, first_stree_cls_name)
1258
+ # record deduplicated stree
1259
+ to_be_erase.append(idx)
1260
+ deduplicated = True
1261
+ # remove class in self._tmp_unmodified_strees
1262
+ for idx in reversed(to_be_erase):
1263
+ unmodified_strees.pop(idx)
1264
+
1265
+ # the name of subnets is updated, so we need to deduplicate again.
1266
+ if deduplicated:
1267
+ self._tmp_replacers.append(name_replacer)
1268
+ self.deduplicate_unmodified_stree(code_bodies)
1269
+
1270
+ def update_unmodified_stree(self, stree, code_bodies) -> bool:
1023
1271
  """
1024
1272
  For the unmodified symbol tree, only one definition code remains in the generated code.
1025
1273
  Everywhere else calling this symbol tree will use the class in this definition code.
1026
1274
  """
1027
1275
  # all modified ast.ClassDef will be exported to code
1028
1276
  if stree.is_modified():
1277
+ logger.debug(f"stree:{stree.get_opt_cls_name()} is modified.")
1029
1278
  return False
1030
1279
  # all un-modified ast.ClassDef only keep one instance
1031
- first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1032
- if first_cls_name is None:
1033
- class_ast = stree.get_class_ast()
1034
- if class_ast:
1035
- self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
1280
+ unmodified_strees = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1281
+ if not unmodified_strees:
1282
+ self._tmp_unmodified_strees[type(stree.get_origin_network())] = [stree]
1283
+ logger.debug(f"stree:{stree.get_opt_cls_name()} is the first stree.")
1284
+ return False
1285
+ # Init function may be different even if stree is not modified, when subnets in stree is
1286
+ # initialized by different arguments.
1287
+ first_stree = unmodified_strees[0]
1288
+ first_stree_cls_name = first_stree.get_opt_cls_name()
1289
+ if astunparse.unparse(stree.get_init_func_ast()) != astunparse.unparse(first_stree.get_init_func_ast()):
1290
+ # init ast may be updated after inserting subtrees of stree, so we need to save unmodified strees
1291
+ # and deduplicate later
1292
+ self._tmp_unmodified_strees[type(stree.get_origin_network())].append(stree)
1293
+ logger.debug(f"init func different, stree:{stree.get_opt_cls_name()}, first_stree:{first_stree_cls_name}.")
1036
1294
  return False
1037
1295
  # Un-modified ast.ClassDef already exist in code_bodies,
1038
1296
  # replace class name to class name of first un-modified ast.ClassDef.
@@ -1040,66 +1298,105 @@ class SymbolTree(Observer, Observable, NodeManager):
1040
1298
  replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1041
1299
  else:
1042
1300
  replacer = AstReplacer(ast.Module(body=code_bodies))
1043
- replacer.replace_all(stree.get_class_ast().name, first_cls_name)
1301
+ logger.debug(f"replace stree:{stree.get_opt_cls_name()} to {first_stree_cls_name}.")
1302
+ replacer.replace_all(stree.get_class_ast().name, first_stree_cls_name)
1044
1303
  self._tmp_replacers.append(replacer)
1045
1304
  return True
1046
1305
 
1047
- def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
1306
+ def init_code_bodies(self, code_bodies: list) -> int:
1307
+ """Init code bodied"""
1308
+ # Add basic imports
1309
+ code_bodies.append(ast.Import([ast.alias(name='sys', asname=None)]))
1310
+ code_bodies.append(ast.Import([ast.alias(name='mindspore', asname=None)]))
1311
+ code_bodies.append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
1312
+ code_bodies.append(ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], level=0))
1313
+ code_bodies.append(ast.ImportFrom(module='mindspore.ops',
1314
+ names=[ast.alias(name='functional', asname='F')], level=0))
1315
+ code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
1316
+ # Add user custom codes into code_bodies
1317
+ custom_codes = self.get_custom_codes()
1318
+ for code_ast in custom_codes:
1319
+ code_bodies.append(code_ast)
1320
+ code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
1321
+ return len(code_bodies)
1322
+
1323
+ def convert_stree_to_code_bodies(self, stree: 'SymbolTree', code_bodies: list, dividing_pos=0) -> int:
1048
1324
  """
1049
1325
  Convert nodes in stree to code_bodies
1326
+ - Add external function asts into code_bodies
1327
+ - Add father class asts into code_bodies
1328
+ - Add import asts of symbol tree into code_bodies
1329
+ - Add user custom codes into code_bodies
1330
+ - Add class asts of symbol tree into code_bodies
1331
+ - Add subtrees to code_bodies
1332
+ """
1333
+ insert_pos = dividing_pos
1334
+ # Add external asts into code_bodies
1335
+ for ast_func, import_asts in reversed(stree.get_external_ast().items()):
1336
+ if self.check_body_exist(ast_func, code_bodies):
1337
+ continue
1338
+ # add imports of external_ast
1339
+ self._tmp_import_strs.clear()
1340
+ for ast_import in import_asts:
1341
+ if not self.check_body_exist(ast_import, code_bodies):
1342
+ code_bodies.insert(insert_pos, ast_import)
1343
+ insert_pos += 1
1344
+ # add external_ast
1345
+ code_bodies.insert(insert_pos, ast_func)
1346
+ insert_pos += 1
1347
+ # add divide
1348
+ code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1349
+ insert_pos += 1
1050
1350
 
1051
- 1. Add import asts into code_bodies
1052
- 2. Add class, function and other type of asts into code_bodies
1053
- 3. Add father class asts into code_bodies
1054
- 4. Add external function asts into code_bodies
1055
- 5. Add subtrees to code_bodies
1056
- 5.1 Add subtrees in construct to code_bodies
1057
- 5.2 Add subtrees in CellContainers to code_bodies
1058
-
1059
- """
1060
- # Add import asts into code_bodies
1351
+ # Add father class asts into code_bodies
1352
+ for ast_class, import_asts in stree.get_father_class_ast().items():
1353
+ if self.check_body_exist(ast_class, code_bodies):
1354
+ continue
1355
+ # add imports of father class
1356
+ self._tmp_import_strs.clear()
1357
+ for ast_import in import_asts:
1358
+ if not self.check_body_exist(ast_import, code_bodies):
1359
+ code_bodies.insert(insert_pos, ast_import)
1360
+ insert_pos += 1
1361
+ # add ast of father class
1362
+ code_bodies.insert(insert_pos, ast_class)
1363
+ insert_pos += 1
1364
+ # add divide
1365
+ code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1366
+ insert_pos += 1
1367
+
1368
+ # external functions and father class are above the dividing_pos to support deduplication.
1369
+ dividing_pos = insert_pos
1370
+
1371
+ # Add import asts of symbol tree into code_bodies
1372
+ self._tmp_import_strs.clear()
1061
1373
  for body in stree.get_import_asts():
1062
1374
  if not self.check_body_exist(body, code_bodies):
1063
1375
  code_bodies.insert(insert_pos, body)
1064
1376
  insert_pos += 1
1065
1377
 
1066
- # Add class, function and other type of asts into code_bodies
1378
+ # Add class asts of symbol tree into code_bodies
1067
1379
  if stree.get_module_ast():
1068
1380
  for body in stree.get_module_ast().body:
1069
1381
  if self.check_body_exist(body, code_bodies):
1070
1382
  continue
1071
- if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
1072
- code_bodies.insert(insert_pos, body)
1073
- else:
1074
- code_bodies.append(body)
1075
-
1076
- # Add father class asts into code_bodies
1077
- for body in reversed(stree.get_father_class_ast()):
1078
- if self.check_body_exist(body, code_bodies):
1079
- # remove exist ast in old position, then insert ast to upper position
1080
- if sys.version_info >= (3, 9):
1081
- exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
1082
- else:
1083
- exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
1084
- code_bodies.remove(exist_ast)
1085
- code_bodies.insert(insert_pos, body)
1086
-
1087
- # Add external asts into code_bodies
1088
- for body in stree.get_external_ast():
1089
- if not self.check_body_exist(body, code_bodies):
1090
1383
  code_bodies.insert(insert_pos, body)
1091
1384
  insert_pos += 1
1092
1385
 
1386
+ # add divide
1387
+ code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
1388
+ insert_pos += 1
1389
+
1093
1390
  # Add subtrees to code_bodies
1094
1391
  for node in stree.get_tree_nodes():
1095
1392
  sub_stree = node.symbol_tree
1096
- # Ignore TreeNode create by function in the class
1097
- if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
1098
- continue
1099
1393
  # For the unmodified class, update class name to name of first class
1100
- if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
1394
+ if self.update_unmodified_stree(sub_stree, code_bodies):
1101
1395
  continue
1102
- self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
1396
+ dividing_pos = self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, dividing_pos)
1397
+
1398
+ # return new dividing position
1399
+ return dividing_pos
1103
1400
 
1104
1401
  def get_code(self) -> str:
1105
1402
  """
@@ -1112,15 +1409,18 @@ class SymbolTree(Observer, Observable, NodeManager):
1112
1409
  self._tmp_unmodified_strees.clear()
1113
1410
  self._tmp_replacers.clear()
1114
1411
  code_bodies = []
1115
- self.convert_stree_to_code_bodies(self, code_bodies)
1412
+ begin_pos = self.init_code_bodies(code_bodies)
1413
+ self.convert_stree_to_code_bodies(self, code_bodies, begin_pos)
1414
+ self.deduplicate_unmodified_stree(code_bodies)
1116
1415
  if sys.version_info >= (3, 9):
1117
1416
  gencode_module = ast.Module(body=code_bodies, type_ignores=[])
1118
1417
  else:
1119
1418
  gencode_module = ast.Module(body=code_bodies)
1120
1419
  SymbolTree._remove_unused_import(gencode_module)
1420
+ self._process_duplicate_name_modules(gencode_module)
1121
1421
  SymbolTree._remove_duplicated_import(gencode_module)
1422
+ SymbolTree._remove_arg_annotations(gencode_module)
1122
1423
  ast.fix_missing_locations(self._module_ast)
1123
- IfFixer().fix(gencode_module)
1124
1424
  code = astunparse.unparse(gencode_module)
1125
1425
  # Revert the class name to its original state
1126
1426
  for replacer in self._tmp_replacers:
@@ -1137,6 +1437,9 @@ class SymbolTree(Observer, Observable, NodeManager):
1137
1437
  cls = self._get_cls_through_file()
1138
1438
  new_net = cls(self._origin_network)
1139
1439
  self._merge_origin_property(new_net)
1440
+ # update parameters' names to fix duplicated names bug
1441
+ # which occurs after inserting cell to celllist/sequentialcell
1442
+ new_net.update_parameters_name()
1140
1443
  return new_net
1141
1444
 
1142
1445
  def set_saved_file_name(self, file_name: str):
@@ -1157,42 +1460,189 @@ class SymbolTree(Observer, Observable, NodeManager):
1157
1460
  f.write(source.encode('utf-8'))
1158
1461
  f.flush()
1159
1462
 
1160
- def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
1161
- node_manager: NodeManager):
1162
- """ insert_to_ast_while_insert_node. """
1163
- if new_node.get_node_type() == NodeType.Input:
1164
- # insert a new input
1165
- self._inputs.append(new_node)
1166
- ast_construct = self.get_ast_root()
1167
- arg: str = new_node.get_targets()[0].value
1168
- ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
1169
- AstModifier.append_arg_to_function(ast_construct, ast_arg)
1463
+
1464
+ def flatten_nodes(self, node, erase_another_branch: bool = False, erase_nodes_after_return: bool = False):
1465
+ """Flatten nodes in ControlFlow node."""
1466
+ if not isinstance(node, ControlFlow):
1467
+ raise ValueError(f"For flatten_nodes, the type of node can only be ControlFlow, but got {type(node)}.")
1468
+ upper_node_manager = node.get_node_manager()
1469
+ if isinstance(upper_node_manager, (SymbolTree, CallFunction)):
1470
+ ast_bodies = upper_node_manager.get_manager_ast().body
1471
+ elif isinstance(upper_node_manager, ControlFlow):
1472
+ ast_bodies = upper_node_manager.get_manager_ast()
1473
+ else:
1474
+ raise ValueError("For flatten_nodes, the node can only be contained in [SymbolTree, CallFunction, "
1475
+ f"ControlFlow], but the node is in {type(upper_node_manager)}.")
1476
+ base_node = node.orelse_node if node.orelse_node else node.body_node
1477
+ for n in node.nodes()[:]:
1478
+ self.erase_node(n)
1479
+ self.insert_node(n, base_node, False, upper_node_manager, False)
1480
+ AstModifier.insert_ast_to_bodies(ast_bodies, n.get_ast(), base_node.get_ast(), False)
1481
+ base_node = n
1482
+ self.erase_node(node)
1483
+ # remove another branch
1484
+ if erase_another_branch:
1485
+ if node.is_orelse:
1486
+ self.erase_node(node.body_node)
1487
+ elif node.orelse_node is not None:
1488
+ self.erase_node(node.orelse_node)
1489
+ # remove nodes after return node
1490
+ if erase_nodes_after_return:
1491
+ has_return = False
1492
+ for n in upper_node_manager.nodes():
1493
+ if has_return:
1494
+ logger.warning(f"Node {n.get_name()} which is behind the flatten return node is "
1495
+ f"automatically erased.")
1496
+ self.erase_node(n)
1497
+ elif n.get_node_type() == NodeType.Output:
1498
+ has_return = True
1499
+
1500
+ def eval_ast_result(self, ast_node: ast.AST) -> (bool, bool):
1501
+ """
1502
+ Eval ast_node and get result, only used in control flow node.
1503
+ """
1504
+ # ast.Constant can be check without eval
1505
+ if isinstance(ast_node, ast.Constant):
1506
+ return True, bool(ast.value)
1507
+ # Get the module where the code of ast_node is located
1508
+ file_path = inspect.getfile(type(self.get_origin_network()))
1509
+ module = None
1510
+ for m in list(sys.modules.values()):
1511
+ if hasattr(m, "__file__") and m.__file__ and os.path.normcase(m.__file__) == os.path.normcase(file_path):
1512
+ module = m
1513
+ break
1514
+ if not module:
1515
+ logger.warning("Failed to get module of ast_node.")
1516
+ return False, False
1517
+ # eval ast_node and get result
1518
+ logger.debug(f"Eval ast node: {astunparse.unparse(ast_node)}")
1519
+ ast_expr = ast.Expression(ast_node)
1520
+ ast_expr = ast.fix_missing_locations(ast_expr)
1521
+ try:
1522
+ # eval with ast make this operation free of instruction injection
1523
+ # pylint: disable=eval-used
1524
+ result = eval(compile(ast_expr, "eval_ast_result", "eval"), {**globals(), **module.__dict__}, locals())
1525
+ except Exception as e: # pylint: disable=broad-except
1526
+ logger.debug(f"Cannot get result of ast_node by eval, err:{e}")
1527
+ return False, False
1528
+ logger.debug(f"Eval ast result success, result: {result}")
1529
+ return True, bool(result)
1530
+
1531
+ def flatten_static_if_control_flow(self):
1532
+ """
1533
+ For static if control flow, flatten codes in branch which will be executed and erase another branch.
1534
+ """
1535
+ for node in self.all_nodes()[:]:
1536
+ if not node.get_belong_symbol_tree():
1537
+ # the node has been erased
1538
+ continue
1539
+ if isinstance(node, ControlFlow) and node.test_result is not None:
1540
+ stree = node.get_belong_symbol_tree()
1541
+ if node.test_result:
1542
+ stree.flatten_nodes(node.body_node, True, True)
1543
+ else:
1544
+ if node.orelse_node is not None:
1545
+ stree.flatten_nodes(node.orelse_node, True, True)
1546
+ else:
1547
+ stree.erase_node(node.body_node)
1548
+
1549
+ def add_custom_codes(self, code: str):
1550
+ """Add user custom codes"""
1551
+ code_ast = ast.parse(code)
1552
+ self._custom_codes.extend(code_ast.body)
1553
+
1554
+ def get_custom_codes(self) -> List[ast.AST]:
1555
+ """Add user custom codes"""
1556
+ return self._custom_codes
1557
+
1558
+ def save_file_path_to_sys(self, level_num, file_path, belonging_ast: ast.AST = None):
1559
+ """
1560
+ Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
1561
+
1562
+ When level_num = 0(e.g. from xxx import yyy), current path will be saved.
1563
+ When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
1564
+ When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
1565
+ """
1566
+ file_path = os.path.dirname(os.path.abspath(file_path))
1567
+ file_path = os.path.normcase(file_path)
1568
+ file_path = os.path.normpath(file_path)
1569
+ if level_num > 1:
1570
+ for _ in range(level_num - 1):
1571
+ file_path = os.path.dirname(file_path)
1572
+ sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
1573
+ # add imports to import_asts of belonging_ast
1574
+ import_asts = self._get_imports_list_of_ast(belonging_ast)
1575
+ import_asts.append(ast.Import([ast.alias(name='sys', asname=None)]))
1576
+ import_asts.append(sys_path_append_ast)
1577
+
1578
+ def save_imports_from_file(self, file_path, belonging_ast: ast.AST = None):
1579
+ """Save imports from file"""
1580
+ self.save_file_path_to_sys(0, file_path, belonging_ast)
1581
+ if not os.path.exists(file_path):
1582
+ raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
1583
+ with open(file_path, "r", encoding="utf-8") as f:
1584
+ source_code = f.read()
1585
+ import_nodes = AstImportFinder(ast.parse(dedent(source_code))).get_import_node()
1586
+ if not import_nodes:
1587
+ return
1588
+ # add imports to import_asts of belonging_ast
1589
+ import_asts = self._get_imports_list_of_ast(belonging_ast)
1590
+ for import_node in import_nodes:
1591
+ import_node = SymbolTree._process_relative_import(import_node, file_path)
1592
+ if import_node:
1593
+ import_asts.append(import_node)
1594
+
1595
+ def add_import(self, module: types.ModuleType, name: str, belonging_ast: None):
1596
+ """add codes: from `module` import `name`"""
1597
+ if not isinstance(module, types.ModuleType):
1598
+ raise TypeError(f"For add_import, module should be ModuleType, but got {type(module)}")
1599
+ if not hasattr(module, name):
1600
+ logger.info(f"module {module.__name__} doesn't have attr '{name}', it may be a local variable.")
1601
+ return
1602
+ # add imports to import_asts of belonging_ast
1603
+ import_asts = self._get_imports_list_of_ast(belonging_ast)
1604
+ if module.__name__ == "__main__":
1605
+ # get attr from module instead of import to avoid duplicate execution of __main__ module
1606
+ code = f"{name} = getattr(sys.modules['__main__'], '{name}')"
1607
+ code_ast = ast.parse(code).body[0]
1608
+ import_asts.append(code_ast)
1609
+ elif module.__name__ == "builtins":
1610
+ # built-in functions are not need to be imported
1611
+ pass
1170
1612
  else:
1171
- # insert a new assign statement
1172
- ast_assign = new_node.get_ast()
1173
- if ast_assign is None:
1174
- func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
1175
- new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
1176
- ast_assign = new_node.update_ast_node()
1177
- if not isinstance(ast_assign, ast.Assign):
1178
- raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
1179
- # Save instance into _origin_network.
1180
- setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
1181
- # Insert ast to __init__ function
1182
- if isinstance(new_node, TreeNode):
1183
- init_code = f"self.{new_node.get_name()} = " \
1184
- f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
1613
+ # add import of obj to ast
1614
+ func_file_path = inspect.getabsfile(module)
1615
+ func_file_path = os.path.normcase(func_file_path)
1616
+ prefix_paths = []
1617
+ for path in sys.path:
1618
+ path = os.path.normcase(path)
1619
+ if func_file_path.startswith(path):
1620
+ prefix_paths.append(path)
1621
+ prefix_paths.sort(key=len, reverse=True)
1622
+ for path in prefix_paths:
1623
+ import_path = func_file_path[len(path):]
1624
+ import_str = import_path.replace(os.path.sep, '.')
1625
+ import_str = import_str[1:] # remove first '.'
1626
+ mod = import_str.rsplit('.', 1)[0]
1627
+ if SymbolTree._check_import(func_file_path[:len(path)], mod):
1628
+ import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
1629
+ import_asts.append(import_node)
1630
+ break
1185
1631
  else:
1186
- init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
1187
- init_ast = ast.parse(init_code).body[0]
1188
- AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
1189
- # Insert ast to construct_function/class_internal_function
1190
- ast_base_node = base_node.get_ast() if base_node else None
1191
- ast_functiondef = node_manager.get_ast_functiondef()
1192
- if not ast_functiondef:
1193
- raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
1194
- "when inserting the ast.")
1195
- AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
1632
+ self.save_file_path_to_sys(0, func_file_path, belonging_ast)
1633
+ mod = os.path.basename(func_file_path).rsplit('.')[0]
1634
+ import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
1635
+ import_asts.append(import_node)
1636
+
1637
+ def _get_imports_list_of_ast(self, belonging_ast: ast.AST):
1638
+ # get import_asts of belonging_ast
1639
+ import_asts = self._import_asts
1640
+ if belonging_ast is not None:
1641
+ if belonging_ast in self._father_class_ast:
1642
+ import_asts = self._father_class_ast.get(belonging_ast)
1643
+ elif belonging_ast in self._external_ast:
1644
+ import_asts = self._external_ast.get(belonging_ast)
1645
+ return import_asts
1196
1646
 
1197
1647
  def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1198
1648
  if isinstance(node_or_name, str):
@@ -1265,7 +1715,7 @@ class SymbolTree(Observer, Observable, NodeManager):
1265
1715
  time.sleep(0.5)
1266
1716
  i += 1
1267
1717
  if not tmp_module:
1268
- logger.error(f"load module {tmp_module_name} failed.")
1718
+ raise ImportError(f"load module {tmp_module_name} failed.")
1269
1719
  # Save new module to sys.modules to support inspect.getsource().
1270
1720
  sys.modules[tmp_module_name] = tmp_module
1271
1721
  network_cls = getattr(tmp_module, self._opt_cls_name)
@@ -1295,6 +1745,75 @@ class SymbolTree(Observer, Observable, NodeManager):
1295
1745
  for c in cells:
1296
1746
  new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
1297
1747
  # merge primitives
1748
+ # pylint: disable=protected-access
1298
1749
  primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1299
1750
  for p in primitives:
1300
- new_net._primitives[p] = self._origin_network._primitives[p]
1751
+ new_net._primitives[p] = self._origin_network._primitives[p] # pylint: disable=protected-access
1752
+
1753
+ def _process_duplicate_name_modules(self, module_ast: ast.Module):
1754
+ """Adjust names of imported modules with same name and different import path."""
1755
+ # {name1: [path1, path2, ...], ...}
1756
+ name_path_dict: Dict[str, List[str]] = {}
1757
+ # names of modules need to be suffixed: {name1: suffixed_name1, ...}
1758
+ name_need_suffix: Dict[str, str] = {}
1759
+ # used to record replace actions in ast.ImportFrom
1760
+ import_replacer = AstReplacer(None)
1761
+ self._tmp_replacers.append(import_replacer)
1762
+
1763
+ def suffix_alias(alias: ast.alias, suffix: int):
1764
+ """suffix the name of alias in ast.ImportFrom"""
1765
+ new_name = f"{alias.asname}_{suffix}" if alias.asname else f"{alias.name}_{suffix}"
1766
+ import_replacer._trace.append((alias, 'asname', alias.asname, new_name)) # pylint: disable=protected-access
1767
+ alias.asname = new_name
1768
+ return new_name
1769
+
1770
+ def is_divider(ast_node):
1771
+ """judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
1772
+ return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
1773
+
1774
+ def record_imports(ast_node: ast.ImportFrom):
1775
+ """record name and path of imported modules to find the duplicate name modules."""
1776
+ for alias in ast_node.names[:]:
1777
+ name = alias.asname if alias.asname else alias.name
1778
+ if name == '*':
1779
+ continue
1780
+ # current name is firstly imported, just record it
1781
+ if name not in name_path_dict:
1782
+ name_path_dict[name] = [ast_node.module]
1783
+ continue
1784
+ # current name is imported before, check whether it is a duplicated name
1785
+ for idx, path in enumerate(name_path_dict[name]):
1786
+ if path.startswith(ast_node.module):
1787
+ # e.g. origin code is 'from a.b.c import A' and new code is 'from a.b import A'
1788
+ # then we update name_path_dict[name][idx] from 'a.b.c' to 'a.b' and update name to A_{idx}
1789
+ name_path_dict[name][idx] = ast_node.module
1790
+ if idx > 0:
1791
+ name_need_suffix[name] = suffix_alias(alias, idx)
1792
+ break
1793
+ elif ast_node.module.startswith(path):
1794
+ # e.g. origin code is 'from a.b import A' and new code is 'from a.b.c import A'
1795
+ # then we just need to update name to A_{idx}
1796
+ if idx > 0:
1797
+ name_need_suffix[name] = suffix_alias(alias, idx)
1798
+ break
1799
+ else:
1800
+ # current name is imported from a new path, save the path and update the name
1801
+ name_path_dict[name].append(ast_node.module)
1802
+ name_need_suffix[name] = suffix_alias(alias, len(name_path_dict[name]) - 1)
1803
+
1804
+ def suffix_names_in_ast(ast_node: Union[ast.ClassDef, ast.FunctionDef]):
1805
+ """suffix names in ast.ClassDef or ast.FunctionDef"""
1806
+ if not name_need_suffix:
1807
+ return
1808
+ name_replacer = AstReplacer(ast_node)
1809
+ self._tmp_replacers.append(name_replacer)
1810
+ for name, new_name in name_need_suffix.items():
1811
+ name_replacer.replace_all(name, new_name)
1812
+
1813
+ for ast_node in module_ast.body:
1814
+ if isinstance(ast_node, ast.ImportFrom):
1815
+ record_imports(ast_node)
1816
+ if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
1817
+ suffix_names_in_ast(ast_node)
1818
+ if is_divider(ast_node):
1819
+ name_need_suffix.clear()