mindspore 2.1.0__cp37-cp37m-win_amd64.whl → 2.2.11__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (511) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +13 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +67 -72
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +86 -106
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +29 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +33 -7
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +61 -95
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/__init__.py +4 -2
  205. mindspore/nn/wrap/cell_wrapper.py +87 -34
  206. mindspore/nn/wrap/grad_reducer.py +8 -5
  207. mindspore/nn/wrap/loss_scale.py +105 -42
  208. mindspore/numpy/array_creations.py +1 -2
  209. mindspore/numpy/array_ops.py +3 -2
  210. mindspore/numpy/utils_const.py +5 -5
  211. mindspore/opencv_core452.dll +0 -0
  212. mindspore/opencv_imgcodecs452.dll +0 -0
  213. mindspore/opencv_imgproc452.dll +0 -0
  214. mindspore/ops/_grad_experimental/__init__.py +0 -5
  215. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  216. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  217. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  218. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  219. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  220. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  221. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  222. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  223. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  224. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  225. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  226. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  227. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  228. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  229. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  230. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  231. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  232. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  233. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  234. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  235. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  236. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  237. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  238. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  239. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  240. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  241. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  242. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  243. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  244. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  245. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  246. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  248. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  249. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  250. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  251. mindspore/ops/_primitive_cache.py +1 -1
  252. mindspore/ops/_tracefunc.py +45 -13
  253. mindspore/ops/_utils/utils.py +6 -1
  254. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  255. mindspore/ops/_vmap/vmap_base.py +3 -3
  256. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  257. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  258. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  259. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  260. mindspore/ops/arg_dtype_cast.py +54 -0
  261. mindspore/ops/composite/base.py +37 -10
  262. mindspore/ops/composite/math_ops.py +5 -4
  263. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  264. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  265. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  266. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  268. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  270. mindspore/ops/deprecated.py +304 -0
  271. mindspore/ops/function/__init__.py +4 -1
  272. mindspore/ops/function/array_func.py +174 -193
  273. mindspore/ops/function/clip_func.py +81 -13
  274. mindspore/ops/function/debug_func.py +1 -1
  275. mindspore/ops/function/grad/grad_func.py +18 -9
  276. mindspore/ops/function/image_func.py +10 -4
  277. mindspore/ops/function/linalg_func.py +5 -5
  278. mindspore/ops/function/math_func.py +575 -386
  279. mindspore/ops/function/nn_func.py +568 -260
  280. mindspore/ops/function/random_func.py +88 -57
  281. mindspore/ops/function/sparse_func.py +1 -1
  282. mindspore/ops/function/sparse_unary_func.py +14 -12
  283. mindspore/ops/function/vmap_func.py +6 -5
  284. mindspore/ops/functional.py +15 -10
  285. mindspore/ops/op_info_register.py +244 -25
  286. mindspore/ops/operations/__init__.py +31 -19
  287. mindspore/ops/operations/_grad_ops.py +71 -7
  288. mindspore/ops/operations/_inner_ops.py +350 -17
  289. mindspore/ops/operations/_quant_ops.py +4 -8
  290. mindspore/ops/operations/_sequence_ops.py +42 -0
  291. mindspore/ops/operations/array_ops.py +68 -282
  292. mindspore/ops/operations/comm_ops.py +107 -59
  293. mindspore/ops/operations/custom_ops.py +94 -70
  294. mindspore/ops/operations/debug_ops.py +8 -4
  295. mindspore/ops/operations/image_ops.py +18 -12
  296. mindspore/ops/operations/inner_ops.py +26 -3
  297. mindspore/ops/operations/math_ops.py +192 -144
  298. mindspore/ops/operations/nn_ops.py +857 -489
  299. mindspore/ops/operations/other_ops.py +0 -22
  300. mindspore/ops/operations/random_ops.py +53 -111
  301. mindspore/ops/operations/sparse_ops.py +3 -1
  302. mindspore/ops/primitive.py +24 -18
  303. mindspore/parallel/_auto_parallel_context.py +68 -8
  304. mindspore/parallel/_cost_model_context.py +2 -2
  305. mindspore/parallel/_offload_context.py +17 -3
  306. mindspore/parallel/_parallel_serialization.py +12 -5
  307. mindspore/parallel/_ps_context.py +12 -0
  308. mindspore/parallel/_tensor.py +18 -13
  309. mindspore/parallel/_transformer/layers.py +5 -3
  310. mindspore/parallel/_transformer/loss.py +1 -0
  311. mindspore/parallel/_transformer/moe.py +2 -2
  312. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  313. mindspore/parallel/_transformer/transformer.py +23 -3
  314. mindspore/parallel/_utils.py +11 -7
  315. mindspore/parallel/algo_parameter_config.py +85 -5
  316. mindspore/parallel/checkpoint_transform.py +19 -12
  317. mindspore/parallel/shard.py +21 -14
  318. mindspore/pgodb140.dll +0 -0
  319. mindspore/pgort140.dll +0 -0
  320. mindspore/profiler/common/struct_type.py +3 -3
  321. mindspore/profiler/common/util.py +4 -2
  322. mindspore/profiler/envprofiling.py +1 -1
  323. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  324. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  325. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  326. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  327. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  328. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  329. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  330. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  331. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  332. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  333. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  334. mindspore/profiler/parser/flops_parser.py +15 -11
  335. mindspore/profiler/parser/framework_parser.py +38 -22
  336. mindspore/profiler/parser/hccl_parser.py +16 -12
  337. mindspore/profiler/parser/integrator.py +22 -11
  338. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  339. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  340. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  341. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  342. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  343. mindspore/profiler/parser/optime_parser.py +1 -1
  344. mindspore/profiler/parser/profiler_info.py +21 -2
  345. mindspore/profiler/parser/step_trace_parser.py +11 -14
  346. mindspore/profiler/profiling.py +179 -89
  347. mindspore/rewrite/api/node.py +102 -19
  348. mindspore/rewrite/api/node_type.py +5 -1
  349. mindspore/rewrite/api/pattern_engine.py +1 -1
  350. mindspore/rewrite/api/scoped_value.py +9 -17
  351. mindspore/rewrite/api/symbol_tree.py +131 -47
  352. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  353. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  354. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  355. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  356. mindspore/rewrite/common/rewrite_elog.py +5 -1
  357. mindspore/rewrite/namer.py +33 -24
  358. mindspore/rewrite/namespace.py +14 -5
  359. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  360. mindspore/rewrite/node/call_function.py +79 -0
  361. mindspore/rewrite/node/cell_container.py +135 -0
  362. mindspore/rewrite/node/control_flow.py +88 -0
  363. mindspore/rewrite/{node.py → node/node.py} +273 -234
  364. mindspore/rewrite/node/node_manager.py +254 -0
  365. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  366. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  367. mindspore/rewrite/parsers/assign_parser.py +216 -221
  368. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  369. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  370. mindspore/rewrite/parsers/constant_parser.py +9 -6
  371. mindspore/rewrite/parsers/container_parser.py +9 -7
  372. mindspore/rewrite/parsers/for_parser.py +42 -21
  373. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  374. mindspore/rewrite/parsers/if_parser.py +28 -24
  375. mindspore/rewrite/parsers/module_parser.py +196 -25
  376. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  377. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  378. mindspore/rewrite/parsers/return_parser.py +6 -6
  379. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  380. mindspore/rewrite/sparsify/utils.py +1 -1
  381. mindspore/rewrite/symbol_tree.py +523 -578
  382. mindspore/rewrite/symbol_tree_builder.py +9 -193
  383. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  384. mindspore/run_check/_check_version.py +6 -4
  385. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  386. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  387. mindspore/tbbmalloc.dll +0 -0
  388. mindspore/tinyxml2.dll +0 -0
  389. mindspore/train/_utils.py +7 -3
  390. mindspore/train/amp.py +323 -123
  391. mindspore/train/anf_ir_pb2.py +14 -2
  392. mindspore/train/callback/_backup_and_restore.py +2 -12
  393. mindspore/train/callback/_callback.py +29 -4
  394. mindspore/train/callback/_checkpoint.py +23 -8
  395. mindspore/train/callback/_early_stop.py +2 -2
  396. mindspore/train/callback/_landscape.py +4 -4
  397. mindspore/train/callback/_loss_monitor.py +2 -2
  398. mindspore/train/callback/_on_request_exit.py +2 -2
  399. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  400. mindspore/train/callback/_summary_collector.py +15 -8
  401. mindspore/train/callback/_time_monitor.py +58 -5
  402. mindspore/train/data_sink.py +5 -11
  403. mindspore/train/dataset_helper.py +84 -57
  404. mindspore/train/loss_scale_manager.py +2 -2
  405. mindspore/train/metrics/__init__.py +3 -3
  406. mindspore/train/metrics/cosine_similarity.py +1 -1
  407. mindspore/train/metrics/hausdorff_distance.py +3 -2
  408. mindspore/train/metrics/mean_surface_distance.py +3 -2
  409. mindspore/train/metrics/metric.py +39 -19
  410. mindspore/train/metrics/roc.py +2 -2
  411. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  412. mindspore/train/mind_ir_pb2.py +85 -36
  413. mindspore/train/model.py +187 -47
  414. mindspore/train/serialization.py +487 -161
  415. mindspore/train/summary/_summary_adapter.py +1 -1
  416. mindspore/train/summary/_writer_pool.py +3 -2
  417. mindspore/train/summary/summary_record.py +37 -17
  418. mindspore/train/train_thor/convert_utils.py +3 -3
  419. mindspore/train/train_thor/dataset_helper.py +1 -1
  420. mindspore/turbojpeg.dll +0 -0
  421. mindspore/vcmeta.dll +0 -0
  422. mindspore/vcruntime140.dll +0 -0
  423. mindspore/vcruntime140_1.dll +0 -0
  424. mindspore/version.py +1 -1
  425. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  426. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
  427. mindspore/_extends/graph_kernel/expander.py +0 -80
  428. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  429. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  430. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  431. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  432. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  433. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  434. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  435. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  436. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  437. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  438. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  439. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  440. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  441. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  442. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  443. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  444. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  445. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  446. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  447. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  448. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  449. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  450. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  451. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  452. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  453. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  454. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  455. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  456. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  457. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  458. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  459. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  460. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  461. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  462. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  463. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  464. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  465. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  466. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  467. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  468. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  469. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  470. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  471. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  472. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  473. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  474. mindspore/dataset/datapreprocess/__init__.py +0 -20
  475. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  476. mindspore/include/api/net.h +0 -142
  477. mindspore/nn/lr_scheduler.py +0 -262
  478. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  479. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  480. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  481. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  482. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  483. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  484. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  485. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  486. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  487. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  488. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  489. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  490. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  491. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  492. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  493. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  497. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  502. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  503. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  504. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  505. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  506. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  507. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  508. mindspore/rewrite/node_visitor.py +0 -44
  509. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  510. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  511. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,11 @@
