mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -71,15 +71,21 @@ AICORE_METRICS_DICT = {
71
71
  class DeviceSupportParam(Enum):
72
72
  """The device target enum."""
73
73
  CPU = ['start', 'start_profile', 'output_path', 'timeline_limit', 'profile_framework', 'op_time']
74
- GPU = ['start', 'start_profile', 'output_path', 'data_process', 'timeline_limit', 'sync_enable', 'op_time',
75
- 'profile_framework']
76
- ASCEND = ['start', 'start_profile', 'output_path', 'data_process', 'timeline_limit', 'profile_memory',
77
- 'parallel_strategy', 'profile_communication', 'aicore_metrics', 'l2_cache', 'op_time', 'ascend_job_id',
78
- 'profile_framework']
74
+ GPU = [
75
+ 'start', 'start_profile', 'output_path', 'data_process', 'timeline_limit', 'sync_enable', 'op_time',
76
+ 'profile_framework'
77
+ ]
78
+ ASCEND = [
79
+ 'start', 'start_profile', 'output_path', 'data_process', 'timeline_limit', 'profile_memory',
80
+ 'parallel_strategy', 'profile_communication', 'aicore_metrics', 'l2_cache', 'op_time', 'ascend_job_id',
81
+ 'profile_framework'
82
+ ]
79
83
 
80
84
 
81
- ALWAYS_VALID_PARAM = ['start', 'start_profile', 'output_path', 'data_process', 'parallel_strategy', 'l2_cache',
82
- 'ascend_job_id', 'op_time', 'profile_framework']
85
+ ALWAYS_VALID_PARAM = [
86
+ 'start', 'start_profile', 'output_path', 'data_process', 'parallel_strategy', 'l2_cache',
87
+ 'ascend_job_id', 'op_time', 'profile_framework'
88
+ ]
83
89
 
84
90
 
85
91
  def _environment_check():
@@ -161,6 +167,7 @@ def _calculate_dataset_execution_time(input_file, output_file):
161
167
  csv_writer.writerow(['Operation', 'Stage', 'Occurrences', 'Avg. time (us)', 'Custom Info'])
162
168
  for _, v in execution_time_map.items():
163
169
  csv_writer.writerow([v.event, v.stage, v.count, v.average_execution, v.custom_info])
170
+ os.chmod(output_file, modes)
164
171
  logger.info('Successfully calculate the execution time and write it to file: %s.', output_file)
165
172
 
166
173
 
@@ -188,8 +195,10 @@ def _extract_timeline_item(row, time_line, ts_map):
188
195
  # Put the instance event into timeline.
189
196
  elif start_end == '2':
190
197
  title = row['event'] + '::' + row['stage']
191
- event = {'name': title, 'cat': row['module_name'], 'ts': int(row['time_stamp(us)']), 'ph': 'i',
192
- 'pid': row['pid'], 'tid': row['tid'], 'args': {'parent_pid': row['parent_pid']}}
198
+ event = {
199
+ 'name': title, 'cat': row['module_name'], 'ts': int(row['time_stamp(us)']), 'ph': 'i',
200
+ 'pid': row['pid'], 'tid': row['tid'], 'args': {'parent_pid': row['parent_pid']}
201
+ }
193
202
  time_line.append(event)
194
203
  else:
195
204
  logger.warning("Can not map the start time for item: %s.", row)
@@ -209,8 +218,10 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
209
218
  time_line = []
210
219
  # ts_map is used to store the start time of each event_stage_tid_pid
211
220
  ts_map = {}
212
- memory_header = ['tid', 'pid', 'parent_pid', 'module_name', 'event', 'stage', 'level', 'start_end', 'custom_info',
213
- 'memory_usage(kB)', 'time_stamp(us)']
221
+ memory_header = [
222
+ 'tid', 'pid', 'parent_pid', 'module_name', 'event', 'stage', 'level', 'start_end', 'custom_info',
223
+ 'memory_usage(kB)', 'time_stamp(us)'
224
+ ]
214
225
  memory_info = []
215
226
  with open(input_file, 'r') as f:
216
227
  for row in csv.DictReader(f):
@@ -226,12 +237,12 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
226
237
  logger.error("Error occur when analyse line: %s, Details is: %s", row, e)
227
238
  continue
228
239
  if memory_info:
229
- with os.fdopen(os.open(output_memory_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
230
- 'w') as csv_file:
240
+ with os.fdopen(os.open(output_memory_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as csv_file:
231
241
  csv_writer = csv.DictWriter(csv_file, fieldnames=memory_header)
232
242
  csv_writer.writeheader()
233
243
  for item in memory_info:
234
244
  csv_writer.writerow(item)
245
+ os.chmod(output_memory_file, stat.S_IREAD | stat.S_IWRITE)
235
246
  else:
236
247
  logger.warning("No memory_usage is record in file: %s", input_file)
237
248
 
@@ -255,13 +266,23 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
255
266
 
256
267
  if time_line:
257
268
  timeline_file = validate_and_normalize_path(output_timeline_file)
258
- with os.fdopen(os.open(timeline_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
259
- 'w') as json_file:
269
+ with os.fdopen(os.open(timeline_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as json_file:
260
270
  json.dump(time_line, json_file)
271
+ os.chmod(timeline_file, stat.S_IREAD | stat.S_IWRITE)
261
272
  else:
262
273
  logger.warning("No valid time_stamp is record in file: %s", input_file)
263
274
 
264
275
 
276
+ def _ascend_graph_msprof_generator(source_path, model_iteration_dict):
277
+ try:
278
+ msprof_exporter = AscendMsprofExporter(source_path)
279
+ msprof_exporter.export(model_iteration_dict)
280
+ except ProfilerException as err:
281
+ logger.warning(err.message)
282
+ finally:
283
+ pass
284
+
285
+
265
286
  def _ascend_graph_msprof_analyse(source_path):
266
287
  """
267
288
  Ascend graph model msprof data analyse.
@@ -287,7 +308,8 @@ class Profiler:
287
308
  This class to enable the profiling of MindSpore neural networks.
288
309
  MindSpore users can import the mindspore.Profiler, initialize the Profiler object to start profiling,
289
310
  and use Profiler.analyse() to stop profiling and analyse the results.
290
- Users can visualize the results using the MindInsight tool.
311
+ Users can visualize the results using the `MindSpore Insight
312
+ <https://www.mindspore.cn/mindinsight/docs/en/r2.2/index.html>`_ tool.
291
313
  Now, Profiler supports AICORE operator, AICPU operator, HostCPU operator, memory,
292
314
  correspondence, cluster, etc data analysis.
293
315
 
@@ -330,11 +352,16 @@ class Profiler:
330
352
  Default value: ``True`` .
331
353
  timeline_limit (int, optional): (Ascend/GPU) Set the maximum storage size of the timeline file (unit M).
332
354
  When using this parameter, `op_time` must be set to True. Default value: ``500`` .
333
- profile_framework (str, optional): (Ascend/GPU) Whether to collect host memory and time, it must be one of
334
- ["all", "time", "memory", None]. When is enabled, a subdirectory host_info will be generated in the
355
+ profile_framework (str, optional): (Ascend/GPU) The host information to collect, it must be one of
356
+ ["all", "time", "memory", None], When is not set to None, a subdirectory host_info will be generated in the
335
357
  specified profiler directory, which stores the collected memory and time files on the Host side.
336
358
  Default: "all".
337
359
 
360
+ - "all": Record both host timestamp and host memory usage.
361
+ - "time": Only record host timestamp.
362
+ - "memory": Only record host memory usage.
363
+ - None: Not record host information.
364
+
338
365
  Raises:
339
366
  RuntimeError: When the version of CANN does not match the version of MindSpore,
340
367
  MindSpore cannot parse the generated ascend_job_id directory structure.
@@ -407,7 +434,6 @@ class Profiler:
407
434
  self._rank_size = 1
408
435
  self._rank_id = 0
409
436
  self._ascend_profiler = None
410
- self._ascend_msprof_exporter = None
411
437
  self._timeline_size_limit_byte = 500 * 1024 * 1024 # 500MB
412
438
  self._parallel_strategy = True
413
439
  _environment_check()
@@ -424,6 +450,7 @@ class Profiler:
424
450
  self._sync_enable = True
425
451
  self._stop_time = 0
426
452
  self._dynamic_status = False
453
+ self._model_iteration_dict = None
427
454
  self._profile_framework = "all"
428
455
  self._msprof_enable = os.getenv("PROFILER_SAMPLECONFIG")
429
456
  if self._msprof_enable:
@@ -476,6 +503,25 @@ class Profiler:
476
503
 
477
504
  return job_start_time
478
505
 
506
+ @staticmethod
507
+ def _parse_info_json(info_file):
508
+ """
509
+ Parse info log file, get the rank id and device id of the job.
510
+ Args:
511
+ input_file (str): The file path of the parse info log file.
512
+
513
+ Returns:
514
+ rank id, device id
515
+ """
516
+ with open(info_file, "r") as f:
517
+ info_dict = json.load(f)
518
+
519
+ rank_id = info_dict.get("rank_id", 0)
520
+ dev_info = info_dict.get("DeviceInfo", [])
521
+ dev_id = dev_info[0].get("id", -1)
522
+
523
+ return str(rank_id), str(dev_id)
524
+
479
525
  def op_analyse(self, op_name, device_id=None):
480
526
  """
481
527
  Profiler users can use this interface to obtain operator performance data.
@@ -487,8 +533,8 @@ class Profiler:
487
533
  parse. If this interface is used for offline data parsing, Default: ``0`` .
488
534
 
489
535
  Raises:
490
- TypeError: If the op_name parameter type is incorrect.
491
- TypeError: If the device_id parameter type is incorrect.
536
+ TypeError: If the `op_name` parameter type is incorrect.
537
+ TypeError: If the `device_id` parameter type is incorrect.
492
538
  RuntimeError: If MindSpore runs on Ascend, this interface cannot be used.
493
539
 
494
540
  Supported Platforms:
@@ -501,12 +547,12 @@ class Profiler:
501
547
  >>> # Profiler init.
502
548
  >>> profiler = Profiler()
503
549
  >>> # Train Model or eval Model, taking LeNet5 as an example.
504
- >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
550
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
505
551
  >>> net = LeNet5()
506
552
  >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
507
553
  >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
508
554
  >>> # Create the dataset taking MNIST as an example.
509
- >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py
555
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
510
556
  >>> dataloader = create_dataset()
511
557
  >>> model = Model(net, loss, optimizer)
512
558
  >>> model.train(5, dataloader, dataset_sink_mode=False)
@@ -550,7 +596,22 @@ class Profiler:
550
596
  Offline mode isused in abnormal exit scenario. This parameter should be set to ``None``
551
597
  for online mode. Default: ``None``.
552
598
  """
599
+ self._analyse(offline_path=offline_path)
600
+
601
+ def _analyse(self, offline_path=None, model_iteration_dict=None):
602
+ """
603
+ Collect and analyze training performance data, support calls during and after training. The example shows above.
604
+
605
+ Args:
606
+ offline_path (Union[str, None], optional): The data path which need to be analysed with offline mode.
607
+ Offline mode isused in abnormal exit scenario. This parameter should be set to ``None``
608
+ for online mode. Default: ``None``.
609
+ model_iteration_dict: Dictionary with model id as the key and iteration id as the value, Default: ``None``.
610
+ """
611
+ self._model_iteration_dict = model_iteration_dict
553
612
  if offline_path:
613
+ if self._is_offline_parser():
614
+ self._ascend_graph_analyse()
554
615
  _offline_parse(offline_path)
555
616
  return
556
617
  if self._msprof_enable:
@@ -602,7 +663,7 @@ class Profiler:
602
663
  Raises:
603
664
  RuntimeError: If the profiler has already started.
604
665
  RuntimeError: If MD profiling has stopped, repeated start action is not supported.
605
- RuntimeError: If the start_profile parameter is not set or is set to True.
666
+ RuntimeError: If the `start_profile` parameter is not set or is set to ``True``.
606
667
 
607
668
  Examples:
608
669
  >>> from mindspore.train import Callback
@@ -749,7 +810,6 @@ class Profiler:
749
810
 
750
811
  if self._device_target == DeviceTarget.ASCEND.value:
751
812
  self._ascend_profiler = c_expression.Profiler.get_instance("Ascend")
752
- self._ascend_msprof_exporter = AscendMsprofExporter(self._output_path)
753
813
  self._get_devid_rankid_and_devtarget()
754
814
 
755
815
  def _init_profiler_info(self):
@@ -827,7 +887,6 @@ class Profiler:
827
887
  # use context interface to open profiling, for the new mindspore version(after 2020.5.21)
828
888
  self._ascend_profiler = c_expression.Profiler.get_instance("Ascend")
829
889
  self._ascend_profiler.init(self._output_path, int(self._dev_id), self._ascend_profiling_options)
830
- self._ascend_msprof_exporter = AscendMsprofExporter(self._output_path)
831
890
  base_profiling_container_path = os.path.join(self._output_path, "container")
832
891
  container_path = os.path.join(base_profiling_container_path, self._dev_id)
833
892
  data_path = os.path.join(container_path, "data")
@@ -965,8 +1024,6 @@ class Profiler:
965
1024
  else:
966
1025
  logger.info("No need to stop profiler because profiler has been stopped.")
967
1026
  # export op data before analyse
968
- if self._op_time:
969
- self._ascend_msprof_exporter.export(self._start_time, support_step_trace=False)
970
1027
  self._ascend_graph_analyse()
971
1028
 
972
1029
  def _minddata_analyse(self, source_path):
@@ -1040,8 +1097,11 @@ class Profiler:
1040
1097
  aicpu_intermediate_detail_path = validate_and_normalize_path(aicpu_intermediate_detail_path)
1041
1098
  framework_raw_path = validate_and_normalize_path(framework_raw_path)
1042
1099
 
1043
- output_timeline_data_path = os.path.join(self._output_path, f'output_timeline_data_{dev_id}.txt')
1044
- output_timeline_data_path = validate_and_normalize_path(output_timeline_data_path)
1100
+ if context.get_context("mode") == context.GRAPH_MODE:
1101
+ output_timeline_data_path = os.path.join(self._output_path, f'output_timeline_data_{dev_id}.txt')
1102
+ output_timeline_data_path = validate_and_normalize_path(output_timeline_data_path)
1103
+ else:
1104
+ output_timeline_data_path = None
1045
1105
 
1046
1106
  op_analyser = AscendOPGenerator(op_summary, op_statistic, dynamic_status)
1047
1107
  op_analyser.parse()
@@ -1070,7 +1130,7 @@ class Profiler:
1070
1130
  finally:
1071
1131
  pass
1072
1132
 
1073
- def _ascend_timeline_analyse(self, source_path, op_summary, steptrace):
1133
+ def _ascend_timeline_analyse(self, op_summary, steptrace):
1074
1134
  """Analyse timeline info."""
1075
1135
  try:
1076
1136
  logger.info("Profiling: analyzing the timeline data")
@@ -1142,6 +1202,7 @@ class Profiler:
1142
1202
  if self._profile_communication and context.get_context("mode") == context.PYNATIVE_MODE:
1143
1203
  logger.warning("[Profiler]The parameter profile_communication is not supported on Ascend "
1144
1204
  "PyNative mode currently.")
1205
+ return
1145
1206
  try:
1146
1207
  logger.info("Profiling: analyzing the hccl profiler info.")
1147
1208
  dev_id = self._rank_id if self._device_target == DeviceTarget.ASCEND.value else self._dev_id
@@ -1191,9 +1252,10 @@ class Profiler:
1191
1252
  source_path = os.path.join(self._output_path, job_id)
1192
1253
  self._minddata_analyse(source_path)
1193
1254
  if self._op_time:
1255
+ _ascend_graph_msprof_generator(source_path, self._model_iteration_dict)
1194
1256
  op_summary, op_statistic, steptrace = _ascend_graph_msprof_analyse(source_path)
1195
1257
  self._ascend_op_analyse(op_summary, op_statistic, self._dynamic_status)
1196
- self._ascend_timeline_analyse(source_path, op_summary, steptrace)
1258
+ self._ascend_timeline_analyse(op_summary, steptrace)
1197
1259
  graph_ids = np.unique(op_summary['Model ID']).tolist()
1198
1260
  points = self._ascend_fpbp_analyse(op_summary, steptrace)
1199
1261
  if len(graph_ids) == 1:
@@ -1326,29 +1388,37 @@ class Profiler:
1326
1388
  point_info_file_path = validate_and_normalize_path(point_info_file_path)
1327
1389
 
1328
1390
  if self._device_target and self._device_target == DeviceTarget.GPU.value:
1329
- input_file_path = os.path.join(self._output_path, f'step_trace_profiling_{self._dev_id}.txt')
1330
- input_file_path = validate_and_normalize_path(input_file_path)
1331
- parser = GpuStepTraceParser(input_dir=input_file_path,
1332
- output_file_path=step_trace_intermediate_file_path,
1333
- is_training_mode=is_training_mode_flag,
1334
- is_gpu_kernel_async_launch=is_gpu_kernel_async_launch_flag)
1335
- parser.parse_and_save()
1336
- point_info = parser.record_point_info(point_info_file_path)
1337
- else:
1338
- # whether keep the first step
1339
- skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME)
1340
- point_info = framework_parser.point_info
1341
- # recognize inference or training mode
1342
- is_training_mode_flag = framework_parser.check_op_name("Gradients")
1343
- # parser the step trace files and save the result to disk
1344
- source_path = validate_and_normalize_path(source_path)
1345
- parser = AscendStepTraceParser(input_dir=source_path,
1346
- output_file_path=step_trace_intermediate_file_path,
1347
- skip_first_step=skip_first_step_flag,
1348
- is_training_mode=is_training_mode_flag)
1349
- parser.set_task_id_op_name_dict(framework_parser.to_task_id_full_op_name_dict())
1350
- parser.parse_and_save()
1351
- point_info = parser.record_point_info(point_info_file_path)
1391
+ if context.get_context("mode") != context.PYNATIVE_MODE:
1392
+ input_file_path = os.path.join(self._output_path, f'step_trace_profiling_{self._dev_id}.txt')
1393
+ input_file_path = validate_and_normalize_path(input_file_path)
1394
+ parser = GpuStepTraceParser(input_dir=input_file_path,
1395
+ output_file_path=step_trace_intermediate_file_path,
1396
+ is_training_mode=is_training_mode_flag,
1397
+ is_gpu_kernel_async_launch=is_gpu_kernel_async_launch_flag)
1398
+ parser.parse_and_save()
1399
+ point_info = parser.record_point_info(point_info_file_path)
1400
+ # print parser result
1401
+ parser.show()
1402
+ logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path)
1403
+ logger.info("The point info is: %s", point_info)
1404
+
1405
+ return point_info, is_training_mode_flag
1406
+ return {}, is_training_mode_flag
1407
+
1408
+ # whether keep the first step
1409
+ skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME)
1410
+ # recognize inference or training mode
1411
+ is_training_mode_flag = framework_parser.check_op_name("Gradients")
1412
+ # parser the step trace files and save the result to disk
1413
+ source_path = validate_and_normalize_path(source_path)
1414
+ parser = AscendStepTraceParser(input_dir=source_path,
1415
+ output_file_path=step_trace_intermediate_file_path,
1416
+ skip_first_step=skip_first_step_flag,
1417
+ is_training_mode=is_training_mode_flag)
1418
+ parser.set_task_id_op_name_dict(framework_parser.to_task_id_full_op_name_dict())
1419
+ parser.parse_and_save()
1420
+ point_info = parser.record_point_info(point_info_file_path)
1421
+
1352
1422
  # print parser result
1353
1423
  parser.show()
1354
1424
  logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path)
@@ -1393,11 +1463,10 @@ class Profiler:
1393
1463
  return job_id
1394
1464
 
1395
1465
  job_id = ""
1396
- job_dirs = filter(lambda item: item.startswith('JOB') or item.startswith('PROF') and \
1397
- os.path.isdir(os.path.join(self._output_path, item)),
1398
- os.listdir(self._output_path))
1399
- sorted_job_dirs = sorted(job_dirs, key=lambda x: os.path.getmtime(os.path.join(self._output_path, x)),
1400
- reverse=True)
1466
+ job_dirs = filter(lambda item: item.startswith('JOB') or item.startswith('PROF') and os.path.isdir(
1467
+ os.path.join(self._output_path, item)), os.listdir(self._output_path))
1468
+ sorted_job_dirs = sorted(
1469
+ job_dirs, key=lambda x: os.path.getmtime(os.path.join(self._output_path, x)), reverse=True)
1401
1470
 
1402
1471
  for dir_name in sorted_job_dirs:
1403
1472
  if dir_name.startswith('PROF'):
@@ -1414,22 +1483,21 @@ class Profiler:
1414
1483
  "profiler will ignore this job dir.", job_dir)
1415
1484
  continue
1416
1485
 
1417
- training_device_id = start_file_path.split('.')[-1]
1486
+ info_file_path = get_file_path(job_dir, "info.json")
1487
+ if info_file_path is None:
1488
+ logger.warning("Find profiling job path %s, but info.json not exist, "
1489
+ "profiler will ignore this job dir.", job_dir)
1490
+ continue
1491
+
1492
+ _, training_device_id = self._parse_info_json(info_file_path)
1493
+ job_start_time = self._parse_start_log(start_file_path)
1494
+
1418
1495
  if self._dev_id != training_device_id:
1419
1496
  logger.debug("Find profiling find job path %s, but not current training device id. "
1420
1497
  "Current training device id %s, but job path device id: %s, "
1421
1498
  "profiler will ignore this job dir.", job_dir, self._dev_id, training_device_id)
1422
1499
  continue
1423
1500
 
1424
- if not os.listdir(os.path.join(job_dir, 'data')):
1425
- continue
1426
-
1427
- job_start_time = self._parse_start_log(start_file_path)
1428
- if not job_start_time:
1429
- logger.warning("Find profiling job path %s, but fail to get job start info, "
1430
- "profiler will ignore this job dir.", job_start_time)
1431
- continue
1432
-
1433
1501
  if int(job_start_time) < self._start_time:
1434
1502
  logger.warning("Find profiling job path %s, but start_time(%d) is earlier than this training "
1435
1503
  "start_time(%d), profiler will ignore this job dir.",
@@ -1586,7 +1654,7 @@ class Profiler:
1586
1654
  self._profile_framework = kwargs.pop("profile_framework", "all")
1587
1655
  if self._profile_framework not in ["memory", "time", "all", None]:
1588
1656
  logger.warning(f"For '{self.__class__.__name__}', the parameter profile_framework must be one of ['memory',"
1589
- f" 'time', 'all', None]but got {self._profile_framework}, it will be set to 'all'.")
1657
+ f" 'time', 'all', None], but got {self._profile_framework}, it will be set to 'all'.")
1590
1658
  self._profile_framework = "all"
1591
1659
 
1592
1660
  def _host_info_analyse(self):
@@ -14,12 +14,13 @@
14
14
  # ============================================================================
15
15
  """Rewrite module api: Node."""
16
16
 
17
- from typing import Union, Optional
17
+ from typing import Union, Optional, List, Dict
18
+ from types import FunctionType
18
19
 
19
20
  from mindspore.nn import Cell
20
21
  from mindspore.ops.primitive import Primitive
21
22
  from mindspore import _checkparam as Validator
22
- from ..node import Node as NodeImpl
23
+ from ..node.node import Node as NodeImpl
23
24
  from ..symbol_tree import SymbolTree as SymbolTreeImpl
24
25
  from .node_type import NodeType
25
26
  from .scoped_value import ScopedValue
@@ -50,8 +51,8 @@ class Node:
50
51
  return self._node == other._node
51
52
 
52
53
  @staticmethod
53
- def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
54
- kwargs: {str: ScopedValue}=None, name: str = "", is_sub_net: bool = False) -> 'Node':
54
+ def create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None,
55
+ kwargs: Dict[str, ScopedValue] = None, name: str = "", is_sub_net: bool = False) -> 'Node':
55
56
  """
56
57
  Create a node. Only support create from a `Cell` now.
57
58
 
@@ -63,14 +64,15 @@ class Node:
63
64
 
64
65
  Args:
65
66
  cell (Cell): Cell-operator of this forward-layer.
66
- targets (list[ScopedValue]): Indicate output names. Used as targets of an assign statement in source code.
67
- args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
67
+ targets (List[Union[ScopedValue, str]]): Indicate output names. Used as targets of an assign statement in
68
+ source code.
69
+ args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
68
70
  source code. Default: ``None`` , which indicates the `cell` has no args inputs.
69
- kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
71
+ kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
70
72
  Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
71
73
  code. Default: ``None`` , which indicates the `cell` has no kwargs inputs.
72
74
  name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
73
- generate name from `targets` when name is None. Rewrite will check and ensure the uniqueness of `name`
75
+ generate name from `cell` when name is None. Rewrite will check and ensure the uniqueness of `name`
74
76
  while node being inserted. Default: ``""`` .
75
77
  is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse
76
78
  the `cell` to a TreeNode, otherwise the `cell` is parsed to a CallCell node. Default: ``False`` .
@@ -89,7 +91,7 @@ class Node:
89
91
  >>> from mindspore.rewrite import SymbolTree, ScopedValue
90
92
  >>> import mindspore.nn as nn
91
93
  >>> # Define the network structure of LeNet5. Refer to
92
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
94
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
93
95
  >>> net = LeNet5()
94
96
  >>> stree = SymbolTree.create(net)
95
97
  >>> node = stree.get_node("conv1")
@@ -108,8 +110,66 @@ class Node:
108
110
  Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
109
111
  if kwargs is not None:
110
112
  Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
111
- return Node(NodeImpl.create_call_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"),
112
- args, kwargs, name, is_sub_net))
113
+ return Node(NodeImpl.create_call_op(cell, None, targets, args, kwargs, name, is_sub_net))
114
+
115
+ @staticmethod
116
+ def create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]],
117
+ args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None) -> 'Node':
118
+ """
119
+ Create a node that corresponds to a function call. The `function` object is saved into network, and used via
120
+ getting object from `self.` .
121
+
122
+ Args:
123
+ function (FunctionType): The function to be called.
124
+ targets (List[Union[ScopedValue, str]]): indicates output names. Used as targets of an assign statement in
125
+ source code.
126
+ args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
127
+ source code. Default: ``None`` , which indicates the `function` has no args inputs.
128
+ kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
129
+ Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
130
+ code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
131
+
132
+ Returns:
133
+ An instance of `Node`.
134
+
135
+ Raises:
136
+ TypeError: If `function` is not a `FunctionType`.
137
+ TypeError: If `targets` is not `list`.
138
+ TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
139
+ TypeError: If arg in `args` is not a `ScopedValue`.
140
+ TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
141
+
142
+ Examples:
143
+ >>> from mindspore.rewrite import SymbolTree, ScopedValue
144
+ >>> import mindspore.nn as nn
145
+ >>> import mindspore.ops as ops
146
+ >>> # Define the network structure of LeNet5. Refer to
147
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
148
+ >>> net = LeNet5()
149
+ >>> stree = SymbolTree.create(net)
150
+ >>> node = stree.get_node("conv1")
151
+ >>> position = stree.after(node)
152
+ >>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
153
+ ... args=[ScopedValue.create_naming_value('x')])
154
+ >>> stree.insert(position, new_node)
155
+ >>> print(new_node.get_node_type())
156
+ NodeType.CallFunction
157
+ """
158
+ Validator.check_value_type("function", function, [FunctionType, type], "create_call_function")
159
+ Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "create_call_function")
160
+ if args is not None:
161
+ Validator.check_element_type_of_iterable("args", args, [ScopedValue], "create_call_function")
162
+ if kwargs is not None:
163
+ Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "create_call_function")
164
+ return Node(NodeImpl._create_call_function(function, targets, args, kwargs))
165
+
166
+ @staticmethod
167
+ def create_input(param_name: str, default: Optional[ScopedValue] = None) -> 'Node':
168
+ # pylint: disable=missing-function-docstring
169
+ Validator.check_value_type("param_name", param_name, [str], "Node")
170
+ if default is not None:
171
+ Validator.check_value_type("default", default, [ScopedValue], "Node")
172
+ return Node(NodeImpl.create_input_node(None, param_name, default, name=f"input_{param_name}"))
113
173
 
114
174
  def get_handler(self) -> NodeImpl:
115
175
  return self._node
@@ -124,7 +184,7 @@ class Node:
124
184
  Examples:
125
185
  >>> from mindspore.rewrite import SymbolTree
126
186
  >>> # Define the network structure of LeNet5. Refer to
127
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
187
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
128
188
  >>> net = LeNet5()
129
189
  >>> stree = SymbolTree.create(net)
130
190
  >>> node = stree.get_node("conv2")
@@ -144,7 +204,7 @@ class Node:
144
204
  Examples:
145
205
  >>> from mindspore.rewrite import SymbolTree
146
206
  >>> # Define the network structure of LeNet5. Refer to
147
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
207
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
148
208
  >>> net = LeNet5()
149
209
  >>> stree = SymbolTree.create(net)
150
210
  >>> node = stree.get_node("conv1")
@@ -177,7 +237,7 @@ class Node:
177
237
  Examples:
178
238
  >>> from mindspore.rewrite import SymbolTree
179
239
  >>> # Define the network structure of LeNet5. Refer to
180
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
240
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
181
241
  >>> net = LeNet5()
182
242
  >>> stree = SymbolTree.create(net)
183
243
  >>> node = stree.get_node("relu_3")
@@ -216,7 +276,7 @@ class Node:
216
276
  Examples:
217
277
  >>> from mindspore.rewrite import SymbolTree
218
278
  >>> # Define the network structure of LeNet5. Refer to
219
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
279
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
220
280
  >>> net = LeNet5()
221
281
  >>> stree = SymbolTree.create(net)
222
282
  >>> src_node = stree.get_node("fc1")
@@ -256,7 +316,7 @@ class Node:
256
316
  Examples:
257
317
  >>> from mindspore.rewrite import SymbolTree
258
318
  >>> # Define the network structure of LeNet5. Refer to
259
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
319
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
260
320
  >>> net = LeNet5()
261
321
  >>> stree = SymbolTree.create(net)
262
322
  >>> node = stree.get_node("conv1")
@@ -276,7 +336,7 @@ class Node:
276
336
  Examples:
277
337
  >>> from mindspore.rewrite import SymbolTree
278
338
  >>> # Define the network structure of LeNet5. Refer to
279
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
339
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
280
340
  >>> net = LeNet5()
281
341
  >>> stree = SymbolTree.create(net)
282
342
  >>> node = stree.get_node("conv1")
@@ -303,7 +363,7 @@ class Node:
303
363
  Examples:
304
364
  >>> from mindspore.rewrite import SymbolTree
305
365
  >>> # Define the network structure of LeNet5. Refer to
306
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
366
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
307
367
  >>> net = LeNet5()
308
368
  >>> stree = SymbolTree.create(net)
309
369
  >>> node = stree.get_node("conv1")
@@ -326,7 +386,7 @@ class Node:
326
386
  Examples:
327
387
  >>> from mindspore.rewrite import SymbolTree
328
388
  >>> # Define the network structure of LeNet5. Refer to
329
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
389
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
330
390
  >>> net = LeNet5()
331
391
  >>> stree = SymbolTree.create(net)
332
392
  >>> node = stree.get_node("conv1")
@@ -335,6 +395,29 @@ class Node:
335
395
  """
336
396
  return self._node.get_args()
337
397
 
398
+ def get_symbol_tree(self) -> 'SymbolTree':
399
+ """
400
+ Get the symbol tree which current node belongs to.
401
+
402
+ Returns:
403
+ SymbolTree, None if current node does not belong to any SymbolTree.
404
+
405
+ Examples:
406
+ >>> from mindspore.rewrite import SymbolTree
407
+ >>> # Define the network structure of LeNet5. Refer to
408
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
409
+ >>> net = LeNet5()
410
+ >>> stree = SymbolTree.create(net)
411
+ >>> node = stree.get_node("conv1")
412
+ >>> print(type(node.get_symbol_tree()))
413
+ <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
414
+ """
415
+ from .symbol_tree import SymbolTree
416
+ stree_impl = self._node.get_belong_symbol_tree()
417
+ if not stree_impl:
418
+ return None
419
+ return SymbolTree(stree_impl)
420
+
338
421
  def get_kwargs(self) -> {str: ScopedValue}:
339
422
  return self._node.get_kwargs()
340
423