mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -16,16 +16,18 @@
16
16
  from typing import Optional, Union
17
17
  import ast
18
18
  import inspect
19
+ from types import FunctionType
19
20
 
20
21
  from mindspore.nn import Cell
21
22
  from mindspore.ops import Primitive
22
23
  from mindspore import log as logger
23
- from .. import _checkparam as Validator
24
- from .ast_helpers import AstModifier
25
- from .api.scoped_value import ScopedValue, ValueType
26
- from .api.node_type import NodeType
27
- from .namespace import is_subtree
28
- from .ast_helpers.ast_replacer import AstReplacer
24
+ from ... import _checkparam as Validator
25
+ from ..ast_helpers import AstModifier
26
+ from ..api.scoped_value import ScopedValue, ValueType
27
+ from ..api.node_type import NodeType
28
+ from ..namespace import is_subtree
29
+ from ..ast_helpers.ast_replacer import AstReplacer
30
+ from ..ast_creator_register import ast_creator_registry
29
31
 
30
32
  PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough")
31
33
 
@@ -36,35 +38,33 @@ class Node:
36
38
  invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
37
39
  Node has different meaning in different type of node:
38
40
 
39
- - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore. `targets`
40
- is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and `kwargs` are
41
- corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward method. `func` is
42
- corresponding to func of call expression which means symbol of the cell-op.
41
+ - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
42
+ `targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
43
+ `kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
44
+ method. `func` is corresponding to func of call expression which means symbol of the cell-op.
43
45
  - CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
44
- `targets`, `args`, `kwargs` and `func` are as previous.
46
+ `targets`, `args`, `kwargs` and `func_name` are as previous.
45
47
  - CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
46
- `targets` is corresponding to targets of ast.Assign which means return values of this method. `func` represents
47
- the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the method. When
48
- value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be mapped to
49
- CallMethod node whose `func` is "PassThrough".
50
- - GetAttr: retrieves a parameter from the SymbolTree hierarchy. `func` represents which parameter in SymbolTree
51
- hierarchy. `targets` is corresponding to targets of ast.Assign which means what symbol to accept the result of
52
- get-attr. `args` and `kwargs` are don't-care.
48
+ `targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
49
+ represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
50
+ method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
51
+ mapped to CallMethod node whose `func_name` is "PassThrough".
53
52
  - Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
54
- supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func` are don't-care.
53
+ supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
55
54
  - Input: an input node represents an input of current network which also a parameter of forward method of Cell.
56
55
  `targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
57
- of forward function. `kwargs` and `func` are don't-care.
56
+ of forward function. `kwargs` and `func_name` are don't-care.
58
57
  - Output: an output node represents the output of current network which is corresponding to return statement of
59
- forward method of Cell. `args` represents return values. `func` are always be "return". `targets` and `kwargs` are
60
- don't-care.
58
+ forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
59
+ `kwargs` are don't-care.
61
60
  - Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
62
- `targets`, `args`, `kwargs` and `func` are same as a call-cell node. `symbol_tree` is a handler of a SymbolTree
63
- instance.
61
+ `targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
62
+ SymbolTree instance.
64
63
  """
65
64
 
66
65
  def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
67
- func: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
66
+ func_name: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str,
67
+ instance):
68
68
  """
69
69
  Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
70
70
  as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
@@ -75,7 +75,7 @@ class Node:
75
75
  ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
76
76
  not be None except when node type is Unknown.
77
77
  targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
