mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (505) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +12 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +61 -71
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +74 -104
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +13 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +28 -5
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +141 -88
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/cell_wrapper.py +84 -34
  205. mindspore/nn/wrap/grad_reducer.py +8 -5
  206. mindspore/nn/wrap/loss_scale.py +105 -42
  207. mindspore/numpy/array_creations.py +1 -2
  208. mindspore/numpy/array_ops.py +3 -2
  209. mindspore/numpy/utils_const.py +5 -5
  210. mindspore/opencv_core452.dll +0 -0
  211. mindspore/opencv_imgcodecs452.dll +0 -0
  212. mindspore/opencv_imgproc452.dll +0 -0
  213. mindspore/ops/_grad_experimental/__init__.py +0 -5
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  215. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  216. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  217. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  218. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  219. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  220. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  222. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  224. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  225. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  226. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  227. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  228. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  229. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  230. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  231. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  232. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  233. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  234. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  235. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  236. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  237. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  238. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  239. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  240. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  241. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  242. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  243. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  244. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  245. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  246. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  248. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  249. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  250. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  251. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  252. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  253. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  254. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  255. mindspore/ops/_primitive_cache.py +1 -1
  256. mindspore/ops/_tracefunc.py +45 -13
  257. mindspore/ops/_utils/utils.py +6 -1
  258. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  259. mindspore/ops/_vmap/vmap_base.py +3 -3
  260. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  261. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  262. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  263. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  264. mindspore/ops/arg_dtype_cast.py +54 -0
  265. mindspore/ops/composite/base.py +37 -10
  266. mindspore/ops/composite/math_ops.py +5 -4
  267. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  268. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  269. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  270. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  271. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  272. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  273. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  274. mindspore/ops/deprecated.py +304 -0
  275. mindspore/ops/function/__init__.py +4 -1
  276. mindspore/ops/function/array_func.py +174 -193
  277. mindspore/ops/function/clip_func.py +81 -13
  278. mindspore/ops/function/debug_func.py +1 -1
  279. mindspore/ops/function/grad/grad_func.py +18 -9
  280. mindspore/ops/function/image_func.py +10 -4
  281. mindspore/ops/function/linalg_func.py +5 -5
  282. mindspore/ops/function/math_func.py +575 -386
  283. mindspore/ops/function/nn_func.py +568 -260
  284. mindspore/ops/function/random_func.py +88 -57
  285. mindspore/ops/function/sparse_func.py +1 -1
  286. mindspore/ops/function/sparse_unary_func.py +14 -12
  287. mindspore/ops/function/vmap_func.py +6 -5
  288. mindspore/ops/functional.py +15 -10
  289. mindspore/ops/op_info_register.py +244 -25
  290. mindspore/ops/operations/__init__.py +28 -19
  291. mindspore/ops/operations/_grad_ops.py +72 -7
  292. mindspore/ops/operations/_inner_ops.py +350 -17
  293. mindspore/ops/operations/_quant_ops.py +4 -8
  294. mindspore/ops/operations/_sequence_ops.py +42 -0
  295. mindspore/ops/operations/array_ops.py +68 -282
  296. mindspore/ops/operations/comm_ops.py +107 -59
  297. mindspore/ops/operations/custom_ops.py +94 -70
  298. mindspore/ops/operations/debug_ops.py +8 -4
  299. mindspore/ops/operations/image_ops.py +18 -12
  300. mindspore/ops/operations/inner_ops.py +26 -3
  301. mindspore/ops/operations/math_ops.py +189 -141
  302. mindspore/ops/operations/nn_ops.py +794 -489
  303. mindspore/ops/operations/other_ops.py +0 -22
  304. mindspore/ops/operations/random_ops.py +53 -111
  305. mindspore/ops/operations/sparse_ops.py +3 -1
  306. mindspore/ops/primitive.py +24 -18
  307. mindspore/parallel/_auto_parallel_context.py +68 -8
  308. mindspore/parallel/_cost_model_context.py +2 -2
  309. mindspore/parallel/_offload_context.py +17 -3
  310. mindspore/parallel/_parallel_serialization.py +12 -5
  311. mindspore/parallel/_ps_context.py +12 -0
  312. mindspore/parallel/_tensor.py +18 -13
  313. mindspore/parallel/_transformer/layers.py +5 -3
  314. mindspore/parallel/_transformer/loss.py +1 -0
  315. mindspore/parallel/_transformer/moe.py +2 -2
  316. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  317. mindspore/parallel/_transformer/transformer.py +23 -3
  318. mindspore/parallel/_utils.py +11 -7
  319. mindspore/parallel/algo_parameter_config.py +85 -5
  320. mindspore/parallel/checkpoint_transform.py +19 -12
  321. mindspore/parallel/shard.py +21 -14
  322. mindspore/pgodb140.dll +0 -0
  323. mindspore/pgort140.dll +0 -0
  324. mindspore/profiler/common/struct_type.py +3 -3
  325. mindspore/profiler/common/util.py +4 -2
  326. mindspore/profiler/envprofiling.py +1 -1
  327. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  328. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  329. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  330. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  331. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  332. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  333. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  334. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  335. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  336. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  337. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  338. mindspore/profiler/parser/flops_parser.py +15 -11
  339. mindspore/profiler/parser/framework_parser.py +38 -22
  340. mindspore/profiler/parser/hccl_parser.py +16 -12
  341. mindspore/profiler/parser/integrator.py +22 -11
  342. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  343. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  344. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  345. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  346. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  347. mindspore/profiler/parser/optime_parser.py +1 -1
  348. mindspore/profiler/parser/profiler_info.py +21 -2
  349. mindspore/profiler/parser/step_trace_parser.py +11 -14
  350. mindspore/profiler/profiling.py +179 -89
  351. mindspore/rewrite/api/node.py +102 -19
  352. mindspore/rewrite/api/node_type.py +5 -1
  353. mindspore/rewrite/api/pattern_engine.py +1 -1
  354. mindspore/rewrite/api/scoped_value.py +9 -17
  355. mindspore/rewrite/api/symbol_tree.py +131 -47
  356. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  357. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  358. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  359. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  360. mindspore/rewrite/common/rewrite_elog.py +5 -1
  361. mindspore/rewrite/namer.py +33 -24
  362. mindspore/rewrite/namespace.py +14 -5
  363. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  364. mindspore/rewrite/node/call_function.py +79 -0
  365. mindspore/rewrite/node/cell_container.py +135 -0
  366. mindspore/rewrite/node/control_flow.py +88 -0
  367. mindspore/rewrite/{node.py → node/node.py} +273 -234
  368. mindspore/rewrite/node/node_manager.py +254 -0
  369. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  370. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  371. mindspore/rewrite/parsers/assign_parser.py +216 -221
  372. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  373. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  374. mindspore/rewrite/parsers/constant_parser.py +9 -6
  375. mindspore/rewrite/parsers/container_parser.py +9 -7
  376. mindspore/rewrite/parsers/for_parser.py +36 -15
  377. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  378. mindspore/rewrite/parsers/if_parser.py +28 -24
  379. mindspore/rewrite/parsers/module_parser.py +196 -25
  380. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  381. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  382. mindspore/rewrite/parsers/return_parser.py +6 -6
  383. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  384. mindspore/rewrite/sparsify/utils.py +1 -1
  385. mindspore/rewrite/symbol_tree.py +523 -578
  386. mindspore/rewrite/symbol_tree_builder.py +9 -193
  387. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  388. mindspore/run_check/_check_version.py +6 -4
  389. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  390. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  391. mindspore/tbbmalloc.dll +0 -0
  392. mindspore/tinyxml2.dll +0 -0
  393. mindspore/train/_utils.py +7 -3
  394. mindspore/train/amp.py +323 -123
  395. mindspore/train/anf_ir_pb2.py +14 -2
  396. mindspore/train/callback/_backup_and_restore.py +2 -12
  397. mindspore/train/callback/_callback.py +29 -4
  398. mindspore/train/callback/_checkpoint.py +23 -8
  399. mindspore/train/callback/_early_stop.py +2 -2
  400. mindspore/train/callback/_landscape.py +4 -4
  401. mindspore/train/callback/_loss_monitor.py +2 -2
  402. mindspore/train/callback/_on_request_exit.py +2 -2
  403. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  404. mindspore/train/callback/_summary_collector.py +15 -8
  405. mindspore/train/callback/_time_monitor.py +58 -5
  406. mindspore/train/data_sink.py +5 -11
  407. mindspore/train/dataset_helper.py +84 -57
  408. mindspore/train/loss_scale_manager.py +2 -2
  409. mindspore/train/metrics/__init__.py +3 -3
  410. mindspore/train/metrics/cosine_similarity.py +1 -1
  411. mindspore/train/metrics/hausdorff_distance.py +3 -2
  412. mindspore/train/metrics/mean_surface_distance.py +3 -2
  413. mindspore/train/metrics/metric.py +39 -19
  414. mindspore/train/metrics/roc.py +2 -2
  415. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  416. mindspore/train/mind_ir_pb2.py +85 -36
  417. mindspore/train/model.py +187 -47
  418. mindspore/train/serialization.py +487 -161
  419. mindspore/train/summary/_summary_adapter.py +1 -1
  420. mindspore/train/summary/_writer_pool.py +3 -2
  421. mindspore/train/summary/summary_record.py +37 -17
  422. mindspore/train/train_thor/convert_utils.py +3 -3
  423. mindspore/train/train_thor/dataset_helper.py +1 -1
  424. mindspore/turbojpeg.dll +0 -0
  425. mindspore/vcmeta.dll +0 -0
  426. mindspore/vcruntime140.dll +0 -0
  427. mindspore/vcruntime140_1.dll +0 -0
  428. mindspore/version.py +1 -1
  429. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
  430. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
  431. mindspore/_extends/graph_kernel/expander.py +0 -80
  432. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  433. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  434. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  435. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  436. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  437. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  438. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  439. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  440. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  441. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  442. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  443. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  444. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  445. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  446. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  447. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  448. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  449. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  450. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  451. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  452. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  453. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  454. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  455. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  456. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  457. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  458. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  459. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  460. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  461. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  462. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  463. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  464. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  465. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  466. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  467. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  468. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  469. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  470. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  471. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  472. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  473. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  474. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  475. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  476. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  477. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  478. mindspore/dataset/datapreprocess/__init__.py +0 -20
  479. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  480. mindspore/include/api/net.h +0 -142
  481. mindspore/nn/lr_scheduler.py +0 -262
  482. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  483. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  484. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  485. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  486. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  487. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  488. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  489. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  490. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  491. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  492. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  493. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  497. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  502. mindspore/rewrite/node_visitor.py +0 -44
  503. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  504. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  505. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Rewrite module api: SymbolTree."""
16
- from typing import Optional, Union
17
- from types import FunctionType
16
+ from typing import Optional, Union, List
18
17
  import mindspore as ms
19
18
 
20
19
  from mindspore.nn import Cell
@@ -53,7 +52,61 @@ class SymbolTree:
53
52
 
54
53
  This interface parses the `network` instance, expands each source
55
54
  code statement of the forward computation process, and parses it into nodes,
56
- which is stored in the SymbolTree.
55
+ which is stored in the SymbolTree. The specific process is as follows:
56
+
57
+ 1. Obtain the source code of the network instance.
58
+ 2. Perform AST parsing on the network and obtain the AST nodes (abstract syntax trees) of each
59
+ statement in the network.
60
+ 3. Expand complex statements in the network forward evaluation process into multiple simple statements.
61
+ 4. Create a SymbolTree object. Each SymbolTree corresponds to one network instance.
62
+ 5. Use the rewrite node to store each statement of the network forward computation process. The node records
63
+ the input, output, and other information of the statement.
64
+ 6. Save the rewrite node to the SymbolTree, and update and maintain the topological connection between
65
+ the nodes.
66
+ 7. Return the SymbolTree object corresponding to the network instance.
67
+
68
+ If a user-defined network of type :class:`mindspore.nn.Cell` is called in the forward computation process
69
+ of the network, rewrite will generate a node of type `NodeType.Tree` for the corresponding statement. This
70
+ type of node stores a new SymbolTree, which parses and maintains the node information of the user-defined
71
+ network.
72
+
73
+ If the following types of statements are called in the forward computation process of the network, rewrite
74
+ will parse the internal statements in the statement and generate corresponding nodes:
75
+
76
+ - :class:`mindspore.nn.SequentialCell`
77
+ - Functions within classes
78
+ - Control flow statements, such as `if` statements
79
+
80
+ Note:
81
+ Because the specific execution branch of control flows are still unknown during the rewrite operation
82
+ of the network, no topology information will be established between the nodes inside the control flow
83
+ and the nodes outside.
84
+ Users cannot obtain nodes inside the control flow when they acquire nodes outside the control flow using
85
+ interfaces like :func:`mindspore.rewrite.Node.get_inputs` and :func:`mindspore.rewrite.Node.get_users` .
86
+ Users also cannot obtain nodes outside the control flow, if they use these interfaces inside the control
87
+ flow.
88
+ Therefore, when users modify the network, they need to manually handle the node information inside and
89
+ outside the control flow.
90
+
91
+ The current rewrite module has the following syntax limitations:
92
+
93
+ - Only networks of type :class:`mindspore.nn.Cell` are supported as input to the rewrite module.
94
+ - Parsing assignment statements with multiple output values is not currently supported.
95
+ - Parsing loop statements is not currently supported.
96
+ - Parsing decorator syntax is not currently supported.
97
+ - Parsing class variable syntax is not currently supported. If class variable uses external data,
98
+ the network after rewrite may be missing data.
99
+ - Parsing local classes and embedded classes is not currently supported, that is, the definition
100
+ of classes need to be placed on the outermost layer.
101
+ - Parsing closure syntax is not currently supported, that is, the definition of out-of-class
102
+ functions need to be placed at the outermost layer.
103
+ - Parsing lambda expression syntax is not currently supported.
104
+
105
+ For statements that do not support parsing, rewrite will generate nodes of type `NodeType.Python`
106
+ for corresponding statements to ensure that the network after rewrite can run normally.
107
+ The `Python` node does not support modifying the input and output of statements, and there may be
108
+ a problem between variable names and those generated by the rewrite. In this case, users need to
109
+ adjust the variable names manually.
57
110
 
58
111
  Args:
59
112
  network (Cell): `network` used to create SymbolTree.
@@ -67,7 +120,7 @@ class SymbolTree:
67
120
  Examples:
68
121
  >>> from mindspore.rewrite import SymbolTree
69
122
  >>> # Define the network structure of LeNet5. Refer to
70
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
123
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
71
124
  >>> net = LeNet5()
72
125
  >>> stree = SymbolTree.create(net)
73
126
  >>> print(type(stree))
@@ -90,42 +143,37 @@ class SymbolTree:
90
143
  if v not in MsDtypes and not isinstance(v, ParamTypes):
91
144
  raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
92
145
 
93
- def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
94
- Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
95
- Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
96
- args_ = list(args)
97
- SymbolTree._check_args_type(args_)
98
- for i, arg in enumerate(args_):
99
- if isinstance(arg, Node):
100
- args_[i] = arg.get_handler()
101
- SymbolTree._check_kwargs_type(kwargs)
102
- for key, value in kwargs.items():
103
- if isinstance(value, Node):
104
- kwargs[key] = value.get_handler()
105
- return Node(self._symbol_tree._create_call_function(func, targets, args_, kwargs)) # pylint: disable=W0212
106
-
107
146
  def get_handler(self) -> SymbolTreeImpl:
108
147
  return self._symbol_tree
109
148
 
110
- def nodes(self):
149
+ def nodes(self, all_nodes: bool = False):
111
150
  """
