mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  164. mindspore/lib/libmindspore_shared_lib.so +0 -0
  165. mindspore/lib/libnnacl.so +0 -0
  166. mindspore/lib/libopencv_core.so.4.5 +0 -0
  167. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  169. mindspore/lib/libps_cache.so +0 -0
  170. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  183. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  187. mindspore/lib/plugin/ascend/libakg.so +0 -0
  188. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  189. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  190. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  191. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  193. mindspore/lib/plugin/cpu/libakg.so +0 -0
  194. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  196. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  197. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  198. mindspore/nn/__init__.py +0 -2
  199. mindspore/nn/cell.py +313 -74
  200. mindspore/nn/dynamic_lr.py +21 -21
  201. mindspore/nn/layer/activation.py +22 -30
  202. mindspore/nn/layer/basic.py +15 -13
  203. mindspore/nn/layer/channel_shuffle.py +1 -1
  204. mindspore/nn/layer/container.py +271 -9
  205. mindspore/nn/layer/conv.py +323 -204
  206. mindspore/nn/layer/dense.py +8 -5
  207. mindspore/nn/layer/embedding.py +33 -27
  208. mindspore/nn/layer/flash_attention.py +61 -95
  209. mindspore/nn/layer/image.py +8 -6
  210. mindspore/nn/layer/math.py +16 -25
  211. mindspore/nn/layer/normalization.py +107 -66
  212. mindspore/nn/layer/padding.py +1 -1
  213. mindspore/nn/layer/pooling.py +131 -109
  214. mindspore/nn/layer/rnn_cells.py +27 -22
  215. mindspore/nn/layer/rnns.py +13 -16
  216. mindspore/nn/layer/thor_layer.py +1 -1
  217. mindspore/nn/layer/transformer.py +221 -154
  218. mindspore/nn/learning_rate_schedule.py +9 -1
  219. mindspore/nn/loss/loss.py +235 -174
  220. mindspore/nn/optim/ada_grad.py +2 -1
  221. mindspore/nn/optim/adadelta.py +1 -0
  222. mindspore/nn/optim/adafactor.py +2 -1
  223. mindspore/nn/optim/adam.py +7 -4
  224. mindspore/nn/optim/adamax.py +3 -2
  225. mindspore/nn/optim/adasum.py +2 -2
  226. mindspore/nn/optim/asgd.py +2 -3
  227. mindspore/nn/optim/ftrl.py +6 -5
  228. mindspore/nn/optim/lamb.py +7 -4
  229. mindspore/nn/optim/lars.py +1 -1
  230. mindspore/nn/optim/lazyadam.py +5 -3
  231. mindspore/nn/optim/momentum.py +2 -1
  232. mindspore/nn/optim/optimizer.py +53 -4
  233. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  234. mindspore/nn/optim/rmsprop.py +4 -3
  235. mindspore/nn/optim/rprop.py +23 -12
  236. mindspore/nn/optim/sgd.py +26 -11
  237. mindspore/nn/optim/thor.py +9 -7
  238. mindspore/nn/probability/bijector/bijector.py +5 -5
  239. mindspore/nn/probability/bijector/power_transform.py +27 -27
  240. mindspore/nn/probability/bijector/softplus.py +3 -3
  241. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  242. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  243. mindspore/nn/probability/distribution/beta.py +3 -3
  244. mindspore/nn/probability/distribution/categorical.py +7 -7
  245. mindspore/nn/probability/distribution/cauchy.py +0 -1
  246. mindspore/nn/probability/distribution/distribution.py +3 -3
  247. mindspore/nn/probability/distribution/gamma.py +3 -3
  248. mindspore/nn/probability/distribution/geometric.py +4 -4
  249. mindspore/nn/probability/distribution/gumbel.py +4 -4
  250. mindspore/nn/probability/distribution/log_normal.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +2 -2
  252. mindspore/nn/probability/distribution/poisson.py +4 -4
  253. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  254. mindspore/nn/probability/distribution/uniform.py +6 -6
  255. mindspore/nn/wrap/__init__.py +4 -2
  256. mindspore/nn/wrap/cell_wrapper.py +87 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  274. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  275. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  276. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  277. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  278. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  279. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  280. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  281. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  283. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  284. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  285. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  286. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  287. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  288. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  289. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  290. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  291. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  293. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  294. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  295. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  297. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  298. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  299. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  300. mindspore/ops/_primitive_cache.py +1 -1
  301. mindspore/ops/_tracefunc.py +45 -13
  302. mindspore/ops/_utils/utils.py +6 -1
  303. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  304. mindspore/ops/_vmap/vmap_base.py +3 -3
  305. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  306. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  307. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  308. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  309. mindspore/ops/arg_dtype_cast.py +54 -0
  310. mindspore/ops/composite/base.py +37 -10
  311. mindspore/ops/composite/math_ops.py +5 -4
  312. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  313. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  314. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  315. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  316. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  319. mindspore/ops/deprecated.py +304 -0
  320. mindspore/ops/function/__init__.py +4 -1
  321. mindspore/ops/function/array_func.py +174 -193
  322. mindspore/ops/function/clip_func.py +81 -13
  323. mindspore/ops/function/debug_func.py +1 -1
  324. mindspore/ops/function/grad/grad_func.py +18 -9
  325. mindspore/ops/function/image_func.py +10 -4
  326. mindspore/ops/function/linalg_func.py +5 -5
  327. mindspore/ops/function/math_func.py +575 -386
  328. mindspore/ops/function/nn_func.py +568 -260
  329. mindspore/ops/function/random_func.py +88 -57
  330. mindspore/ops/function/sparse_func.py +1 -1
  331. mindspore/ops/function/sparse_unary_func.py +14 -12
  332. mindspore/ops/function/vmap_func.py +6 -5
  333. mindspore/ops/functional.py +15 -10
  334. mindspore/ops/op_info_register.py +244 -25
  335. mindspore/ops/operations/__init__.py +31 -19
  336. mindspore/ops/operations/_grad_ops.py +71 -7
  337. mindspore/ops/operations/_inner_ops.py +350 -17
  338. mindspore/ops/operations/_quant_ops.py +4 -8
  339. mindspore/ops/operations/_sequence_ops.py +42 -0
  340. mindspore/ops/operations/array_ops.py +68 -282
  341. mindspore/ops/operations/comm_ops.py +107 -59
  342. mindspore/ops/operations/custom_ops.py +94 -70
  343. mindspore/ops/operations/debug_ops.py +8 -4
  344. mindspore/ops/operations/image_ops.py +18 -12
  345. mindspore/ops/operations/inner_ops.py +26 -3
  346. mindspore/ops/operations/math_ops.py +192 -144
  347. mindspore/ops/operations/nn_ops.py +857 -489
  348. mindspore/ops/operations/other_ops.py +0 -22
  349. mindspore/ops/operations/random_ops.py +53 -111
  350. mindspore/ops/operations/sparse_ops.py +3 -1
  351. mindspore/ops/primitive.py +24 -18
  352. mindspore/parallel/_auto_parallel_context.py +68 -8
  353. mindspore/parallel/_cost_model_context.py +2 -2
  354. mindspore/parallel/_offload_context.py +17 -3
  355. mindspore/parallel/_parallel_serialization.py +12 -5
  356. mindspore/parallel/_ps_context.py +12 -0
  357. mindspore/parallel/_tensor.py +18 -13
  358. mindspore/parallel/_transformer/layers.py +5 -3
  359. mindspore/parallel/_transformer/loss.py +1 -0
  360. mindspore/parallel/_transformer/moe.py +2 -2
  361. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  362. mindspore/parallel/_transformer/transformer.py +23 -3
  363. mindspore/parallel/_utils.py +11 -7
  364. mindspore/parallel/algo_parameter_config.py +85 -5
  365. mindspore/parallel/checkpoint_transform.py +19 -12
  366. mindspore/parallel/shard.py +21 -14
  367. mindspore/profiler/common/struct_type.py +3 -3
  368. mindspore/profiler/common/util.py +4 -2
  369. mindspore/profiler/envprofiling.py +1 -1
  370. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  371. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  372. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  373. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  374. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  375. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  376. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  377. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  378. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  379. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  380. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  381. mindspore/profiler/parser/flops_parser.py +15 -11
  382. mindspore/profiler/parser/framework_parser.py +38 -22
  383. mindspore/profiler/parser/hccl_parser.py +16 -12
  384. mindspore/profiler/parser/integrator.py +22 -11
  385. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  386. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  387. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  388. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  389. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  390. mindspore/profiler/parser/optime_parser.py +1 -1
  391. mindspore/profiler/parser/profiler_info.py +21 -2
  392. mindspore/profiler/parser/step_trace_parser.py +11 -14
  393. mindspore/profiler/profiling.py +179 -89
  394. mindspore/rewrite/api/node.py +102 -19
  395. mindspore/rewrite/api/node_type.py +5 -1
  396. mindspore/rewrite/api/pattern_engine.py +1 -1
  397. mindspore/rewrite/api/scoped_value.py +9 -17
  398. mindspore/rewrite/api/symbol_tree.py +131 -47
  399. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  400. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  401. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  402. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  403. mindspore/rewrite/common/rewrite_elog.py +5 -1
  404. mindspore/rewrite/namer.py +33 -24
  405. mindspore/rewrite/namespace.py +14 -5
  406. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  407. mindspore/rewrite/node/call_function.py +79 -0
  408. mindspore/rewrite/node/cell_container.py +135 -0
  409. mindspore/rewrite/node/control_flow.py +88 -0
  410. mindspore/rewrite/{node.py → node/node.py} +273 -234
  411. mindspore/rewrite/node/node_manager.py +254 -0
  412. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  413. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  414. mindspore/rewrite/parsers/assign_parser.py +216 -221
  415. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  416. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  417. mindspore/rewrite/parsers/constant_parser.py +9 -6
  418. mindspore/rewrite/parsers/container_parser.py +9 -7
  419. mindspore/rewrite/parsers/for_parser.py +42 -21
  420. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  421. mindspore/rewrite/parsers/if_parser.py +28 -24
  422. mindspore/rewrite/parsers/module_parser.py +196 -25
  423. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  424. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  425. mindspore/rewrite/parsers/return_parser.py +6 -6
  426. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  427. mindspore/rewrite/sparsify/utils.py +1 -1
  428. mindspore/rewrite/symbol_tree.py +523 -578
  429. mindspore/rewrite/symbol_tree_builder.py +9 -193
  430. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  431. mindspore/run_check/_check_version.py +6 -4
  432. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  433. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  434. mindspore/scipy/linalg.py +1 -1
  435. mindspore/scipy/ops.py +55 -5
  436. mindspore/scipy/optimize/__init__.py +3 -2
  437. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  438. mindspore/scipy/optimize/minimize.py +7 -3
  439. mindspore/train/_utils.py +7 -3
  440. mindspore/train/amp.py +323 -123
  441. mindspore/train/anf_ir_pb2.py +14 -2
  442. mindspore/train/callback/_backup_and_restore.py +2 -12
  443. mindspore/train/callback/_callback.py +29 -4
  444. mindspore/train/callback/_checkpoint.py +23 -8
  445. mindspore/train/callback/_early_stop.py +2 -2
  446. mindspore/train/callback/_landscape.py +4 -4
  447. mindspore/train/callback/_loss_monitor.py +2 -2
  448. mindspore/train/callback/_on_request_exit.py +2 -2
  449. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  450. mindspore/train/callback/_summary_collector.py +15 -8
  451. mindspore/train/callback/_time_monitor.py +58 -5
  452. mindspore/train/data_sink.py +5 -11
  453. mindspore/train/dataset_helper.py +84 -57
  454. mindspore/train/loss_scale_manager.py +2 -2
  455. mindspore/train/metrics/__init__.py +3 -3
  456. mindspore/train/metrics/cosine_similarity.py +1 -1
  457. mindspore/train/metrics/hausdorff_distance.py +3 -2
  458. mindspore/train/metrics/mean_surface_distance.py +3 -2
  459. mindspore/train/metrics/metric.py +39 -19
  460. mindspore/train/metrics/roc.py +2 -2
  461. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  462. mindspore/train/mind_ir_pb2.py +85 -36
  463. mindspore/train/model.py +187 -47
  464. mindspore/train/serialization.py +487 -161
  465. mindspore/train/summary/_summary_adapter.py +1 -1
  466. mindspore/train/summary/_writer_pool.py +3 -2
  467. mindspore/train/summary/summary_record.py +37 -17
  468. mindspore/train/train_thor/convert_utils.py +3 -3
  469. mindspore/train/train_thor/dataset_helper.py +1 -1
  470. mindspore/version.py +1 -1
  471. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  474. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  475. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  477. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  478. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  479. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  480. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  481. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  482. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  483. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  485. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  486. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  487. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  488. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  489. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  490. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  491. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  492. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  493. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  494. mindspore/_extends/graph_kernel/expander.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  496. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  497. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  498. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  499. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  500. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  501. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  502. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  503. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  504. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  505. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  507. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  508. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  509. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  510. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  511. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  513. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  514. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  515. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  516. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  517. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  518. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  519. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  520. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  522. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  523. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  524. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  526. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  527. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  528. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  531. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  533. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  534. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  535. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  536. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  537. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  538. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  540. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  541. mindspore/dataset/datapreprocess/__init__.py +0 -20
  542. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  543. mindspore/include/api/net.h +0 -142
  544. mindspore/nn/lr_scheduler.py +0 -262
  545. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  546. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  547. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  548. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  549. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  550. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  559. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  563. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  574. mindspore/rewrite/node_visitor.py +0 -44
  575. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  576. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,9 @@