78
- func (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
78
+ func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
79
79
  args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
80
80
  kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
81
81
  name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
@@ -89,58 +89,29 @@ class Node:
89
89
  self._attribute = Node._get_cell_or_prim_op_attribute(instance)
90
90
  self._instance = instance
91
91
  self._name = name
92
- self._func: Optional[ScopedValue] = func
92
+ self._func_name: Optional[ScopedValue] = func_name
93
93
  self._targets: [ScopedValue] = targets
94
94
  self._args_num = len(args) if args is not None else 0
95
95
  self._kwargs_num = len(kwargs) if kwargs is not None else 0
96
96
  self._normalized_args_keys = [] # for saving args' order
97
97
  self._normalized_args = self._get_normalized_args(args, kwargs)
98
- # edge of node
99
- self._inputs: [Node] = []
100
98
  # position in graph nodes list
101
99
  # it will affect code-order of python code
102
100
  self._prev: Optional[Node] = None
103
101
  self._next: Optional[Node] = None
104
102
  # A handler of SymbolTree current node belonging to
105
103
  self._belong_tree = None
106
-
107
- @classmethod
108
- def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
109
- func: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
110
- name: str = ""):
111
- """
112
- Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
113
- A `CallCell` node represents an invoking to cell-op.
114
- A `CallPrimitive` node represents an invoking to primitive-op.
115
-
116
- Args:
117
- op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
118
- ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
119
- targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
120
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
121
- args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
122
- kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
123
- class.
124
- name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
125
- Name of node also used as field name in network class.
126
- """
127
-
128
- if not isinstance(op, (Cell, Primitive)):
129
- raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
130
- non_custom_args = Node._handle_custom_obj_in_args(args)
131
- non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
132
- if ast_node is None:
133
- ast_node = AstModifier.create_call_assign(targets, func, non_custom_args, non_custom_kwargs)
134
- if isinstance(op, Cell):
135
- node_type = NodeType.CallCell
136
- else:
137
- node_type = NodeType.CallPrimitive
138
- return cls(node_type, ast_node, targets, func, args, kwargs, name, op)
104
+ # A handler of NodeManager current node belonging to
105
+ self._node_manager = None
106
+ # A dict that records which target of which Node current Node's argument come from
107
+ self._arg_providers: {int: (Node, int)} = {}
108
+ # A dict that records which argument of which Node uses current Node's target
109
+ self._target_users: {int: [(Node, int)]} = {}
139
110
 
140
111
  @classmethod
141
112
  def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
142
- func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
143
- name: str = ""):
113
+ func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
114
+ kwargs: {str: ScopedValue}=None, name: str = ""):
144
115
  """
145
116
  Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
146
117
  invoking to cell-op.
@@ -149,7 +120,7 @@ class Node:
149
120
  ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
150
121
  should not be None currently.
151
122
  targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
152
- func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
123
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
153
124
  args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
154
125
  kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
155
126
  name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
@@ -159,12 +130,12 @@ class Node:
159
130
  args = []
160
131
  if kwargs is None:
161
132
  kwargs = {}
162
- if isinstance(func, str):
163
- func = ScopedValue.create_naming_value(func)
133
+ if isinstance(func_name, str):
134
+ func_name = ScopedValue.create_naming_value(func_name)
164
135
  new_targets = Node._handle_targets(targets)
165
136
  if ast_node is None:
166
137
  raise RuntimeError("Input ast_node is None")
167
- return cls(NodeType.CallMethod, ast_node, new_targets, func, args, kwargs, name, None)
138
+ return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
168
139
 
169
140
  @classmethod
170
141
  def create_call_pass_through_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
@@ -187,7 +158,8 @@ class Node:
187
158
  return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
188
159
 
189
160
  @classmethod
190
- def create_input_node(cls, ast_node: ast.AST, arg_name: str, default: Optional[ScopedValue] = None, name: str = ""):
161
+ def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
162
+ name: str = ""):
191
163
  """
192
164
  Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
193
165
  SymbolTree which is corresponding to parameters of forward function.
@@ -204,6 +176,8 @@ class Node:
204
176
  args = []
205
177
  else:
206
178
  args = [default]
179
+ if ast_node is None:
180
+ ast_node = ast.arg(arg_name)
207
181
  return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
208
182
 
209
183
  @classmethod
@@ -241,17 +215,83 @@ class Node:
241
215
  args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
242
216
  sequentially in the list.
243
217
  ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
244
- saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
218
+ saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
245
219
  name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
246
220
  Name of node also used as field name in network class. The format of mathops node name
247
221
  is 'AstNodeName_AstOpName_n'.