15
15
  """Parse ast.Assign in construct function to node of SymbolTree."""
16
16
  import ast
17
17
 
18
- from mindspore.rewrite.parser import Parser
18
+ from mindspore.rewrite.parsers.parser import Parser
19
19
  from mindspore.rewrite.symbol_tree import SymbolTree
20
- from mindspore.rewrite.parser_register import ParserRegister
21
-
22
- from mindspore.rewrite.parser_register import reg_parser
20
+ from mindspore.rewrite.parsers.parser_register import ParserRegister, reg_parser
23
21
  from ..common import error_str
24
-
22
+ from ..node.node_manager import NodeManager
25
23
 
26
24
  class AttributeParser(Parser):
27
25
  """Parse ast.Attribute in construct function to node of SymbolTree."""
@@ -30,13 +28,14 @@ class AttributeParser(Parser):
30
28
  """Parse target type."""
31
29
  return ast.Attribute
32
30
 
33
- def process(self, stree: SymbolTree, node: ast.Attribute):
31
+ def process(self, stree: SymbolTree, node: ast.Attribute, node_manager: NodeManager):
34
32
  """
35
33
  Parse ast.Attribute node.
36
34
 
37
35
  Args:
38
36
  stree ([SymbolTree]): Symbol Tree under parsing.
39
37
  node ([ast.Attribute]): An ast.Attribute node.
38
+ node_manager (NodeManager): NodeManager those asts belong to.
40
39
 
41
40
  Returns:
42
41
  The value of node.
@@ -47,8 +46,11 @@ class AttributeParser(Parser):
47
46
  if not isinstance(node, ast.Attribute):
48
47
  raise TypeError(error_str(f"Attribute parser only supports parsing ast.Attribute type nodes, but got "
49
48
  f"'{type(node).__name__}'", father_node=node))
49
+ if not isinstance(node.value, (ast.Name, ast.Attribute)):
50
+ raise RuntimeError(error_str(f"Attribute parser only supports (ast.Attribute, ast.Name) as value of "
51
+ f"ast.Attribute, but got '{type(node).__name__}'", father_node=node))
50
52
  parser = ParserRegister.instance().get_parser(type(node.value))
51
- value = parser.process(stree, node.value)
53
+ value = parser.process(stree, node.value, node_manager)
52
54
 
53
55
  return ".".join([value, node.attr])
54
56
 
@@ -13,17 +13,19 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
16
- import sys
17
- import ast
18
16
  import inspect
17
+ from typing import Union, Dict
18
+ import ast
19
19
  from mindspore import log as logger
20
20
  from mindspore.nn import Cell
21
21
  from mindspore._extends.parse.namespace import CellNamespace
22
22
  from ..symbol_tree import SymbolTree
23
- from ..parser import Parser
24
- from ..parser_register import ParserRegister, reg_parser
23
+ from .parser import Parser
24
+ from .parser_register import ParserRegister, reg_parser
25
25
  from ..ast_helpers import AstReplacer
26
26
  from ..common import error_str
27
+ from ..parsers.module_parser import ModuleParser
28
+ from ..node.node_manager import NodeManager
27
29
 
28
30
 
29
31
  class AstScopeChecker:
@@ -106,25 +108,58 @@ class AstScopeChecker:
106
108
  class ClassDefParser(Parser):
107
109
  """Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