14
14
  # ============================================================================
15
15
  """Unique name producer for target, name of node, class name, etc."""
16
16
 
17
- from typing import Union
17
+ from typing import Union, Tuple
18
18
 
19
- from .node import Node
19
+ from .node.node import Node
20
20
  from .api.node_type import NodeType
21
21
 
22
22
 
@@ -33,7 +33,7 @@ class Namer:
33
33
  self._names: {str: int} = {}
34
34
 
35
35
  @staticmethod
36
- def _real_name(name: str) -> str:
36
+ def _real_name(name: str) -> Tuple[str, int]:
37
37
  """
38
38
  Find real name. For example, "name1" is the real name of "name1_10", "name1" is the real name of "name1_10_3".
39
39
  If not find real name before find unique name, unique name may be not unique. For example:
@@ -47,21 +47,21 @@ class Namer:
47
47
  name (str): Origin name which may have digit prefix.
48
48
 
49
49
  Returns:
50
- A string represents real-name.
50
+ A string represents real-name and a int represents suffix.
51
51
  """
52
52
  if name == '_':
53
- return name
53
+ return name, None
54
54
  pos = name.rfind("_")
55
- if pos == -1:
56
- return name
55
+ if pos == -1 or pos == len(name) - 1:
56
+ return name, None
57
57
  digit = True