248
222
  """
249
223
  return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
250
224
 
225
+ @staticmethod
226
+ def create_assign_node(targets, func_name, args, kwargs):
227
+ """Create a ast.Assign type node."""
228
+ # create targets
229
+ ast_targets = [ast_creator_registry.get("Name")(targets)]
230
+ # create call
231
+ ast_func = ast_creator_registry.get("Attribute")(func_name)
232
+ ast_args = ast_creator_registry.get("Args")(args)
233
+ ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
234
+ ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
235
+ # create assign
236
+ ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
237
+ return ast_node
238
+
239
+ @staticmethod
240
+ def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
241
+ kwargs: {str: ScopedValue}=None):
242
+ """
243
+ Create a node that corresponds to a function call.
244
+
245
+ Args:
246
+ function (FunctionType): The function to be called.
247
+ targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
248
+ args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
249
+ source code. Default: ``None`` , which indicates the `function` has no args inputs.
250
+ kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
251
+ Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
252
+ code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
253
+
254
+ Returns:
255
+ An instance of `Node`.
256
+ """
257
+ if args is None:
258
+ args = []
259
+ if kwargs is None:
260
+ kwargs = {}
261
+ targets = Node._handle_targets(targets)
262
+ _package = None
263
+ if isinstance(function, FunctionType):
264
+ _package = function.__globals__['__package__']
265
+ func_full_name = ".".join([_package, function.__name__]) if _package else function.__name__
266
+ func_scope = ''
267
+ func_name = func_full_name.split('.')[-1]
268
+ if func_full_name.count('.') > 0:
269
+ func_scope = func_full_name.rsplit('.')[0]
270
+ func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
271
+ node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
272
+ return node
273
+
274
+ @classmethod
275
+ def inner_create_call_function(cls, node_name, ast_node, func_name, function, targets, args, kwargs):
276
+ '''
277
+ Instantiate an instance of node whose type is `CallFunction`.
278
+
279
+ Args:
280
+ node_name (str): Name of node.
281
+ func_name (str): Name of function.
282
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
283
+ targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
284
+ function (Object): An instance of function. See detail in docstring of Node class.
285
+ args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
286
+ kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
287
+ class.
288
+ '''
289
+ return cls(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, function)
290
+
251
291
  @staticmethod
252
292
  def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
253
- func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
254
- name: str = "", is_sub_net: bool = False):
293
+ args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
294
+ is_sub_net: bool = False):
255
295
  """
256
296
  Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
257
297
  If op is custom defined, it is treated by TreeNode.
@@ -262,12 +302,11 @@ class Node:
262
302
  op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
263
303
  ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
264
304
  targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
265
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
266
305
  args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
267
306
  kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
268
307
  class.
269
- name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
270
- Name of node also used as field name in network class.
308
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
309
+ `SymbolTree`. Name of node also used as field name in network class.
271
310
  is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
272
311
  `cell` to a TreeNode, else a CallCell Node. Default is a False.
273
312
  """
@@ -275,29 +314,58 @@ class Node:
275
314
  if ast_node is not None:
276
315
  Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
277
316
  Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
278
- Validator.check_value_type("func", func, [ScopedValue, str], "Node")
279
317
  if args is not None:
280
318
  Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
281
319
  if kwargs is not None:
282
320
  Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
283
- cls_name = type(op).__name__
284
-
285
321
  if args is None:
286
322
  args = []
287
323
  if kwargs is None:
288
324
  kwargs = {}
289
- if isinstance(func, str):
290
- func = ScopedValue.create_naming_value(func)
325
+ Validator.check_value_type("node_name", node_name, [str], "Node")
291
326
  new_targets = Node._handle_targets(targets)
292
- if is_sub_net and is_subtree(cls_name):
293
- from .symbol_tree_builder import SymbolTreeBuilder
327
+ if isinstance(node_name, str):
328
+ func_name = ScopedValue.create_naming_value(node_name)
329
+ else:
330
+ func_name = node_name
331
+ if is_sub_net and is_subtree(op):
332
+ from ..symbol_tree_builder import SymbolTreeBuilder
294
333
  stb = SymbolTreeBuilder(op)
295
334
  stree = stb.build()
296
335
  replacer = AstReplacer(stree.get_class_ast())
297
336
  replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
298
- return TreeNode.create_tree_node(stree, ast_node, new_targets, func, args, kwargs, name, op)
337
+ return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
338
+
339
+ return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
340
+
341
+ @classmethod
342
+ def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
343
+ func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
344
+ node_name: str = ""):
345
+ """
346
+ Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
347
+ A `CallCell` node represents an invoking to cell-op.
348
+ A `CallPrimitive` node represents an invoking to primitive-op.
349
+
350
+ Args:
351
+ op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
352
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
353
+ targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
354
+ func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
355
+ args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
356
+ kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
357
+ class.
358
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
359
+ `SymbolTree`. Name of node also used as field name in network class.
360
+ """
299
361
 
300
- return Node.create_call_buildin_op(op, ast_node, new_targets, func, args, kwargs, name)
362
+ if not isinstance(op, (Cell, Primitive)):
363
+ raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
364
+ if isinstance(op, Cell):
365
+ node_type = NodeType.CallCell
366
+ else:
367
+ node_type = NodeType.CallPrimitive
368
+ return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
301
369
 
302
370
  @staticmethod
303
371
  def _get_construct_arg_names(parameters):
@@ -506,21 +574,23 @@ class Node:
506
574
  """