108
110
 
111
+ # a denied_function_decorator_list which is registered by user
112
+ denied_function_decorator_list = []
113
+ # Entry function of the forward computation process
114
+ entry_function = "construct"
115
+
109
116
  def __init__(self):
110
117
  """Constructor"""
111
118
  super(ClassDefParser, self).__init__()
112
119
  self._cell_namespace = CellNamespace('mindspore.nn')
113
120
 
114
121
  @staticmethod
115
- def _is_super_expr(expr: ast.AST) -> bool:
116
- """Check whether ast node is super().__init__()"""
117
- if not isinstance(expr, ast.Expr):
118
- return False
119
- if not isinstance(expr.value, ast.Call):
120
- return False
121
- if not isinstance(expr.value.func, ast.Attribute):
122
- return False
123
- if expr.value.func.attr != "__init__" or not isinstance(expr.value.func.value, ast.Call):
124
- return False
125
- if not isinstance(expr.value.func.value.func, ast.Name) or expr.value.func.value.func.id != "super":
126
- return False
127
- return True
122
+ def _process_init_func_ast(init_ast: ast.FunctionDef, class_name: str, is_father_class: bool,
123
+ father_classes: dict):
124
+ """Process init func"""
125
+ ClassDefParser._modify_arguments_of_init_func(init_ast)
126
+ new_bodies = ClassDefParser._create_bodys_of_init_func(class_name, is_father_class, father_classes)
127
+ init_ast.body = new_bodies
128
+
129
+ @staticmethod
130
+ def _create_bodys_of_init_func(class_name: str, is_father_class: bool, father_classes: dict):
131
+ """Modify bodys of init func."""
132
+ new_bodies = []
133
+ # update father class init in new class
134
+ father_class_init_bodies = ClassDefParser._father_class_init_process(father_classes, is_father_class)
135
+ new_bodies.extend(father_class_init_bodies)
136
+ # copy variables into new class
137
+ if is_father_class:
138
+ ast_copy_attr = ast.parse(
139
+ "for key, value in obj.__dict__.items():\n"
140
+ " if not key.startswith('__'):\n"
141
+ f" setattr({class_name}, key, value)").body[0]
142
+ new_bodies.append(ast_copy_attr)
143
+ else:
144
+ ast_copy_attr = ast.parse(
145
+ "for key, value in obj.__dict__.items(): setattr(self, key, value)").body[0]
146
+ new_bodies.append(ast_copy_attr)
147
+ return new_bodies
148
+
149
+ @staticmethod
150
+ def _father_class_init_process(father_classes: dict, is_father_class: bool) -> [ast.AST]:
151
+ """Add ast bodies of code: father_class.__init__(...)"""
152
+ father_class_init_bodies = []
153
+ for idx, father_class in father_classes.items():
154
+ if father_class == "Cell":
155
+ father_class_init_code = "super().__init__()"
156
+ elif is_father_class:
157
+ father_class_init_code = f"{father_class}.__init__(self, obj.__bases__[{idx}])"
158
+ else:
159
+ father_class_init_code = f"{father_class}.__init__(self, obj.__class__.__bases__[{idx}])"
160
+ father_class_init_ast = ast.parse(father_class_init_code).body[0]
161
+ father_class_init_bodies.append(father_class_init_ast)
162
+ return father_class_init_bodies
128
163
 