58
58
  for i in range(pos + 1, len(name)):
59
59
  if not name[i].isdigit():
60
60
  digit = False
61
61
  break
62
62
  if digit:
63
- return Namer._real_name(name[:pos])
64
- return name
63
+ return name[:pos], int(name[pos + 1:])
64
+ return name, None
65
65
 
66
66
  def get_name(self, origin_name: str) -> str:
67
67
  """
@@ -75,15 +75,28 @@ class Namer:
75
75
  """
76
76
  if origin_name == '_':
77
77
  return origin_name
78
- origin_name = Namer._real_name(origin_name)
79
- number = self._names.get(origin_name)
78
+ real_name, suffix_idx = Namer._real_name(origin_name)
79
+ name = origin_name
80
+ number = self._names.get(name)
80
81
  if number is None:
81
- self._names[origin_name] = 1
82
- return origin_name
82
+ self._names[name] = 1
83
+ if not suffix_idx:
84
+ # When _names is {x:2} and origin_name is y,
85
+ # origin_name is not in _names and can be returned.
86
+ return name
87
+ if suffix_idx and not self._names.get(real_name, -1) >= suffix_idx:
88
+ # When _names is {x:2} and origin_name is x_3,
89
+ # return x_3 and update _names to {x:2, x_3:1}
90
+ return name
91
+ # When _names is {x:2} and origin_name is x_1,
92
+ # set new_name to x_1_1 by set number to 1, and continue to update name.
93
+ number = 1
83
94
  while True:
84
- new_name = f"{origin_name}_{number}"
95
+ new_name = f"{name}_{number}"
85
96
  number += 1
86
- self._names[origin_name] = number
97
+ self._names[name] = number
98
+ # When _names is {x:2, x_3:1}, origin_name is x and number is update to 3,
99
+ # new_name x_3 is conflict with key x_3, so this new_name need to be skipped.
87
100
  if new_name in self._names.keys():
88
101
  continue
89
102
  return new_name
@@ -141,16 +154,12 @@ class NodeNamer(Namer):
141
154
  if origin_name is None or not origin_name:
142
155
  if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.CallFunction,
143
156
  NodeType.Tree):
144
- if not isinstance(node_or_name, Node):
145
- raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
146
- targets = node_or_name.get_targets()
147
- # return node and head node will not call this method
148
- if not targets:
149
- raise RuntimeError("node should has at lease one target except return-node and head-node: ",
150
- node_or_name)
151
- origin_name = str(targets[0].value)
157
+ origin_name = type(node_or_name.get_instance()).__name__
152
158
  elif node_or_name.get_node_type() == NodeType.Python:
153
- origin_name = node_or_name.get_instance().__name__
159
+ if node_or_name.get_instance():
160
+ origin_name = type(node_or_name.get_instance()).__name__
161
+ else:
162
+ origin_name = "python_node"
154
163
  elif node_or_name.get_node_type() == NodeType.Input:
155
164
  origin_name = "parameter"
156
165
  elif node_or_name.get_node_type() == NodeType.Output:
@@ -21,12 +21,21 @@ _ms_nn_ns = CellNamespace('mindspore.nn')
21
21
  _ms_ops_ns = CellNamespace('mindspore.ops.operations')
22
22
  _ms_functional_ns = CellNamespace('mindspore.ops.functional')
23
23
 
24
-
25
- def is_subtree(cls_name):
26
- """Determine whether 'cls_name' is a subtree."""
27
- if cls_name == "QuantizeWrapperCell":
24
+ # Elements in _subtree_black_list will not be converted to symbol tree.
25
+ # Only str and types are stored in _subtree_black_list.
26
+ _subtree_black_list = ["QuantizeWrapperCell",]
27
+
28
+ def is_subtree(cls_inst):
29
+ """Determine whether 'cls_inst' is a subtree."""
30
+ cls_name = type(cls_inst).__name__
31
+ black_list_types = tuple([elem for elem in _subtree_black_list if not isinstance(elem, str)])
32
+ if cls_name in _subtree_black_list or isinstance(cls_inst, black_list_types):
33
+ return False
34
+ if cls_name in _ms_common_ns and isinstance(cls_inst, _ms_common_ns[cls_name]):
35
+ return False
36
+ if cls_name in _ms_nn_ns and isinstance(cls_inst, _ms_nn_ns[cls_name]):
28
37
  return False
29
- if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:
38
+ if cls_name in _ms_ops_ns and isinstance(cls_inst, _ms_ops_ns[cls_name]):
30
39
  return False
31
40
 
32
41
  return True