507
575
  return self._next
508
576
 
509
- def has_same_ast(self, node: Union['Node', ast.AST]) -> bool:
577
+ def set_prev(self, node: 'Node'):
510
578
  """
511
- Check if other node holds same ast node with self.
579
+ Set previous node of current node.
512
580
 
513
581
  Args:
514
- node (Union[Node, ast.AST]): An instance of ast.AST or an instance of node to be compared.
582
+ node (Node): Node to be set as previous node of current node.
583
+ """
584
+ self._prev = node
515
585
 
516
- Returns:
517
- A bool.
586
+ def set_next(self, node: 'Node'):
518
587
  """
519
- if isinstance(node, Node):
520
- return self.has_same_ast(node._ast_node)
521
- if isinstance(node, ast.AST):
522
- return id(self._ast_node) == id(node)
523
- return False
588
+ Set next node of current node.
589
+
590
+ Args:
591
+ node (Node): Node to be set as next node of current node.
592
+ """
593
+ self._next = node
524
594
 
525
595
  def get_ast(self) -> Optional[ast.AST]:
526
596
  """
@@ -550,16 +620,24 @@ class Node:
550
620
  """Set the symbol tree to which node belongs."""
551
621
  self._belong_tree = symbol_tree
552
622
 
623
+ def get_node_manager(self):
624
+ """Get the NodeManager current node belongs to."""
625
+ return self._node_manager
626
+
627
+ def set_node_manager(self, node_manager):
628
+ """Set NodeManager current node belongs."""
629
+ self._node_manager = node_manager
630
+
553
631
  def isolate(self):
554
632
  """Link prev node to next node and isolate node from source code order list."""
555
- origin_prev: Optional[Node] = self._prev
556
- origin_next: Optional[Node] = self._next
633
+ origin_prev: Optional[Node] = self.get_prev()
634
+ origin_next: Optional[Node] = self.get_next()
557
635
  if origin_prev is not None:
558
- origin_prev._next = origin_next
636
+ origin_prev.set_next(origin_next)
559
637
  if origin_next is not None:
560
- origin_next._prev = origin_prev
561
- self._prev = None
562
- self._next = None
638
+ origin_next.set_prev(origin_prev)
639
+ self.set_prev(None)
640
+ self.set_next(None)
563
641
 
564
642
  def insert_before(self, node: 'Node'):
565
643
  """
@@ -569,12 +647,12 @@ class Node:
569
647
  node (Node): An instance of node to be inserted in.
570
648
  """
571
649
  node.isolate()
572
- origin_prev: Optional[Node] = self._prev
650
+ origin_prev: Optional[Node] = self.get_prev()
573
651
  if origin_prev is not None:
574
- origin_prev._next = node
575
- node._prev = origin_prev
576
- node._next = self
577
- self._prev = node
652
+ origin_prev.set_next(node)
653
+ node.set_prev(origin_prev)
654
+ node.set_next(self)
655
+ self.set_prev(node)
578
656
 
579
657
  def insert_after(self, node: 'Node'):
580
658
  """