129
164
  @staticmethod
130
165
  def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
@@ -135,127 +170,153 @@ class ClassDefParser(Parser):
135
170
  kw_defaults=[], defaults=[], vararg=None, kwarg=None)
136
171
  ast.fix_missing_locations(ast_init_fn)
137
172
 
173
+ @staticmethod
174
+ def get_ast_name(ast_node: Union[ast.Name, ast.Attribute]) -> str:
175
+ """Get ast id name"""
176
+ if isinstance(ast_node, ast.Name):
177
+ return ast_node.id
178
+ if isinstance(ast_node, ast.Attribute):
179
+ return ast_node.attr
180
+ return ""
181
+
182
+ @staticmethod
183
+ def _process_class_variables(stree: SymbolTree, function_defs: list):
184
+ """Process class variables of class, only used in child class."""
185
+ init_func_ast = stree.get_init_func_ast()
186
+ for key, value in stree.get_origin_network().__class__.__dict__.items():
187
+ if key.startswith('__'):
188
+ # ignore inner functions
189
+ continue
190
+ if callable(value) and key in function_defs:
191
+ # ignore functions defined by self
192
+ continue
193
+ assign_code = f"self.__class__.{key} = obj.__class__.{key}"
194
+ assign_ast = ast.parse(assign_code).body[0]
195
+ init_func_ast.body.append(assign_ast)
196
+
197
+ @staticmethod
198
+ def _need_add_init_func(cls_ast: ast.ClassDef) -> bool:
199
+ """If the class don't have init func, we need to add an init func"""
200
+ for body in cls_ast.body:
201
+ if isinstance(body, ast.FunctionDef) and body.name == '__init__':
202
+ return False
203
+ return True
204
+
205
+ @staticmethod
206
+ def _add_init_func(cls_ast: ast.ClassDef):
207
+ """Add init func with super().__init__()"""
208
+ init_func_ast = ast.parse("def __init__(self): super().__init__()").body[0]
209
+ cls_ast.body.insert(0, init_func_ast)
210
+ ast.fix_missing_locations(cls_ast)
211
+
212
+ @staticmethod
213
+ def _process_father_classes(stree, node: ast.ClassDef, cur_class_def: type) -> list:
214
+ """Process father class."""
215
+ father_classes: Dict[int, str] = {}
216
+ for idx, base in enumerate(node.bases):
217
+ father_class_name = ClassDefParser.get_ast_name(base)
218
+ if not father_class_name:
219
+ continue
220
+ father_classes[idx] = father_class_name
221
+ if father_class_name == "Cell":
222
+ continue
223
+ father_class_def = cur_class_def.__bases__[idx]
224
+ ClassDefParser._process_one_father_class(stree, father_class_def, father_class_name)
225
+ node.bases[idx] = ast.Name(id=father_class_name, ctx=ast.Load())
226
+ return father_classes
227
+
228
+ @staticmethod
229
+ def _process_one_father_class(stree: SymbolTree, father_class_def: type, father_class_name: str):
230
+ """Process one father class"""
231
+ # save father class's file path and imports into symbol tree
232
+ net_path = inspect.getfile(father_class_def)
233
+ ModuleParser.save_file_path_to_sys(stree, 0, net_path)
234
+ ModuleParser.save_imports_from_file(stree, net_path)
235
+ # get father class's ast
236
+ source_code = inspect.getsource(father_class_def)
237
+ father_class_ast: ast.ClassDef = ast.parse(source_code).body[0]
238
+ # process father class's father classes
239
+ father_classes = ClassDefParser._process_father_classes(stree, father_class_ast, father_class_def)
240
+ # process father class's __init__ function
241
+ if ClassDefParser._need_add_init_func(father_class_ast):
242
+ ClassDefParser._add_init_func(father_class_ast)
243
+ for body in father_class_ast.body[:]:
244
+ if isinstance(body, ast.FunctionDef) and body.name == "__init__":
245
+ # Add function decorator
246
+ ClassDefParser._func_decorator_process(body)
247
+ ClassDefParser._process_init_func_ast(body, father_class_name, True, father_classes)
248
+ else:
249
+ # Remove other codes, which are copied in __init__ function.
250
+ father_class_ast.body.remove(body)
251
+ # save father class's ast into symbol tree
252
+ stree.get_father_class_ast().append(father_class_ast)
253
+
254
+ @staticmethod
255
+ def _func_decorator_process(node: ast.FunctionDef):
256
+ """
257
+ User should set the denied function decorators,
258
+ because the symbol_tree cant pass the correct parameters to decorators but the instance "obj".
259
+ """
260
+ for decorator in node.decorator_list[:]:
261
+ decorator_name = ""
262
+ if isinstance(decorator, ast.Call):
263
+ func = decorator.func
264
+ if isinstance(func, ast.Name):
265
+ decorator_name = func.id
266
+ elif isinstance(decorator, ast.Name):
267
+ decorator_name = decorator.id
268
+ if decorator_name in ClassDefParser.denied_function_decorator_list:
269
+ node.decorator_list.remove(decorator)
270
+
138
271
  def target(self):