@@ -1,4 +1,4 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,11 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
- """complex expanders init"""
16
-
17
- from .abs import CAbs
18
- from .add import CAdd
19
- from .div import CDiv
20
- from .mul import CMul
21
- from .sub import CSub
22
- from .real_div import CRealDiv
15
+ """
16
+ SymbolTree node
17
+ """
18
+ from mindspore.rewrite.node.node import Node, TreeNode
19
+ from mindspore.rewrite.node.node_manager import NodeManager
20
+ from mindspore.rewrite.node.call_function import CallFunction
21
+ from mindspore.rewrite.node.cell_container import CellContainer
22
+ from mindspore.rewrite.node.control_flow import ControlFlow
@@ -0,0 +1,79 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """CallFunction Node."""
16
+ import ast
17
+ from .node import Node
18
+ from .node_manager import NodeManager
19
+ from ..api.scoped_value import ScopedValue
20
+ from ..api.node_type import NodeType
21
+ from ..ast_helpers import AstModifier
22
+
23
+
24
+ class CallFunction(Node, NodeManager):
25
+ """CallFunction is used for class internal function."""
26
+ def __init__(self, targets: [ScopedValue], func_name: ScopedValue, args: [ScopedValue],
27
+ kwargs: {str: ScopedValue}, node_name: str, ast_node: ast.AST, ast_functiondef: ast.FunctionDef,
28
+ stree, instance):
29
+ """
30
+ Constructor of CallFunction.
31
+
32
+ Args:
33
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
34
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
35
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
36
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
37
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
38
+ SymbolTree. Name of node also used as field name in network class.
39
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
40
+ ast_functiondef (ast.FunctionDef): An instance of ast.FunctionDef represents corresponding function
41
+ definition in ast.
42
+ stree (SymbolTree): Symbol tree used to get node_namer.
43
+ instance: Object in network corresponding to this node.
44
+ """
45
+ if isinstance(func_name, str):
46
+ func_name = ScopedValue.create_naming_value(func_name)
47
+ Node.__init__(self, NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, instance)
48
+ NodeManager.__init__(self, stree.get_node_namer())
49
+ NodeManager.set_ast_functiondef(self, ast_functiondef)
50
+ NodeManager.set_manager_name(self, func_name.value)
51
+
52
+ def erase_node(self, node):
53
+ """Erase node from CallFunction."""
54
+ NodeManager.erase_node(self, node)
55
+ # erase asts
56
+ ret = AstModifier.erase_ast_from_function(self.get_ast_functiondef(), node.get_ast())
57
+ if not ret:
58
+ raise ValueError(f"erase node failed, node {node.get_name()} not in function ast tree.")
59
+
60
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
61
+ """
62
+ Insert a node before or after base_node.
63
+
64
+ Args:
65
+ new_node (Node): Node to be inserted.
66
+ base_node (Node): New node will be inserted before or after base_node.
67
+ before_node (bool): Indicate whether new node is inserted before base_node.
68
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
69
+ """
70
+ NodeManager.insert_node(self, new_node, base_node, before_node)
71
+ if insert_to_ast:
72
+ stree = self.get_belong_symbol_tree()
73
+ stree.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
74
+
75
+ def set_belong_symbol_tree(self, symbol_tree):
76
+ """Set the symbol tree to which node belongs."""
77
+ self._belong_tree = symbol_tree
78
+ for node in self.nodes():
79
+ node.set_belong_symbol_tree(symbol_tree)
@@ -0,0 +1,135 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """CellContainer Node."""
16
+ import ast
17
+ from mindspore import log as logger
18
+ from .node import Node
19
+ from .node_manager import NodeManager
20
+ from ..api.scoped_value import ScopedValue
21
+ from ..api.node_type import NodeType
22
+
23
+
24
+ class CellContainer(Node, NodeManager):
25
+ """CellContainer is used for nn.SequencialCell."""
26
+
27
+ def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func_name: ScopedValue,
28
+ args: [ScopedValue], kwargs: {str: ScopedValue}, node_name: str, stree, instance):
29
+ """Constructor of CellContainer.
30
+
31
+ Args:
32
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
33
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
34
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
35
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
36
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
37
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
38
+ SymbolTree. Name of node also used as field name in network class.
39
+ stree (SymbolTree): Symbol tree used to get node_namer.
40
+ instance: Object in network corresponding to this node.
41
+ """
42
+ if isinstance(func_name, str):
43
+ func_name = ScopedValue.create_naming_value(func_name)
44
+ Node.__init__(self, NodeType.CellContainer, ast_node, targets, func_name, args, kwargs, node_name, instance)
45
+ NodeManager.__init__(self, stree.get_node_namer())
46
+ NodeManager.set_manager_name(self, func_name.value)
47
+
48
+ def append(self, node, insert_to_ast: bool = True):
49
+ """ Append new node to node list. """
50
+ self.append_node(node, insert_to_ast)
51
+
52
+ def append_node(self, node, insert_to_ast: bool = True):
53
+ """ Append new node to node list. """
54
+ self.insert_node(node, self.get_tail(), False, insert_to_ast)
55
+
56
+ def erase(self, node):
57
+ """Erase node from container."""
58
+ self.erase_node(node)
59
+
60
+ def erase_node(self, node):
61
+ """Erase node from container."""
62
+ # add code `del self.container_name[node_index]` into __init__ function
63
+ _, init_ast_functiondef = self._get_stree_and_init_ast()
64
+ if not init_ast_functiondef:
65
+ logger.error(f"Erase node {node.get_name()} failed: get symboltree and __init__ ast failed.")
66
+ return
67
+ node_idx = self.nodes().index(node)
68
+ erase_code = f"del {self.get_func_name()}[{node_idx}]"
69
+ erase_ast = ast.parse(erase_code).body[0]
70
+ init_ast_functiondef.body.append(erase_ast)
71
+ # earse node in NodeManager
72
+ NodeManager.erase_node(self, node)
73
+
74
+ def insert(self, index, node, insert_to_ast: bool = True):
75
+ """Insert node into container according index"""
76
+ node_index = index + len(self._inputs)
77
+ if node_index >= self.node_count:
78
+ raise IndexError("In MindSpore Rewrite CellContainer, inserting a node raises index error! "
79
+ f"node_index: {node_index} >= node_num: {self.node_count}")
80
+ self.insert_node(node, self.nodes()[node_index], False, insert_to_ast)
81
+
82
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
83
+ """
84
+ Insert a node before or after base_node.
85
+
86
+ The instance is modified here. The scenario needs to be optimized.
87
+
88
+ Args:
89
+ new_node (Node): Node to be inserted.
90
+ base_node (Node): New node will be inserted before or after base_node.
91
+ before_node (bool): Indicate whether new node is inserted before base_node.
92
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
93
+ """
94
+ # Insert node to NodeManager firstly to update node_name, which is used during insert ast.
95
+ # tail_node may be changed after insert node into node_manager, so we record tail node here.
96
+ tail_node = self.get_tail()
97
+ NodeManager.insert_node(self, new_node, base_node, before_node)
98
+ new_node.set_func_name(ScopedValue.create_naming_value(new_node.get_name()))
99
+ new_node.update_ast_node()
100
+ # add insert/append code into __init__ function
101
+ if insert_to_ast:
102
+ stree, init_ast_functiondef = self._get_stree_and_init_ast()
103
+ if not init_ast_functiondef:
104
+ logger.error(f"Insert new_node {new_node.get_name()} failed: get symboltree and __init__ ast failed.")
105
+ return
106
+ setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
107
+ node_idx = self.nodes().index(base_node)
108
+ if before_node:
109
+ insert_code = f"{self.get_func_name()}._insert({node_idx}, self.{new_node.get_name()})"
110
+ else:
111
+ if base_node == tail_node:
112
+ insert_code = f"{self.get_func_name()}.append(self.{new_node.get_name()})"
113
+ else:
114
+ insert_code = f"{self.get_func_name()}._insert({node_idx + 1}, self.{new_node.get_name()})"
115
+ insert_ast = ast.parse(insert_code).body[0]
116
+ init_ast_functiondef.body.append(insert_ast)
117
+
118
+ def set_belong_symbol_tree(self, symbol_tree):
119
+ """Set the symbol tree to which node belongs."""
120
+ self._belong_tree = symbol_tree
121
+ for node in self.nodes():
122
+ node.set_belong_symbol_tree(symbol_tree)
123
+
124
+ def _get_stree_and_init_ast(self):
125
+ """Get symbol tree and ast of __init__ function from container."""
126
+ # add codes `del self.container_name[node_index]`` into __init__ function
127
+ stree = self.get_belong_symbol_tree()
128
+ if stree is None:
129
+ logger.error(f"Get symboltree of CellContainer {self.get_name()} failed.")
130
+ return None, None
131
+ init_ast_functiondef = stree.get_init_func_ast()
132
+ if init_ast_functiondef is None:
133
+ logger.error(f"Get ast of __init__ function from class {stree.get_opt_cls_name()} failed.")
134
+ return None, None
135
+ return stree, init_ast_functiondef
@@ -0,0 +1,88 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """ControlFlow Node."""
16
+ from typing import List
17
+ import ast
18
+ from .node import Node, TreeNode
19
+ from .node_manager import NodeManager
20
+ from ..api.scoped_value import ScopedValue
21
+ from ..api.node_type import NodeType
22
+ from ..ast_helpers import AstModifier
23
+
24
+
25
+ class ControlFlow(Node, NodeManager):
26
+ """ControlFlow node is used for statements like loops and `if` ."""
27
+ def __init__(self, node_name: str, ast_body: List[ast.AST], stree):
28
+ """
29
+ Constructor of ControlFlow.
30
+
31
+ Args:
32
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
33
+ SymbolTree. Name of node also used as field name in network class.
34
+ ast_node (ast.AST): An instance of ast.AST represents control flow statements, can be one of ast.If,
35
+ ast.Ifexp, ast.For, ast.While.
36
+ is_orelse (bool): Whether process else branch of node.
37
+ stree (SymbolTree): Symbol tree used to get node_namer.
38
+ """
39
+ Node.__init__(self, NodeType.ControlFlow, ast_body, None, node_name, [], [], node_name, None)
40
+ NodeManager.__init__(self, stree.get_node_namer())
41
+ NodeManager.set_manager_name(self, node_name)
42
+ self.ast_body = ast_body
43
+
44
+ def erase_node(self, node):
45
+ """Erase node from container."""
46
+ NodeManager.erase_node(self, node)
47
+ # erase node's ast
48
+ ret = AstModifier.erase_ast_from_bodies(self.ast_body, node.get_ast())
49
+ if not ret:
50
+ raise ValueError(f"Erase node failed, node {node.get_name()} is not in ControlFlow ast tree.")
51
+
52
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
53
+ """
54
+ Insert a node before or after base_node.
55
+
56
+ Args:
57
+ new_node (Node): Node to be inserted.
58
+ base_node (Node): New node will be inserted before or after base_node.
59
+ before_node (bool): Indicate whether new node is inserted before base_node.
60
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
61
+ """
62
+ NodeManager.insert_node(self, new_node, base_node, before_node)
63
+ if insert_to_ast:
64
+ ast_assign = new_node.get_ast()
65
+ if ast_assign is None:
66
+ func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
67
+ new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
68
+ ast_assign = new_node.update_ast_node()
69
+ # Save instance into _origin_network.
70
+ stree = self.get_belong_symbol_tree()
71
+ setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
72
+ # Insert ast_assign to __init__ function
73
+ if isinstance(new_node, TreeNode):
74
+ init_code = f"self.{new_node.get_name()} = " \
75
+ f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
76
+ else:
77
+ init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
78
+ init_ast = ast.parse(init_code).body[0]
79
+ AstModifier.insert_assign_ast_to_function(stree.get_init_func_ast(), init_ast)
80
+ # Insert ast_assign to bodies
81
+ ast_base_node = base_node.get_ast() if base_node else None
82
+ AstModifier.insert_assign_ast_to_bodies(self.ast_body, ast_assign, ast_base_node, before_node)
83
+
84
+ def set_belong_symbol_tree(self, symbol_tree):
85
+ """Set the symbol tree to which node belongs."""
86
+ self._belong_tree = symbol_tree
87
+ for node in self.nodes():
88
+ node.set_belong_symbol_tree(symbol_tree)