112
151
  Get the generator of the node in the current SymbolTree, which is used to iterate
113
152
  through the nodes in SymbolTree.
114
153
 
154
+ Args:
155
+ all_nodes (bool): Get all nodes including nodes in CallFunction node, CellContainer node
156
+ and sub symbol tree. Default: ``False`` .
157
+
115
158
  Returns:
116
- A generator for node of current SymbolTree.
159
+ A generator for nodes in SymbolTree.
160
+
161
+ Raises:
162
+ TypeError: If `all_nodes` is not bool.
117
163
 
118
164
  Examples:
119
165
  >>> from mindspore.rewrite import SymbolTree
120
166
  >>> # Define the network structure of LeNet5. Refer to
121
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
167
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
122
168
  >>> net = LeNet5()
123
169
  >>> stree = SymbolTree.create(net)
124
170
  >>> print([node.get_name() for node in stree.nodes()])
125
171
  ['input_x', 'Expr', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
126
172
  'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return']
127
173
  """
128
- for node in self._symbol_tree.nodes():
174
+ Validator.check_value_type("all_nodes", all_nodes, [bool], "nodes")
175
+ nodes = self._symbol_tree.all_nodes() if all_nodes else self._symbol_tree.nodes()
176
+ for node in nodes:
129
177
  yield Node(node)
130
178
 
131
179
  def get_node(self, node_name: str) -> Optional[Node]:
@@ -141,7 +189,7 @@ class SymbolTree:
141
189
  Examples:
142
190
  >>> from mindspore.rewrite import SymbolTree
143
191
  >>> # Define the network structure of LeNet5. Refer to
144
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
192
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
145
193
  >>> net = LeNet5()
146
194
  >>> stree = SymbolTree.create(net)
147
195
  >>> node = stree.get_node('conv1')
@@ -149,12 +197,12 @@ class SymbolTree:
149
197
  conv1
150
198
  """