@@ -584,31 +662,26 @@ class Node:
584
662
  node (Node): An instance of node to be inserted in.
585
663
  """
586
664
  node.isolate()
587
- origin_next: Optional[Node] = self._next
588
- self._next = node
589
- node._prev = self
590
- node._next = origin_next
665
+ origin_next: Optional[Node] = self.get_next()
666
+ self.set_next(node)
667
+ node.set_prev(self)
668
+ node.set_next(origin_next)
591
669
  if origin_next is not None:
592
- origin_next._prev = node
670
+ origin_next.set_prev(node)
593
671
 
594
672
  def get_inputs(self) -> ['Node']:
595
673
  """
596
- Getter of _inputs which represents input nodes of current node in topological order.
674
+ Get input nodes of current node in topological order.
597
675
 
598
676
  Returns:
599
677
  A list of instances of Node as input nodes.
600
678
  """
601
- return self._inputs
602
-
603
- def set_inputs(self, inputs: ['Node']):
604
- """
605
- Setter of _inputs which represents input nodes of current node in topological order.
606
-
607
-
608
- Args:
609
- inputs (list[Node]): A list of instances of Node as new input nodes.
610
- """
611
- self._inputs = inputs
679
+ inputs = []
680
+ for arg_provider in self.get_arg_providers().values():
681
+ if not arg_provider:
682
+ continue
683
+ inputs.append(arg_provider[0])
684
+ return inputs
612
685
 
613
686
  def get_targets(self) -> [ScopedValue]:
614
687
  """
@@ -654,26 +727,26 @@ class Node:
654
727
  NodeType.MathOps):
655
728
  self._sync_assign_targets_to_ast()
656
729
 
657
- def get_func(self) -> ScopedValue:
730
+ def get_func_name(self) -> ScopedValue:
658
731
  """
659
- Getter of `_func`. See detail in docstring of Node class for meaning of func.
732
+ Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
660
733
 
661
734
  Returns:
662
735
  An instance of ScopedValue.
663
736
  """
664
- return self._func
737
+ return self._func_name
665
738
 
666
- def set_func(self, func: ScopedValue):
739
+ def set_func_name(self, func_name: ScopedValue):
667
740
  """
668
- Setter of `_func`. See detail in docstring of Node class for meaning of func.
741
+ Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
669
742
 
670
743
  Note:
671
- When `_func` is updated, corresponding ast node would be updated also.
744
+ When `_func_name` is updated, corresponding ast node would be updated also.
672
745
 
673
746
  Args:
674
747
  func (ScopedValue): An instance of ScopedValue as new func.
675
748
  """
676
- self._func = func
749
+ self._func_name = func_name
677
750
  if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
678
751
  self._sync_assign_func_to_ast()
679
752
 
@@ -750,11 +823,11 @@ class Node:
750
823
  Validator.check_value_type("node", node, [Node], "Node")
751
824
  Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
752
825
  if out_idx is None:
753
- if len(node._targets) != 1:
826
+ if len(node.get_targets()) != 1:
754
827
  raise RuntimeError("node should has one output when out_idx is not provided")
755
828
  out_idx = 0
756
- Validator.check_int_range(out_idx, 0, len(node._targets), Validator.INC_LEFT, "arg_idx")
757
- new_arg = node._targets[out_idx]
829
+ Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
830
+ new_arg = node.get_targets()[out_idx]
758
831
  self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
759
832
  self._sync_arg()
760
833
 