139
272
  """Parse target type"""
140
273
  return ast.ClassDef
141
274
 
142
- def process(self, stree: SymbolTree, node: ast.ClassDef):
275
+ def process(self, stree: SymbolTree, node: ast.ClassDef, node_manager: NodeManager):
143
276
  """
144
- Parse init and construct in ast.ClassDef.
277
+ Parse init and entry function(default: construct) in ast.ClassDef.
145
278
 
146
279
  Args:
147
280
  stree ([SymbolTree]): Symbol Tree under parsing.
148
281
  node ([ast.ClassDef]): An ast.ClassDef node.
282
+ node_manager (NodeManager): NodeManager those asts belong to.
149
283
  """
284
+ # Update network's class name from xxx to xxxOpt in ast
150
285
  replacer = AstReplacer(node)
151
286
  replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
152
287
 
288
+ # process network's father classes
153
289
  stree.set_class_ast(node)
154
- has_father_class = self._handle_father_class(stree, node)
290
+ cur_class_def = type(stree.get_origin_network())
291
+ father_classes = ClassDefParser._process_father_classes(stree, node, cur_class_def)
155
292
 
156
- if self._need_add_init_func(stree, node):
157
- self._add_init_func(node)
293
+ # add __init__ function to network if necessary
294
+ if isinstance(stree.get_origin_network(), Cell) and ClassDefParser._need_add_init_func(node):
295
+ ClassDefParser._add_init_func(node)
158
296
 
159
- for body in node.body:
297
+ # save function defs in ast node to filter function class variables.
298
+ function_defs = []
299
+ for body in node.body[:]:
160
300
  if isinstance(body, ast.FunctionDef):
301
+ function_defs.append(body.name)
302
+ ClassDefParser._func_decorator_process(body)
161
303
  if body.name == "__init__":
162
- self._process_init_func_ast(body, has_father_class)
163
304
  stree.set_init_func_ast(body)
164
- elif body.name == "construct":
305
+ ClassDefParser._process_init_func_ast(body, stree.get_opt_cls_name(), False, father_classes)
306
+ elif body.name == ClassDefParser.entry_function:
307
+ stree.set_ast_root(body)
165
308
  parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef)
166
- parser.process(stree, body)
309
+ parser.process(stree, body, stree)
167
310
  else:
168
311
  logger.info(
169
312
  "Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s",
170
313
  body.name)
314
+ elif isinstance(body, (ast.Assign, ast.If, ast.IfExp)):
315
+ # Remove class variables, which are copied in __init__ function.
316
+ node.body.remove(body)
171
317
  else:
172
318
  logger.info("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__)
173
-
174
- def _is_subtree_field(self, ori_net, field) -> bool:
175
- op = getattr(ori_net, field)
176
- return not type(op).__name__ in self._cell_namespace
177
-
178
- def _process_init_func_ast(self, init_ast: ast.FunctionDef, has_father_class: bool):
179
- """Process init func"""
180
- ClassDefParser._modify_arguments_of_init_func(init_ast)
181
- new_bodies = self._replace_ori_field_of_init_func(init_ast.body, has_father_class)
182
- init_ast.body = new_bodies
183
-
184
- def _need_add_init_func(self, stree: SymbolTree, cls_ast: ast.ClassDef) -> bool:
185
- """If class is child class of nn.Cell but not have init func, then we need to add init func"""
186
- if not isinstance(stree.get_origin_network(), Cell):
187
- return False
188
- for body in cls_ast.body:
189
- if isinstance(body, ast.FunctionDef) and body.name == '__init__':
190
- return False
191
- return True
192
-
193
- def _add_init_func(self, cls_ast: ast.ClassDef):
194
- """Add init func with super().__init__()"""
195
- init_func_ast = ast.parse("def __init__(self): super().__init__()").body[0]
196
- cls_ast.body.insert(0, init_func_ast)
197
- ast.fix_missing_locations(cls_ast)
198
-
199
- def _replace_ori_field_of_init_func(self, bodies: [], has_father_class: bool):
200
- """
201
- Replace original field in init func to self.XX = getattr(self._handler, "XX").
202
- Only keep following two kinds of ast nodes in bodies right now:
203
- 1. Ast.If and test is self.XX.
204
- 2. Ast.Assign and target is self.XX.
205
-
206
- Args:
207
- bodies ([]): bodied of init ast.FunctionDef.
208
- has_father_class (bool): whether class has father class that is not nn.Cell
209
-
210
- Raises:
211
- RuntimeError: Not support multi-targets in assign.
212
- RuntimeError: Only support target.value in [ast.Name] in assign node.
213
- """
214
- new_bodies = []
215
- for body in bodies:
216
- if self._is_super_expr(body):
217
- if has_father_class:
218
- body.value.args = [ast.Name(id='obj', ctx=ast.Load())]
219
- body.value.keywords = []
220
- new_bodies.append(body)
221
- continue
222
- ast_copy_attr = ast.parse(
223
- "for key, value in obj.__dict__.items(): setattr(self, key, value)").body[0]
224
- new_bodies.append(ast_copy_attr)
225
- return new_bodies
226
-
227
- def _handle_father_class(self, stree, node: ast.ClassDef) -> bool:
228
- """Handle father class."""
229
- has_father_class = False
230
- for base in node.bases:
231
- parser: Parser = ParserRegister.instance().get_parser(type(base))
232
- father_class = parser.process(stree, base)
233
- if "Cell" not in father_class:
234
- for k, m in sys.modules.items():
235
- if k in ("_ast", "ast"):
236
- continue
237
- if hasattr(m, father_class):
238
- cls = getattr(m, father_class)
239
- if not inspect.isclass(cls):
240
- continue
241
- source_code = inspect.getsource(cls)
242
- father_class_ast: ast.Module = ast.parse(source_code)
243
- self._father_class_process_init_func_ast(stree, father_class_ast)
244
- stree._father_class_ast.append(father_class_ast) # pylint: disable=protected-access
245
- has_father_class = True
246
- break
247
- return has_father_class
248
-
249
- def _father_class_process_init_func_ast(self, stree: SymbolTree, father_class_ast: ast.Module):
250
- father_class_stree: SymbolTree = SymbolTree(stree.get_origin_network(), father_class_ast)
251
- for ast_body in father_class_ast.body:
252
- if isinstance(ast_body, ast.ClassDef):
253
- has_father_class = self._handle_father_class(stree, ast_body)
254
- if self._need_add_init_func(father_class_stree, ast_body):
255
- self._add_init_func(ast_body)
256
- for body in ast_body.body:
257
- if isinstance(body, ast.FunctionDef) and body.name == "__init__":
258
- self._process_init_func_ast(body, has_father_class)
259
-
319
+ # Copy function class variables into new network
320
+ ClassDefParser._process_class_variables(stree, function_defs)
260
321
 
261
322
  g_classdef_parser = reg_parser(ClassDefParser())
@@ -15,11 +15,11 @@
15
15
  """Parse ast.Assign in construct function to node of SymbolTree."""
16
16
  import ast
17
17
 
18
- from mindspore.rewrite.parser import Parser
18
+ from mindspore.rewrite.parsers.parser import Parser
19
19
  from mindspore.rewrite.symbol_tree import SymbolTree
20
- from mindspore.rewrite.parser_register import reg_parser
20
+ from mindspore.rewrite.parsers.parser_register import reg_parser
21
21
  from ..common import error_str
22
-
22
+ from ..node.node_manager import NodeManager
23
23
 
24
24
  class NameParser(Parser):
25
25
  """Parse ast.Name in construct function to node of SymbolTree."""
@@ -28,13 +28,14 @@ class NameParser(Parser):
28
28
  """Parse target type."""
29
29
  return ast.Name
30
30
 
31
- def process(self, stree: SymbolTree, node: ast.Name):
31
+ def process(self, stree: SymbolTree, node: ast.Name, node_manager: NodeManager):
32
32
  """
33
33
  Parse ast.Name node.
34
34
 
35
35
  Args:
36
36
  stree ([SymbolTree]): Symbol Tree under parsing.
37
37
  node ([ast.Name]): An ast.Name node.
38
+ node_manager (NodeManager): NodeManager those asts belong to.
38
39
 
39
40
  Raises:
40
41
  TypeError: Name parser only supports parsing ast.Name type nodes.
@@ -52,13 +53,14 @@ class NumParser(Parser):
52
53
  """Parse target type."""
53
54
  return ast.Num
54
55
 
55
- def process(self, stree: SymbolTree, node: ast.Num):
56
+ def process(self, stree: SymbolTree, node: ast.Num, node_manager: NodeManager):
56
57
  """
57
58
  Parse ast.Num node.
58
59
 
59
60
  Args:
60
61
  stree ([SymbolTree]): Symbol Tree under parsing.
61
62
  node ([ast.Num]): An ast.Num node.
63
+ node_manager (NodeManager): NodeManager those asts belong to.
62
64
 
63
65
  Raises:
64
66
  TypeError: Num parser only supports parsing ast.Num type nodes.
@@ -76,13 +78,14 @@ class StrParser(Parser):
76
78
  """Parse target type."""
77
79
  return ast.Str
78
80
 
79
- def process(self, stree: SymbolTree, node: ast.Str):
81
+ def process(self, stree: SymbolTree, node: ast.Str, node_manager: NodeManager):
80
82
  """
81
83
  Parse ast.Str node.
82
84
 
83
85
  Args:
84
86
  stree ([SymbolTree]): Symbol Tree under parsing.
85
87
  node ([ast.Str]): An ast.Str node.
88
+ node_manager (NodeManager): NodeManager those asts belong to.
86
89
 
87
90
  Returns:
88
91
  The value of node.
@@ -15,12 +15,12 @@
15
15
  """Parse Container in construct function to node of SymbolTree."""
16
16
  import ast
17
17
 
18
- from mindspore.rewrite.parser import Parser
18
+ from mindspore.rewrite.parsers.parser import Parser
19
19
  from mindspore.rewrite.symbol_tree import SymbolTree
20
- from mindspore.rewrite.parser_register import ParserRegister
20
+ from mindspore.rewrite.parsers.parser_register import ParserRegister, reg_parser
21
21
 
22
- from mindspore.rewrite.parser_register import reg_parser
23
22
  from ..common import error_str
23
+ from ..node.node_manager import NodeManager
24
24
 
25
25
 
26
26
  class ListParser(Parser):
@@ -30,13 +30,14 @@ class ListParser(Parser):
30
30
  """Parse target type."""
31
31
  return list
32
32
 
33
- def process(self, stree: SymbolTree, node: list):
33
+ def process(self, stree: SymbolTree, node: list, node_manager: NodeManager):
34
34
  """
35
35
  Parse list.
36
36
 
37
37
  Args:
38
38
  stree ([SymbolTree]): Symbol Tree under parsing.
39
39
  node ([list]): An list of node.
40
+ father_node_managernode (NodeManager): NodeManager those asts belong to.
40
41
 
41
42
  Returns:
42
43
  A list of value.
@@ -50,7 +51,7 @@ class ListParser(Parser):
50
51
  result = []
51
52
  for n in node:
52
53
  parser = ParserRegister.instance().get_parser(type(n))
53
- value = parser.process(stree, n)
54
+ value = parser.process(stree, n, node_manager)
54
55
  result.append(value)
55
56
  return result
56
57
 
@@ -62,13 +63,14 @@ class TupleParser(Parser):
62
63
  """Parse target type."""
63
64
  return tuple
64
65
 
65
- def process(self, stree: SymbolTree, node: tuple):
66
+ def process(self, stree: SymbolTree, node: tuple, node_manager: NodeManager):
66
67
  """
67
68
  Parse tuple.
68
69
 
69
70
  Args:
70
71
  stree ([SymbolTree]): Symbol Tree under parsing.
71
72
  node ([tuple]): An tuple of node.
73
+ node_manager (NodeManager): NodeManager those asts belong to.
72
74
 
73
75
  Returns:
74
76
  A tuple of value.
@@ -79,7 +81,7 @@ class TupleParser(Parser):
79
81
  result = []
80
82
  for n in node:
81
83
  parser = ParserRegister.instance().get_parser(type(n))
82
- value = parser.process(stree, n)
84
+ value = parser.process(stree, n, node_manager)
83
85
  result.append(value)
84
86
  return tuple(result)
85
87
 
@@ -13,17 +13,23 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """ Parse ast.For node """
16
+ import sys
16
17
  import ast
17
- import astunparse
18
18
 
19
19
  from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
20
20
  from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier
21
21
  from mindspore import log as logger
22
22
  from mindspore import nn
23
23
  from ..symbol_tree import SymbolTree
24
- from ..parser import Parser
25
- from ..parser_register import reg_parser
24
+ from .parser import Parser
25
+ from .parser_register import reg_parser
26
26
  from ..common.event import Event
27
+ from ..node.node_manager import NodeManager
28
+
29
+ if sys.version_info >= (3, 9):
30
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
31
+ else:
32
+ import astunparse
27
33
 
28
34
  EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(")
29
35
 
@@ -34,20 +40,20 @@ class ForParser(Parser):
34
40
  @staticmethod
35
41
  def modify_init_ast(stree, i, obj, iter_var_name):
36
42
  """Modify the ast node in init function."""
37
- target = f"{iter_var_name.strip()}_{str(i)}"
43
+ target = f"{iter_var_name.strip()}{str(i)}"
38
44
  setattr(stree.get_origin_network(), target, obj)
39
45
  stree.get_origin_network().insert_child_to_cell(target, obj)
40
46
  AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
41
47
  targets=[ScopedValue(ValueType.NamingValue, "self", target)],
42
48
  expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
43
49
  args=[ScopedValue(ValueType.NamingValue, "", "obj"),
44
- ScopedValue(ValueType.StringValue, "", target)])
50
+ ScopedValue(ValueType.ConstantValue, "", target)])
45
51
 
46
52
  @staticmethod
47
53
  def modify_construct_ast(stree, ast_node, old_name, new_name):
48
54
  """Modify the ast node in construct function."""
49
55
  node_str: str = astunparse.unparse(ast_node)
50
- node_str = node_str.replace(old_name, new_name)
56
+ node_str = node_str.replace(old_name+'(', new_name+'(')
51
57
  module_node = ast.parse(node_str)
52
58
  new_node = module_node.body[0]
53
59
  return new_node
@@ -55,13 +61,18 @@ class ForParser(Parser):
55
61
  def target(self):
56
62
  return ast.For
57
63
 
58
- def process(self, stree: SymbolTree, node: ast.For):
64
+ def process(self, stree: SymbolTree, node: ast.For, node_manager: NodeManager):
59
65
  """ Process ast.For node """
60
66
  if isinstance(node.target, ast.Name):
61
67
  targets = node.target.id
68
+ if isinstance(node.iter, ast.Str) or (isinstance(node.iter, ast.Constant) and
69
+ isinstance(node.iter.val, str)):
70
+ # Ast.For which has iter with type of str is converted to python node to avoid instruction injection
71
+ stree.try_append_python_node(node, node)
72
+ return
62
73
  iter_code = astunparse.unparse(node.iter)
63
74
  if not iter_code.startswith(EVAL_WHITE_LIST):
64
- logger.warning(
75
+ logger.info(
65
76
  f"For MindSpore Rewrtie, illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
66
77
  return
67
78
  if "self" in iter_code:
@@ -71,37 +82,47 @@ class ForParser(Parser):
71
82
  except (NameError, TypeError) as e:
72
83
  _info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
73
84
  f"an error occurred: {str(e)}"
74
- logger.warning(_info)
75
- stree.try_append_python_node(node, node)
85
+ logger.info(_info)
86
+ stree.try_append_python_node(node, node, node_manager)
76
87
  return
77
88
 
78
89
  iter_var_name = iter_code.split(".")[-1]
79
- index = stree.get_ast_root().body.index(node) + 1
80
- if isinstance(iter_obj, list):
90
+ ast_functiondef = node_manager.get_ast_functiondef()
91
+ if not ast_functiondef:
92
+ logger.info(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
93
+ "when parsing 'for' statement.")
94
+ stree.try_append_python_node(node, node, node_manager)
95
+ return
96
+ index = ast_functiondef.body.index(node) + 1
97
+ if isinstance(iter_obj, (list, nn.CellList)):
81
98
  for obj in iter_obj:
82
99
  if not isinstance(obj, nn.Cell):
83
- stree.try_append_python_node(node, node)
100
+ stree.try_append_python_node(node, node, node_manager)
84
101
  return
85
102
  for i, obj in enumerate(iter_obj):
86
103
  ForParser.modify_init_ast(stree, i, obj, iter_var_name)
87
104
  for body in node.body:
88
- new_func_name = f"self.{iter_var_name.strip()}_{str(i)}".strip()
105
+ new_func_name = f"self.{iter_var_name.strip()}{str(i)}".strip()
89
106
  new_node = ForParser.modify_construct_ast(stree, body, targets, new_func_name)
90
- stree.get_ast_root().body.insert(index, new_node)
107
+ ast_functiondef.body.insert(index, new_node)
91
108
  index += 1
109
+ # Expand "for" statement and replace the body with Pass
110
+ for body in node.body[:]:
111
+ node.body.remove(body)
112
+ node.body.append(ast.Pass())
113
+
92
114
  if stree.get_ori_cls_name() == "SequentialCell":
93
115
  stree.on_change(Event.CodeChangeEvent)
94
- stree.get_ast_root().body.remove(node)
95
116
  return
96
117
  if isinstance(iter_obj, range):
97
- logger.warning("For MindSpore Rewrite, range not support.")
118
+ logger.info("For MindSpore Rewrite, range not support.")
98
119
  elif isinstance(iter_obj, zip):
99
- logger.warning("For MindSpore Rewrite, zip not support.")
120
+ logger.info("For MindSpore Rewrite, zip not support.")
100
121
  elif isinstance(iter_obj, enumerate):
101
- logger.warning("For MindSpore Rewrite, enumerate not support.")
122
+ logger.info("For MindSpore Rewrite, enumerate not support.")
102
123
  else:
103
- logger.warning(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
104
- stree.try_append_python_node(node, node)
124
+ logger.info(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
125
+ stree.try_append_python_node(node, node, node_manager)
105
126
  return
106
127
 
107
128
  g_for_parser = reg_parser(ForParser())