151
199
  Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
152
- node_impl = self._symbol_tree.get_node(node_name)
200
+ node_impl = self._symbol_tree.get_node_from_name(node_name)
153
201
  if node_impl is None:
154
202
  return None
155
203
  return Node(node_impl)
156
204
 
157
- def get_inputs(self) -> [Node]:
205
+ def get_inputs(self) -> List[Node]:
158
206
  return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
159
207
 
160
208
  def before(self, node: Union[Node, str]):
@@ -174,15 +222,17 @@ class SymbolTree:
174
222
  Examples:
175
223
  >>> from mindspore.rewrite import SymbolTree
176
224
  >>> # Define the network structure of LeNet5. Refer to
177
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
225
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
178
226
  >>> net = LeNet5()
179
227
  >>> stree = SymbolTree.create(net)
180
228
  >>> for node in stree.nodes():
181
229
  ... if node.get_name() == "conv1":
182
230
  ... position = stree.before(node)
183
231
  """
184
- Validator.check_value_type("node", node, [Node], "SymbolTree")
185
- return self._symbol_tree.before(node.get_handler())
232
+ Validator.check_value_type("node", node, [Node, str], "SymbolTree")
233
+ if isinstance(node, Node):
234
+ node = node.get_handler()
235
+ return self._symbol_tree.before(node)
186
236
 
187
237
  def after(self, node: Union[Node, str]):
188
238
  """