@@ -943,6 +1016,66 @@ class Node:
943
1016
  """
944
1017
  return self._attribute.get(key)
945
1018
 
1019
+ def get_arg_providers(self) -> dict:
1020
+ """
1021
+ Getter of _arg_providers.
1022
+
1023
+ Return:
1024
+ dict, key is type of int indicating the index of args, and value is type of tuple, which includes
1025
+ the node and the index of node's targets who provides the argument.
1026
+ """
1027
+ return self._arg_providers
1028
+
1029
+ def set_arg_providers(self, index: int, provider: tuple):
1030
+ """
1031
+ Setter of _arg_providers.
1032
+
1033
+ Args:
1034
+ index (int): Indicating provider of which argument need to be set.
1035
+ provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
1036
+ """
1037
+ self._arg_providers[index] = provider
1038
+
1039
+ def get_target_users(self, index=-1) -> Union[dict, list]:
1040
+ """
1041
+ Getter of _target_users.
1042
+
1043
+ Args:
1044
+ index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
1045
+ be returned.
1046
+
1047
+ Return:
1048
+ Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
1049
+ The type of elements in list is tuple, which includes the user node and the index of node's arguments
1050
+ who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
1051
+ value is list of users of corresponding target.
1052
+ """
1053
+ if index == -1:
1054
+ return self._target_users
1055
+ if index not in self._target_users.keys():
1056
+ self._target_users[index] = []
1057
+ return self._target_users.get(index, None)
1058
+
1059
+ def append_target_users(self, index: int, provider: tuple):
1060
+ """
1061
+ Setter of _target_users.
1062
+
1063
+ Args:
1064
+ index (int): Indicating users of which target need to be append.
1065
+ provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
1066
+
1067
+ """
1068
+ if index not in self._target_users.keys():
1069
+ self._target_users[index] = []
1070
+ self._target_users.get(index).append(provider)
1071
+
1072
+ def update_ast_node(self) -> ast.AST:
1073
+ """Update node's ast_node by current targets, func_name, args and kwargs."""
1074
+ ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
1075
+ self.get_args(), self.get_kwargs())
1076
+ self.set_ast(ast_assign)
1077
+ return ast_assign
1078
+
946
1079
  def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
947
1080
  """
948
1081
  Merge args and kwargs to normalized args.
@@ -983,6 +1116,10 @@ class Node:
983
1116
  self._normalized_args_keys.append(arg_key)
984
1117
  return normalized_args
985
1118
 
1119
+ ##########################################################################################################
1120
+ # Synchronize rewrite node args to ast node
1121
+ ##########################################################################################################
1122
+
986
1123
  def _sync_assign_func_to_ast(self):
987
1124
  """Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
988
1125
  if self._ast_node is None:
@@ -994,20 +1131,21 @@ class Node:
994
1131
  if not isinstance(call_ast, ast.Call):
995
1132
  raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
996
1133
  func_ast = call_ast.func
997
- if not self._func.value:
1134
+ if not self._func_name.value:
998
1135
  if isinstance(func_ast, ast.Name):
999
- func_ast.id = self._func.value
1136
+ func_ast.id = self._func_name.value
1000
1137
  else:
1001
- call_ast.func = ast.Name(self._func.value, ast.Store())
1138
+ call_ast.func = ast.Name(self._func_name.value, ast.Store())
1002
1139
  else:
1003
1140
  if isinstance(func_ast, ast.Attribute):
1004
1141
  func_value = func_ast.value
1005
1142
  if not isinstance(func_value, ast.Name):
1006
1143
  raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
1007
- func_value.id = self._func.scope
1008
- func_ast.attr = self._func.value
1144
+ func_value.id = self._func_name.scope
1145
+ func_ast.attr = self._func_name.value
1009
1146
  else:
1010
- call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store())
1147
+ call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
1148
+ self._func_name.value, ast.Store())
1011
1149
  ast.fix_missing_locations(assign_ast)
1012
1150
 
1013
1151
  def _sync_assign_targets_to_ast(self):
@@ -1023,7 +1161,7 @@ class Node:
1023
1161
  raise RuntimeError("self._targets should have the same length as targets_ast's elts")
1024
1162
  if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
1025
1163
  raise RuntimeError("self._targets should have targets_ast same length")
1026
- for i in range(0, len(self._targets)):
1164
+ for i, _ in enumerate(self._targets):
1027
1165
  target = self._targets[i]
1028
1166
  target_ast = targets_ast[0]
1029
1167
  if isinstance(target_ast, ast.Name):
@@ -1043,7 +1181,7 @@ class Node:
1043
1181
  return
1044
1182
  assign_ast = self._ast_node
1045
1183
  if not isinstance(assign_ast, ast.Assign):
1046
- raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
1184
+ raise TypeError(f"assign_ast should be ast.Assign, got: {type(assign_ast)}")
1047
1185
  assign_value = assign_ast.value