@@ -201,15 +251,17 @@ class SymbolTree:
201
251
  Examples:
202
252
  >>> from mindspore.rewrite import SymbolTree
203
253
  >>> # Define the network structure of LeNet5. Refer to
204
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
254
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
205
255
  >>> net = LeNet5()
206
256
  >>> stree = SymbolTree.create(net)
207
257
  >>> for node in stree.nodes():
208
258
  ... if node.get_name() == "conv1":
209
259
  ... position = stree.after(node)
210
260
  """
211
- Validator.check_value_type("node", node, [Node], "SymbolTree")
212
- return self._symbol_tree.after(node.get_handler())
261
+ Validator.check_value_type("node", node, [Node, str], "SymbolTree")
262
+ if isinstance(node, Node):
263
+ node = node.get_handler()
264
+ return self._symbol_tree.after(node)
213
265
 
214
266
  def insert(self, position, node: Node) -> Node:
215
267
  """
@@ -233,7 +285,7 @@ class SymbolTree:
233
285
  >>> from mindspore.rewrite import SymbolTree, ScopedValue
234
286
  >>> import mindspore.nn as nn
235
287
  >>> # Define the network structure of LeNet5. Refer to
236
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
288
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
237
289
  >>> net = LeNet5()
238
290
  >>> stree = SymbolTree.create(net)