1048
1186
  if not isinstance(assign_value, ast.Call):
1049
1187
  return
@@ -1094,23 +1232,31 @@ class Node:
1094
1232
  if len(self._normalized_args_keys) != 1:
1095
1233
  raise RuntimeError("self._normalized_args_keys should have 1 elements")
1096
1234
  arg = self._normalized_args.get(self._normalized_args_keys[0])
1097
- if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
1098
- raise RuntimeError("arg should be an IntValue, FloatValue or StringValue")
1235
+ if arg.type != ValueType.ConstantValue:
1236
+ raise RuntimeError("arg should be an ConstantValue")
1099
1237
  if arg.scope != "":
1100
1238
  raise RuntimeError("arg.scope should be empty")
1101
1239
  assign_value.value = arg.value
1102
1240
 
1103
1241
  def _sync_call_method_args_to_ast(self):
1104
- """Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
1242
+ """
1243
+ Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
1244
+
1245
+ For node with type of CallMethod, the value of ast.Assign is one of:
1246
+ - ast.Tuple
1247
+ - ast.Name
1248
+ - ast.ast.Attribute
1249
+ - ...
1250
+ """
1105
1251
  if self._ast_node is None:
1106
1252
  return
1107
1253
  assign_ast = self._ast_node
1108
1254
  if not isinstance(assign_ast, ast.Assign):
1109
1255
  raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
1110
1256
  assign_value = assign_ast.value
1111
- if self._func == PASS_THROUGH_METHOD:
1257
+ if self._func_name == PASS_THROUGH_METHOD:
1112
1258
  self._sync_call_pass_through_method_args_to_ast(assign_value)
1113
- elif self._func.value == "tuple":
1259
+ elif self._func_name.value == "tuple":
1114
1260
  tuple_ast: ast.Tuple = assign_value
1115
1261
  if len(self._normalized_args_keys) != len(tuple_ast.elts):
1116
1262
  raise RuntimeError("size of self._normalized_args_keys should be equal to size of elements of tuple")
@@ -1130,10 +1276,16 @@ class Node:
1130
1276
  else:
1131
1277
  raise RuntimeError("Only support constant or symbol in tuple now")
1132
1278
  else:
1133
- raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self._func.value)
1279
+ raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self._func_name.value)
1134
1280
 
1135
1281
  def _sync_return_node_to_ast(self):
1136
- """Sync return value of ast.Return from self._normalized_args when NodeType is Output."""
1282
+ """
1283
+ Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
1284
+
1285
+ For node with type of CallMethod, the value of ast.Assign is one of:
1286
+ - ast.Name
1287
+ - ast.Tuple
1288
+ """
1137
1289
  if self._ast_node is None:
1138
1290
  return
1139
1291
  return_ast = self._ast_node
@@ -1195,7 +1347,7 @@ class Node:
1195
1347
 
1196
1348
  def _sync_arg(self):
1197
1349
  """Sync _normalized_args to corresponding ast node when updated."""
1198
- if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
1350
+ if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
1199
1351
  NodeType.CellContainer, NodeType.CallFunction):
1200
1352
  self._sync_call_cell_args_to_ast()
1201
1353
  elif self._node_type == NodeType.Output:
@@ -1206,15 +1358,18 @@ class Node:
1206
1358
  self._sync_mathops_node_args_to_ast()
1207
1359
 
1208
1360
 
1361
+ ##########################################################################################################
1362
+ # Child classes
1363
+ ##########################################################################################################
1364
+
1209
1365
  class TreeNode(Node):
1210
1366
  """Tree type Node who holds a handler of SymbolTree."""
1211
1367
 
1212
1368
  def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1213
1369
  args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1214
1370
  """
1215
- Constructor of Node. Rewrite recommend to invoking class method of Node to instantiate an instance of Node such
1216
- as `create_call_buildin_op`, `create_call_method`, `create_python_node`, `create_input_node` and
1217
- `create_output_node`, etc. rather than invoking constructor of Node directly.
1371
+ Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
1372
+ TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
1218
1373
 
1219
1374
  Args:
1220
1375
  tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
@@ -1233,8 +1388,9 @@ class TreeNode(Node):
1233
1388
  self.symbol_tree = tree