239
291
  >>> node = stree.get_node("conv1")
@@ -244,7 +296,7 @@ class SymbolTree:
244
296
  """
245
297
  Validator.check_value_type("position", position, [Position], "SymbolTree")
246
298
  Validator.check_value_type("node", node, [Node], "SymbolTree")
247
- return Node(self._symbol_tree.insert_node(position, node.get_handler()))
299
+ return Node(self._symbol_tree.insert_node(node.get_handler(), position.node, position.before_node))
248
300
 
249
301
  def erase(self, node: Union[Node, str]) -> Optional[Node]:
250
302
  """
@@ -262,16 +314,18 @@ class SymbolTree:
262
314
  Examples:
263
315
  >>> from mindspore.rewrite import SymbolTree
264
316
  >>> # Define the network structure of LeNet5. Refer to
265
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
317
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
266
318
  >>> net = LeNet5()
267
319
  >>> stree = SymbolTree.create(net)
268
320
  >>> node = stree.get_node("conv1")
269
321
  >>> stree.erase(node)
270
322
  """
271
- Validator.check_value_type("node", node, [Node], "SymbolTree")
272
- return Node(self._symbol_tree.erase_node(node.get_handler()))
323
+ Validator.check_value_type("node", node, [Node, str], "SymbolTree")
324
+ if isinstance(node, Node):
325
+ node = node.get_handler()
326
+ return Node(self._symbol_tree.erase_node(node))
273
327
 
274
- def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
328
+ def replace(self, old_node: Node, new_nodes: List[Node]) -> Node:
275
329
  """