1234
1389
 
1235
1390
  @classmethod
1236
- def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str], func: Union[ScopedValue, str],
1237
- args: [ScopedValue], kwargs: {str: ScopedValue}, name: str = "", instance=None):
1391
+ def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
1392
+ func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
1393
+ name: str = "", instance=None):
1238
1394
  """
1239
1395
  Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
1240
1396
  to sub-network.
@@ -1243,104 +1399,14 @@ class TreeNode(Node):
1243
1399
  tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
1244
1400
  ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1245
1401
  targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1246
- func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1402
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1247
1403
  args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1248
1404
  kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1249
1405
  name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1250
1406
  Name of node also used as field name in network class.
1251
1407
  instance: Object in network corresponding to this node.
1252
1408
  """
1253
-
1254
- non_custom_args = Node._handle_custom_obj_in_args(args)
1255
- non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
1256
1409
  new_targets = Node._handle_targets(targets)
1257
- if isinstance(func, str):
1258
- func = ScopedValue.create_naming_value(func)
1259
- if ast_node is None:
1260
- ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
1261
- return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
1262
-
1263
-
1264
- class CellContainer(Node):
1265
- """ Container for saving cell-objects node. """
1266
- class _Visitor():
1267
- """ A iterator of CellContainer nodes. """
1268
- def __init__(self, cellcontainer):
1269
- self._cellcontainer = cellcontainer
1270
-
1271
- def __len__(self):
1272
- """ Get the number of nodes. """
1273
- return self._cellcontainer.node_count
1274
-
1275
- def __iter__(self):
1276
- """Create an iterator over the CellContainer."""
1277
- count = len(self._cellcontainer.node_list)
1278
- i = 0
1279
- while i < count:
1280
- curr = self._cellcontainer.node_list[i]
1281
- if curr.valid:
1282
- yield curr
1283
- i += 1
1284
-
1285
- def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1286
- args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1287
- """Constructor of CellContainer.
1288
-
1289
- Args:
1290
- ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1291
- targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1292
- func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1293
- args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1294
- kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1295
- name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1296
- Name of node also used as field name in network class.
1297
- instance: Object in network corresponding to this node.
1298
- """
1299
- if isinstance(func, str):
1300
- func = ScopedValue.create_naming_value(func)
1301
- super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
1302
- self._node_list = list()
1303
- self._node_count = 0
1304
-
1305
- @property
1306
- def node_count(self):
1307
- """Number of nodes."""
1308
- return len(self._node_list)
1309
-
1310
- @property
1311
- def node_list(self):
1312
- """ Get node list. """
1313
- return self._node_list
1314
-
1315
- def append(self, node):
1316
- """ Append new node to node list. """
1317
- setattr(node, "container", self)
1318
- setattr(node, "valid", True)
1319
- node.set_belong_symbol_tree(self.get_belong_symbol_tree())
1320
- self._node_list.append(node)
1321
- # when creating a cell_container, node instance is already in SequentialCell cell_list
1322
- # so here we need to write a if judgement
1323
- if node.get_instance() not in self.get_instance().cell_list:
1324
- self.get_instance().append(node.get_instance())
1325
-
1326
- def erase(self, node):
1327
- """Erase node form container."""
1328
- index_node = self.node_list.index(node)
1329
- index_instance = self.get_instance().cell_list.index(node.get_instance())
1330
- if index_node != index_instance:
1331
- raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
1332
- setattr(node, "valid", False)
1333
- del self.get_instance()[index_node]
1334
- del self._node_list[index_node]
1335
-
1336
- def insert(self, index, node):
1337
- """Insert node into container"""
1338
- self.node_list.insert(index, node)
1339
- setattr(node, "container", self)
1340
- setattr(node, "valid", True)
1341
- node.set_belong_symbol_tree(self.get_belong_symbol_tree())
1342
- self.get_instance()._insert(index, node.get_instance())
1343
-
1344
- def nodes(self):
1345
- """ Return a iterator of node."""
1346
- return self._Visitor(self)
1410
+ if isinstance(func_name, str):
1411
+ func_name = ScopedValue.create_naming_value(func_name)
1412
+ return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)