276
330
  Replace the `old_node` with nodes in the `new_nodes` list.
277
331
 
@@ -285,7 +339,7 @@ class SymbolTree:
285
339
 
286
340
  Args:
287
341
  old_node (Node): Node to be replaced.
288
- new_nodes (list[Node]): Nodes of the node_tree to replace in.
342
+ new_nodes (List[Node]): Nodes of the node_tree to replace in.
289
343
 
290
344
  Returns:
291
345
  An instance of Node represents root of node_tree been replaced in.
@@ -299,7 +353,7 @@ class SymbolTree:
299
353
  >>> from mindspore.rewrite import SymbolTree, ScopedValue
300
354
  >>> import mindspore.nn as nn
301
355
  >>> # Define the network structure of LeNet5. Refer to
302
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
356
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
303
357
  >>> net = LeNet5()
304
358
  >>> stree = SymbolTree.create(net)
305
359
  >>> node = stree.get_node("conv1")
@@ -320,16 +374,38 @@ class SymbolTree:
320
374
  def dump(self):
321
375
  self._symbol_tree.dump()
322
376
 
323
- def print_node_tabulate(self):
324
- """
377
+ def print_node_tabulate(self, all_nodes: bool = False):
378
+ r"""
325
379
  Print the topology information of nodes in SymbolTree, including node type, node name, node code,
326
380
  and node input-output relationship.
327
- The information is output to the screen using the print interface.
328
381
 
329
- .. warning::
330
- This is an experimental API that is subject to change or deletion.
382
+ The information is output to the screen using the print interface, including the following information:
383
+
384
+ - **node type** (str): The type of node, refer to class:`mindspore.rewrite.NodeType` .
385
+ - **name** (str): The name of node.
386
+ - **codes** (str): The source code statement corresponding to the node.
387
+ - **arg providers** (Dict[int, Tuple[str, int]]): The format is `{[idx, (n, k)]}` , which means the
388
+ `idx` th parameter of the node is provided by the `k` th output of node `n` .
389
+ - **target users** (Dict[int, List[Tuple[str, int]]]): The format is '{[idx, [(n, k)]]}' , which means
390
+ the `idx` th output of the node is used as the `k` th parameter of node `n` .
391
+
392
+ Args:
393
+ all_nodes (bool): Print information of all nodes, including nodes in CallFunction
394
+ node, CellContainer node and sub symbol tree. Default: ``False`` .
395
+
396
+ Raises:
397
+ TypeError: If `all_nodes` is not bool.
398
+
399
+ Examples:
400
+ >>> from mindspore.rewrite import SymbolTree
401
+ >>> # Define the network structure of LeNet5. Refer to
402
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
403
+ >>> net = LeNet5()
404
+ >>> stree = SymbolTree.create(net)
405
+ >>> stree.print_node_tabulate()
331
406
  """
332
- self._symbol_tree.print_node_tabulate()
407
+ Validator.check_value_type("all_nodes", all_nodes, [bool], "print_node_tabulate")
408
+ self._symbol_tree.print_node_tabulate(all_nodes)
333
409
 
334
410
  def get_code(self) -> str:
335
411
  """
@@ -342,7 +418,7 @@ class SymbolTree:
342
418
  Examples:
343
419
  >>> from mindspore.rewrite import SymbolTree
344
420
  >>> # Define the network structure of LeNet5. Refer to
345
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
421
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
346
422
  >>> net = LeNet5()
347
423
  >>> stree = SymbolTree.create(net)
348
424
  >>> codes = stree.get_code()
@@ -355,13 +431,21 @@ class SymbolTree:
355
431
  Get the network object generated based on SymbolTree.
356
432
  The source code is saved to a file in the 'rewritten_network' folder of the current directory.
357
433
 
434
+ Note:
435
+ - The modification of network by rewrite module is based on the modification of AST tree of
436
+ original network instance, and the new network instance will obtain attribute information
437
+ from original network instance, so the new network instance and the original network instance
438
+ have data association, and the original network should no longer be used.
439
+ - Due to the data association between the new network and the original network instance, manually creating
440
+ a network instance using the source code file generated by rewrite is not currently supported.
441
+
358
442
  Returns:
359
443
  A network object generated from SymbolTree.
360
444
 
361
445
  Examples:
362
446
  >>> from mindspore.rewrite import SymbolTree
363
447
  >>> # Define the network structure of LeNet5. Refer to
364
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
448
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
365
449
  >>> net = LeNet5()
366
450
  >>> stree = SymbolTree.create(net)
367
451
  >>> new_net = stree.get_network()
@@ -17,7 +17,8 @@
17
17
  Define some ast helpers for manipulating python ast.
18
18
  """
19
19
 
20
- from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj
20
+ from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj, \
21
+ AstAssignFinder, AstClassFinder, AstFunctionFinder
21
22
  from .ast_replacer import AstReplacer
22
23
  from .ast_modifier import AstModifier
23
24
  from .ast_creator import ast_args_creator, ast_assign_creator, ast_attributer_creator, ast_call_creator, \
@@ -225,3 +225,132 @@ class GetPropertyOfObj(ast.NodeVisitor):
225
225
  self._property = set()
226
226
  self.generic_visit(self._context)
227
227
  return self._property
228
+
229
+
230
+ class AstAssignFinder(ast.NodeVisitor):
231
+ """
232
+ Get assign definition ast of specifical parameter in specific scope.
233
+
234
+ Args:
235
+ node (ast.AST): An instance of ast node as check scope.
236
+ """
237
+ def __init__(self, node: ast.AST):
238
+ self._context = node
239
+ self._scope = ""
240
+ self._value = ""
241
+ self._target = None
242
+
243
+ def visit_Assign(self, node: ast.Assign):
244
+ if self._scope and isinstance(node.targets[0], ast.Attribute):
245
+ if node.targets[0].attr == self._value and isinstance(node.targets[0].value, ast.Name) \
246
+ and node.targets[0].value.id == self._scope:
247
+ self._target = node
248
+ elif not self._scope and isinstance(node.targets[0], ast.Name):
249
+ if node.targets[0].id == self._value:
250
+ self._target = node
251
+
252
+ def get_ast(self, value: str, scope: str = "") -> bool:
253
+ """
254
+ Get assign ast of specifical parameter in specific ast.
255
+
256
+ Args:
257
+ value (str): A string indicates assign target value.
258
+ scope (str): A string indicates assign target scope.
259
+
260
+ Returns:
261
+ An assign ast with the same target name as `scope.value` .
262
+ """
263
+ self._scope = scope
264
+ self._value = value
265
+ self.generic_visit(self._context)
266
+ return self._target
267
+
268
+
269
+ class AstClassFinder(ast.NodeVisitor):
270
+ """
271
+ Find all specific name of ast class node in specific scope.
272
+
273
+ Args:
274
+ node (ast.AST): An instance of ast node as search scope.
275
+ """
276
+
277
+ def __init__(self, node: ast.AST):
278
+ self._scope: ast.AST = node
279
+ self._target: str = ""
280
+ self._results: [ast.ClassDef] = []
281
+
282
+ def visit_ClassDef(self, node):
283
+ """
284
+ An override method, iterating over all ClassDef nodes and save target ast nodes.
285
+
286
+ Args:
287
+ node (ast.AST): An instance of ast node which is visited currently.
288
+ """
289
+
290
+ if node.name == self._target:
291
+ self._results.append(node)
292
+
293
+ def find_all(self, class_name: str) -> [ast.AST]:
294
+ """
295
+ Find all matched ast node.
296
+
297
+ Args:
298
+ class_name (str): Name of class to be found.
299
+
300
+ Returns:
301
+ A list of instance of ast.ClassDef as matched result.
302
+
303
+ Raises:
304
+ TypeError: If input `class_name` is not str.
305
+ """
306
+ if not isinstance(class_name, str):
307
+ raise TypeError("Input class_name should be a str")
308
+ self._target = class_name
309
+ self._results.clear()
310
+ self.visit(self._scope)
311
+ return self._results
312
+
313
+
314
+ class AstFunctionFinder(ast.NodeVisitor):
315
+ """
316
+ Find all specific name of ast function node in specific scope.
317
+
318
+ Args:
319
+ node (ast.AST): An instance of ast node as search scope.
320
+ """
321
+
322
+ def __init__(self, node: ast.AST):
323
+ self._scope: ast.AST = node
324
+ self._target: str = ""
325
+ self._results: [ast.ClassDef] = []
326
+
327
+ def visit_FunctionDef(self, node):
328
+ """
329
+ An override method, iterating over all FunctionDef nodes and save target ast nodes.
330
+
331
+ Args:
332
+ node (ast.AST): An instance of ast node which is visited currently.
333
+ """
334
+
335
+ if node.name == self._target:
336
+ self._results.append(node)
337
+
338
+ def find_all(self, func_name: str) -> [ast.AST]:
339
+ """
340
+ Find all matched ast node.
341
+
342
+ Args:
343
+ func_name (str): Name of function to be found.
344
+
345
+ Returns:
346
+ A list of instance of ast.FunctionDef as matched result.
347
+
348
+ Raises:
349
+ TypeError: If input `func_name` is not str.
350
+ """
351
+ if not isinstance(func_name, str):
352
+ raise TypeError("Input func_name should be a str")
353
+ self._target = func_name
354
+ self._results.clear()
355
+ self.visit(self._scope)
356
+ return self._results