mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@ class _OffloadConfig:
33
33
  OFFLOAD_PARAM = "offload_param"
34
34
  OFFLOAD_PATH = "offload_path"
35
35
  OFFLOAD_CPU_SIZE = "offload_cpu_size"
36
+ OFFLOAD_CHECKPOINT = "offload_checkpoint"
36
37
  OFFLOAD_DISK_SIZE = "offload_disk_size"
37
38
  ENABLE_AIO = "enable_aio"
38
39
  AIO_BLOCK_SIZE = "aio_block_size"
@@ -84,6 +85,16 @@ class _OffloadContext:
84
85
  Validator.check_string(offload_param.lower(), ["cpu", "disk"], "offload_param", "set_offload_param")
85
86
  self._context_handle.set_offload_param(offload_param.lower())
86
87
 
88
+ def set_offload_checkpoint(self, offload_checkpoint):
89
+ """Set offload_checkpoint"""
90
+ if not isinstance(offload_checkpoint, str):
91
+ raise TypeError("For 'set_offload_checkpoint', "
92
+ "the argument 'offload_checkpoint' must be str, but got the type : {}."
93
+ .format(type(offload_checkpoint)))
94
+ Validator.check_string(offload_checkpoint.lower(), ["cpu", "disk"], "offload_checkpoint",
95
+ "set_offload_checkpoint")
96
+ self._context_handle.set_offload_checkpoint(offload_checkpoint.lower())
97
+
87
98
  def set_offload_path(self, offload_path):
88
99
  """Set offload_path"""
89
100
  if not isinstance(offload_path, str):
@@ -194,7 +205,8 @@ class _OffloadContext:
194
205
  _OffloadConfig.HBM_RATIO, _OffloadConfig.OFFLOAD_CPU_SIZE,
195
206
  _OffloadConfig.OFFLOAD_DISK_SIZE, _OffloadConfig.ENABLE_AIO,
196
207
  _OffloadConfig.AIO_BLOCK_SIZE, _OffloadConfig.AIO_QUEUE_DEPTH,
197
- _OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD]:
208
+ _OffloadConfig.ENABLE_PINNED_MEM, _OffloadConfig.AUTO_OFFLOAD,
209
+ _OffloadConfig.OFFLOAD_CHECKPOINT]:
198
210
  unknown_config.append(config_name)
199
211
 
200
212
  if unknown_config:
@@ -220,7 +232,8 @@ class _OffloadContext:
220
232
  _OffloadConfig.AUTO_OFFLOAD: self._context_handle.auto_offload(),
221
233
  _OffloadConfig.HOST_MEM_BLOCk_SIZE: self._context_handle.host_mem_block_size(),
222
234
  _OffloadConfig.CPU_RATIO: self._context_handle.cpu_ratio(),
223
- _OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio()
235
+ _OffloadConfig.HBM_RATIO: self._context_handle.hbm_ratio(),
236
+ _OffloadConfig.OFFLOAD_CHECKPOINT: self._context_handle.offload_checkpoint()
224
237
  }
225
238
  return offload_config
226
239
 
@@ -257,5 +270,6 @@ _set_offload_context_func_map = {
257
270
  _OffloadConfig.AUTO_OFFLOAD: offload_context().set_auto_offload,
258
271
  _OffloadConfig.HOST_MEM_BLOCk_SIZE: offload_context().set_host_mem_block_size,
259
272
  _OffloadConfig.CPU_RATIO: offload_context().set_cpu_ratio,
260
- _OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio
273
+ _OffloadConfig.HBM_RATIO: offload_context().set_hbm_ratio,
274
+ _OffloadConfig.OFFLOAD_CHECKPOINT: offload_context().set_offload_checkpoint
261
275
  }
@@ -330,8 +330,8 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
330
330
  device_list = list(range(0, np.prod(from_tensor_layout[0])))
331
331
  param_rank_list = _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id)
332
332
  param_rank_list_new = [rank % from_device_num for rank in param_rank_list]
333
- param_rank_list_new = set(param_rank_list_new)
334
- result_list.update(param_rank_list_new)
333
+ param_rank_set_new = set(param_rank_list_new)
334
+ result_list.update(param_rank_set_new)
335
335
  return list(result_list)
336
336
 
337
337
 
@@ -228,3 +228,15 @@ def _enable_distributed_mindrt():
228
228
  This method is used to distinguish from old distributed training mode.
229
229
  '''
230
230
  return ps_context().enable_distributed_mindrt()
231
+
232
+
233
+ def _set_checkpoint_load_status(status):
234
+ return ps_context().set_checkpoint_load_status(status)
235
+
236
+
237
+ def _store_warm_up_ptr_by_tensor(param_key, tensor):
238
+ return ps_context().store_warm_up_ptr_by_tensor(param_key, tensor)
239
+
240
+
241
+ def _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor):
242
+ return ps_context().store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
@@ -17,7 +17,6 @@ from __future__ import division
17
17
  from __future__ import absolute_import
18
18
 
19
19
  import numpy as np
20
-
21
20
  from mindspore.common.tensor import Tensor
22
21
  from mindspore.communication.management import get_rank, get_group_size
23
22
  from mindspore._c_expression import TensorTransform
@@ -41,7 +40,7 @@ def _get_tensor_strategy(dev_mat, tensor_map):
41
40
  if dim == -1:
42
41
  tensor_strategy.append(1)
43
42
  else:
44
- tensor_strategy.append(dev_mat[-dim-1])
43
+ tensor_strategy.append(dev_mat[-dim - 1])
45
44
  return tensor_strategy
46
45
 
47
46
 
@@ -198,7 +197,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
198
197
  return tensor_slice_index
199
198
 
200
199
 
201
- def _load_tensor(tensor, dev_mat, tensor_map):
200
+ def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
202
201
  """
203
202
  Get the tensor slice of the local device by the device matrix and the tensor map
204
203
 
@@ -216,7 +215,10 @@ def _load_tensor(tensor, dev_mat, tensor_map):
216
215
  >>> tensor_map = [1, -1]
217
216
  >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
218
217
  """
219
- rank = get_rank()
218
+ if rank_id == -1:
219
+ rank = get_rank()
220
+ else:
221
+ rank = rank_id
220
222
  tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
221
223
  tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
222
224
  np_tensor = tensor.asnumpy()
@@ -225,7 +227,7 @@ def _load_tensor(tensor, dev_mat, tensor_map):
225
227
  return np_tensor_slice
226
228
 
227
229
 
228
- def _load_tensor_by_layout(tensor, layout):
230
+ def _load_tensor_by_layout(tensor, layout, rank_id):
229
231
  """
230
232
  Load tensor by layout.
231
233
 
@@ -246,13 +248,13 @@ def _load_tensor_by_layout(tensor, layout):
246
248
  raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
247
249
  dev_mat = layout[0]
248
250
  tensor_map = layout[1]
249
- if len(tensor_map) == 0:
251
+ if not tensor_map:
250
252
  return tensor
251
253
  uniform_split = layout[4]
252
254
  group = layout[5]
253
255
  if uniform_split == 0:
254
256
  raise RuntimeError("The load tensor only support uniform split now")
255
- tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
257
+ tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, rank_id)
256
258
  if group:
257
259
  # get a totally shard tensor slice for parallel optimizer
258
260
  rank = get_rank(group)
@@ -315,7 +317,6 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
315
317
  return Tensor(tensor_slices_new[0])
316
318
 
317
319
 
318
-
319
320
  def _extract_layout_item(layout_item):
320
321
  dev_matrix = layout_item[0]
321
322
  tensor_map = layout_item[1]
@@ -541,6 +542,7 @@ def _check_operator(operator):
541
542
 
542
543
  def _apply_operator(operator_name):
543
544
  """apply transform operator"""
545
+
544
546
  def _apply_reshape_operator(numpy_data, reshape_op):
545
547
  """
546
548
  Apply reshape operator.
@@ -597,8 +599,8 @@ def _apply_operator(operator_name):
597
599
  raise ValueError("The slice operator information is wrong.")
598
600
  shape_size = len(slice_op[1]) // 3
599
601
  begin = slice_op[1][:shape_size]
600
- end = slice_op[1][shape_size:shape_size*2]
601
- stride = slice_op[1][shape_size*2:]
602
+ end = slice_op[1][shape_size:shape_size * 2]
603
+ stride = slice_op[1][shape_size * 2:]
602
604
  slice_index = []
603
605
  for begin_i, end_i, strides_i in zip(begin, end, stride):
604
606
  s = slice(begin_i, end_i, strides_i)
@@ -637,8 +639,8 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
637
639
  for i in range(len(tensor_slices[0][0])):
638
640
  tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
639
641
  for j in range(1, device_count):
640
- tensor_slices_new = np.concatenate((tensor_slices_new,\
641
- np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
642
+ tensor_slices_new = np.concatenate((tensor_slices_new, \
643
+ np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
642
644
  tensor_slices_col.append(tensor_slices_new)
643
645
  new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
644
646
  for i in range(1, len(tensor_slices_col)):
@@ -424,9 +424,11 @@ class _Linear(Cell):
424
424
  self.out_channels = out_channels
425
425
  if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
426
426
  raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
427
- if isinstance(weight_init, Tensor) and (weight_init.ndim != 2 or weight_init.shape[0] != out_channels or
428
- weight_init.shape[1] != in_channels):
429
- raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
427
+
428
+ if isinstance(weight_init, Tensor):
429
+ if weight_init.ndim != 2 or weight_init.shape[0] != out_channels \
430
+ or weight_init.shape[1] != in_channels:
431
+ raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
430
432
  weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
431
433
  self.expert_num = expert_num
432
434
  self.outer_batch = outer_batch
@@ -139,6 +139,7 @@ class _NLLLoss(Cell):
139
139
  self.add = P.Add().shard(((dp, mp), ()))
140
140
 
141
141
  def construct(self, softmax_result, one_hot_label):
142
+ """The forward of _NLLLoss"""
142
143
  log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
143
144
  loss = self.mul(log_softmax_result, one_hot_label)
144
145
  loss_unsum = self.neg(loss)
@@ -273,7 +273,7 @@ class MoE(Cell):
273
273
  if self.group_wise_a2a:
274
274
  # If capacity can't div by mp, pad for mp shard.
275
275
  if capacity % self.mp != 0:
276
- pad_size = self.mp-(capacity % self.mp)
276
+ pad_size = self.mp - (capacity % self.mp)
277
277
  if pad_size != 0:
278
278
  capacity += pad_size
279
279
  pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
@@ -330,7 +330,7 @@ class MoE(Cell):
330
330
  # Pad capacity for comp_comm_parallel_degree split.
331
331
  pad_size = 0
332
332
  if capacity % self.comp_comm_parallel_degree != 0:
333
- pad_size = self.comp_comm_parallel_degree-(capacity % self.comp_comm_parallel_degree)
333
+ pad_size = self.comp_comm_parallel_degree - (capacity % self.comp_comm_parallel_degree)
334
334
  capacity += pad_size
335
335
  pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
336
336
  (self.expert_dim, self.dp_group, pad_size, self.hidden_size),
@@ -147,10 +147,11 @@ class _PipeLineConfig(_Config):
147
147
  >>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
148
148
  """
149
149
 
150
- def __init__(self, pipeline_stage=1, micro_batch_num=1):
150
+ def __init__(self, pipeline_stage=1, micro_batch_num=1, pipeline_segment=1):
151
151
  Validator.check_positive_int(pipeline_stage, "pipeline_stage")
152
152
  Validator.check_positive_int(micro_batch_num, "micro_batch_num")
153
153
  self.pipeline_stage = pipeline_stage
154
+ self.pipeline_segment = pipeline_segment
154
155
  self.micro_batch_num = micro_batch_num
155
156
 
156
157
  @property
@@ -163,6 +164,16 @@ class _PipeLineConfig(_Config):
163
164
  self._pipeline_stage = value
164
165
  context.set_auto_parallel_context(pipeline_stages=value)
165
166
 
167
+ @property
168
+ def pipeline_segment(self):
169
+ return self._pipeline_segment
170
+
171
+ @pipeline_segment.setter
172
+ def pipeline_segment(self, value):
173
+ Validator.check_positive_int(value, "pipeline_segment")
174
+ self._pipeline_segment = value
175
+ context.set_auto_parallel_context(pipeline_segments=value)
176
+
166
177
  @property
167
178
  def micro_batch_num(self):
168
179
  return self._micro_batch_num
@@ -226,7 +226,8 @@ class TransformerOpParallelConfig(_Config):
226
226
  >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
227
227
  """
228
228
 
229
- def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, micro_batch_num=1,
229
+ def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, pipeline_segment=1,
230
+ micro_batch_num=1,
230
231
  recompute=default_transformer_recompute_config,
231
232
  optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
232
233
  self.recompute = recompute
@@ -234,7 +235,8 @@ class TransformerOpParallelConfig(_Config):
234
235
  self.gradient_aggregation_group = gradient_aggregation_group
235
236
  self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
236
237
  vocab_emb_dp=vocab_emb_dp)
237
- self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
238
+ self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num,
239
+ pipeline_segment=pipeline_segment)
238
240
  self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
239
241
  expert_parallel=expert_parallel)
240
242
 
@@ -309,6 +311,14 @@ class TransformerOpParallelConfig(_Config):
309
311
  def pipeline_stage(self, value):
310
312
  self._pp_config.pipeline_stage = value
311
313
 
314
+ @property
315
+ def pipeline_segment(self):
316
+ return self._pp_config.pipeline_segment
317
+
318
+ @pipeline_segment.setter
319
+ def pipeline_segment(self, value):
320
+ self._pp_config.pipeline_segment = value
321
+
312
322
  @property
313
323
  def optimizer_shard(self):
314
324
  return self._optimizer_shard
@@ -429,6 +439,7 @@ class FeedForward(Cell):
429
439
  >>> print(output.shape)
430
440
  (2, 20, 15)
431
441
  """
442
+
432
443
  @_LogActionOnce(logger=logger, key='FeedForward',
433
444
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
434
445
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -622,6 +633,7 @@ class AttentionMask(Cell):
622
633
  [1. 1. 1. 0]
623
634
  [0. 0. 0. 0]]]
624
635
  """
636
+
625
637
  @_LogActionOnce(logger=logger, key='AttentionMask',
626
638
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
627
639
  @_args_type_validator_check(seq_length=Validator.check_positive_int,
@@ -710,6 +722,7 @@ class VocabEmbedding(Cell):
710
722
  >>> print(table.shape)
711
723
  (30, 30)
712
724
  """
725
+
713
726
  @_LogActionOnce(logger=logger, key='VocabEmbedding',
714
727
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
715
728
  @_args_type_validator_check(vocab_size=Validator.check_positive_int,
@@ -866,6 +879,7 @@ class MultiHeadAttention(Cell):
866
879
  >>> print(past[1].shape)
867
880
  (2, 3, 20, 5)
868
881
  """
882
+
869
883
  @_LogActionOnce(logger=logger, key='MultiHeadAttention',
870
884
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
871
885
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -1203,7 +1217,8 @@ class MultiHeadAttention(Cell):
1203
1217
  def _get_batch_size_from_query(self, query):
1204
1218
  r"""Get the batch size from query tensor"""
1205
1219
  # For the incremental prediction, the seq length for the input is 1.
1206
- if len(F.shape(query)) == 2 and ((self.use_past and self.is_first_iteration) or (not self.use_past)):
1220
+ incr_infer = self.use_past and self.is_first_iteration
1221
+ if len(F.shape(query)) == 2 and ((incr_infer) or (not self.use_past)):
1207
1222
  return F.shape(query)[0] // self.src_seq_length
1208
1223
  return F.shape(query)[0]
1209
1224
 
@@ -1459,6 +1474,7 @@ class TransformerEncoderLayer(Cell):
1459
1474
  >>> print(past[1].shape)
1460
1475
  (2, 2, 16, 4)
1461
1476
  """
1477
+
1462
1478
  @_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
1463
1479
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1464
1480
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -1848,6 +1864,7 @@ class TransformerDecoderLayer(Cell):
1848
1864
  >>> print(past[3].shape)
1849
1865
  (2, 2, 20, 32)
1850
1866
  """
1867
+
1851
1868
  @_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
1852
1869
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1853
1870
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
@@ -2379,6 +2396,7 @@ class TransformerEncoder(Cell):
2379
2396
  >>> print(past[0][1].shape)
2380
2397
  (2, 2, 16, 4)
2381
2398
  """
2399
+
2382
2400
  @_LogActionOnce(logger=logger, key='TransformerEncoder',
2383
2401
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2384
2402
  @_args_type_validator_check(batch_size=Validator.check_positive_int,
@@ -2613,6 +2631,7 @@ class TransformerDecoder(Cell):
2613
2631
  >>> print(past[0][3].shape)
2614
2632
  (2, 2, 20, 32)
2615
2633
  """
2634
+
2616
2635
  @_LogActionOnce(logger=logger, key='TransformerDecoder',
2617
2636
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2618
2637
  @_args_type_validator_check(batch_size=Validator.check_positive_int,
@@ -2882,6 +2901,7 @@ class Transformer(Cell):
2882
2901
  >>> print(de_past[0][3].shape)
2883
2902
  (2, 2, 20, 32)
2884
2903
  """
2904
+
2885
2905
  @_LogActionOnce(logger=logger, key='Transformer',
2886
2906
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2887
2907
  @_args_type_validator_check(batch_size=Validator.check_positive_int,
@@ -100,13 +100,14 @@ def _slice_parameter(parameter, phase, layout):
100
100
  parameter.sliced = True
101
101
  return
102
102
  if not parameter.sliced:
103
- new_tensor = _load_tensor_by_layout(parameter, layout)
103
+ rank = get_rank()
104
+ new_tensor = _load_tensor_by_layout(parameter, layout, rank)
104
105
  parameter.set_data(new_tensor, True)
105
106
 
106
107
 
107
- def _slice_tensor(tensor, layout):
108
+ def _slice_tensor(tensor, layout, rank_id):
108
109
  """Slice python tensor obj according to the layout."""
109
- new_tensor = _load_tensor_by_layout(tensor, layout)
110
+ new_tensor = _load_tensor_by_layout(tensor, layout, rank_id)
110
111
  return new_tensor
111
112
 
112
113
 
@@ -136,14 +137,17 @@ def _to_full_shapes(shapes, device_num):
136
137
  "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
137
138
  new_shape = ()
138
139
  for i, item in enumerate(shape):
139
- new_shape += (item * dataset_strategy[index][i],)
140
+ if item > 0:
141
+ new_shape += (item * dataset_strategy[index][i],) # static shape
142
+ else:
143
+ new_shape += (item,) # dynamic shape
140
144
  new_shapes.append(new_shape)
141
145
  return new_shapes
142
146
  for shape in shapes:
143
147
  new_shape = ()
144
148
  for i, item in enumerate(shape):
145
- if i == 0:
146
- new_shape += (item * device_num,)
149
+ if i == 0 and item > 0:
150
+ new_shape += (item * device_num,) # only for static shape
147
151
  else:
148
152
  new_shape += (item,)
149
153
  new_shapes.append(new_shape)
@@ -201,7 +205,7 @@ def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
201
205
  slice_index += (s,)
202
206
  new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
203
207
  new_tensor_numpy[slice_index] = data.asnumpy()
204
- new_tensor = Tensor(new_tensor_numpy)
208
+ new_tensor = Tensor(new_tensor_numpy, dtype=type_)
205
209
  lst.append(new_tensor)
206
210
  if scaling_sens:
207
211
  lst.append(Tensor(scaling_sens, mstype.float32))
@@ -229,7 +229,7 @@ def set_algo_parameters(**kwargs):
229
229
  """
230
230
  Set parameters in the algorithm for parallel strategy searching. See a typical use in
231
231
  `test_auto_parallel_resnet.py
232
- <https://gitee.com/mindspore/mindspore/blob/r2.1/tests/ut/python/parallel/test_auto_parallel_resnet.py>`_.
232
+ <https://gitee.com/mindspore/mindspore/blob/r2.2/tests/ut/python/parallel/test_auto_parallel_resnet.py>`_.
233
233
 
234
234
  Note:
235
235
  The attribute name is required. This interface works ONLY in AUTO_PARALLEL mode.
@@ -239,10 +239,10 @@ def set_algo_parameters(**kwargs):
239
239
  Default: ``True`` . For example with 8 devices available, if set ``True`` , strategy (4, 1) will not be
240
240
  included in ReLU's candidate strategies, because strategy (4, 1) only utilizes 4 devices.
241
241
  elementwise_op_strategy_follow (bool): Whether the elementwise operator has the consistent strategies as its
242
- subsequent operators. Default: ``False`` . For the example of ReLU followed by Add, where ReLU is
243
- elementwise operator, if this flag is set ``True`` , then the searched strategy by the algorithm
244
- guarantees that strategies of these two operators are consistent, e.g., ReLU's strategy (8, 1) and Add's
245
- strategy ((8, 1), (8, 1)).
242
+ subsequent operators. Elementwise operators refer to operators that operate on input element by element,
243
+ such as Add, ReLU, etc. Default: ``False`` . For the example of ReLU followed by Add, if this flag is set
244
+ ``True`` , then the searched strategy by the algorithm guarantees that strategies of these two operators
245
+ are consistent, e.g., ReLU's strategy (8, 1) and Add's strategy ((8, 1), (8, 1)).
246
246
  enable_algo_approxi (bool): Whether to enable the approximation in the algorithms. Default: ``False`` . Due to
247
247
  large solution space in searching parallel strategy for large DNN model, the algorithm takes fairly long
248
248
  time in this case. To mitigate it, if this flag is set ``True`` , an approximation is made to discard some
@@ -261,8 +261,87 @@ def set_algo_parameters(**kwargs):
261
261
  ValueError: If context keyword is not recognized.
262
262
 
263
263
  Examples:
264
+ .. note::
265
+ Before running the following examples, you need to configure the communication environment variables.
266
+
267
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
268
+ Please see the `rank table startup
269
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
270
+ for more details.
271
+
272
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
273
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
274
+
275
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
276
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
277
+
278
+ >>> import numpy as np
264
279
  >>> import mindspore as ms
280
+ >>> import mindspore.dataset as ds
281
+ >>> from mindspore import nn, ops, train
282
+ >>> from mindspore.communication import init
283
+ >>> from mindspore.common.initializer import initializer
284
+ >>>
285
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
286
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
287
+ >>> search_mode="sharding_propagation")
288
+ >>> init()
289
+ >>> ms.set_algo_parameters(fully_use_devices=True)
265
290
  >>> ms.set_algo_parameters(elementwise_op_strategy_follow=True)
291
+ >>> ms.set_algo_parameters(enable_algo_approxi=True)
292
+ >>> ms.set_algo_parameters(algo_approxi_epsilon=0.2)
293
+ >>> ms.set_algo_parameters(tensor_slice_align_enable=True)
294
+ >>> ms.set_algo_parameters(tensor_slice_align_size=8)
295
+ >>>
296
+ >>> # Define the network structure.
297
+ >>> class Dense(nn.Cell):
298
+ ... def __init__(self, in_channels, out_channels):
299
+ ... super().__init__()
300
+ ... self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
301
+ ... self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
302
+ ... self.matmul = ops.MatMul()
303
+ ... self.add = ops.Add()
304
+ ...
305
+ ... def construct(self, x):
306
+ ... x = self.matmul(x, self.weight)
307
+ ... x = self.add(x, self.bias)
308
+ ... return x
309
+ >>>
310
+ >>> class FFN(nn.Cell):
311
+ ... def __init__(self):
312
+ ... super().__init__()
313
+ ... self.flatten = ops.Flatten()
314
+ ... self.dense1 = Dense(28*28, 64)
315
+ ... self.relu = ops.ReLU()
316
+ ... self.dense2 = Dense(64, 10)
317
+ ...
318
+ ... def construct(self, x):
319
+ ... x = self.flatten(x)
320
+ ... x = self.dense1(x)
321
+ ... x = self.relu(x)
322
+ ... x = self.dense2(x)
323
+ ... return x
324
+ >>> net = FFN()
325
+ >>> net.dense1.matmul.shard(((2, 1), (1, 2)))
326
+ >>>
327
+ >>> # Create dataset.
328
+ >>> step_per_epoch = 16
329
+ >>> def get_dataset(*inputs):
330
+ ... def generate():
331
+ ... for _ in range(step_per_epoch):
332
+ ... yield inputs
333
+ ... return generate
334
+ >>>
335
+ >>> input_data = np.random.rand(1, 28, 28).astype(np.float32)
336
+ >>> label_data = np.random.rand(1).astype(np.int32)
337
+ >>> fake_dataset = get_dataset(input_data, label_data)
338
+ >>> dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
339
+ >>> # Train network.
340
+ >>> optimizer = nn.Momentum(net.trainable_params(), 1e-3, 0.1)
341
+ >>> loss_fn = nn.CrossEntropyLoss()
342
+ >>> loss_cb = train.LossMonitor()
343
+ >>> model = ms.Model(network=net, loss_fn=loss_fn, optimizer=optimizer)
344
+ >>> model.train(epoch=2, train_dataset=dataset, callbacks=[loss_cb])
266
345
  """
267
346
  for key, value in kwargs.items():
268
347
  if key not in set_algo_parameters_config_func_map:
@@ -282,6 +361,7 @@ def get_algo_parameters(attr_key):
282
361
  attr_key (str): The key of the attribute. The keys include: "fully_use_devices",
283
362
  "elementwise_op_strategy_follow", "enable_algo_approxi", "algo_approxi_epsilon",
284
363
  "tensor_slice_align_enable","tensor_slice_align_size".
364
+ See :func:`mindspore.set_algo_parameters` for more details about the meaning of the attributes.
285
365
 
286
366
  Returns:
287
367
  Return attribute value according to the key.
@@ -35,8 +35,7 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
35
35
  """
36
36
  Merge parallel strategy between all pipeline stages in pipeline parallel mode.
37
37
  For more details about converting distributed Checkpoint, please refer to
38
- `Distributed Resilience Training and
39
- Inference <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/resilience_train_and_predict.html>`_.
38
+ `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/model_transformation.html>`_.
40
39
 
41
40
  Note:
42
41
  Strategy file of each pipeline stage should be included in src_strategy_dirs.
@@ -76,17 +75,16 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
76
75
  """
77
76
  List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id
78
77
  during the distributed checkpoint conversion. For more details about converting distributed Checkpoint,
79
- please refer to `Distributed Resilience Training and
80
- Inference <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/resilience_train_and_predict.html>`_.
78
+ please refer to `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/model_transformation.html>`_.
81
79
 
82
80
  Args:
83
81
  rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
84
82
  src_strategy_file (str): Name of source sharding strategy file which saved by
85
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
83
+ `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
86
84
  when the 'src_strategy_file' is None, it means that the source sharding strategy is
87
85
  without any sharing for each parameter. Default:None.
88
86
  dst_strategy_file (str): Name of destination sharding strategy file which saved by
89
- 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
87
+ `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
90
88
  when the 'dst_strategy_file' is None, it means that the destination sharding strategy
91
89
  is without any sharing for each parameter. Default:None.
92
90
 
@@ -139,8 +137,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
139
137
  """
140
138
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
141
139
  for a network. For more details about converting distributed Checkpoint, please refer to
142
- `Distributed Resilience Training and
143
- Inference <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/resilience_train_and_predict.html>`_.
140
+ `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/model_transformation.html>`_.
144
141
 
145
142
  Args:
146
143
  rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
@@ -224,8 +221,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
224
221
  """
225
222
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
226
223
  For more details about converting distributed Checkpoint, please refer to
227
- `Distributed Resilience Training and
228
- Inference <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/resilience_train_and_predict.html>`_.
224
+ `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/model_transformation.html>`_.
229
225
 
230
226
  Note:
231
227
  The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
@@ -24,7 +24,7 @@ class Shard(Shard_):
24
24
 
25
25
  def __init__(self):
26
26
  """Initialize Shard."""
27
- Shard_.__init__(self, 'Shard')
27
+ super().__init__('Shard')
28
28
  self.shard_fn = None
29
29
  self.fn = None
30
30
  self.in_strategy = None
@@ -159,8 +159,8 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
159
159
 
160
160
  Note:
161
161
  You need to set the execution mode to PyNative mode,
162
- set the parallel mode in `set_auto_parallel_context` to "auto_parallel"
163
- and the search mode to "sharding_propagation".
162
+ set the parallel mode in `set_auto_parallel_context` (parallel_mode) to "auto_parallel"
163
+ and the search mode (search_mode) to "sharding_propagation".
164
164
  If the input contain Parameter, its strategy should be set in `in_strategy`.
165
165
 
166
166
  Args:
@@ -224,7 +224,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
224
224
 
225
225
  Tutorial Examples:
226
226
  - `Functional Operator Sharding
227
- <https://www.mindspore.cn/docs/en/r2.1/api_python/samples/mindspore/pynative_shard_function_parallel.html>`_
227
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/pynative_shard_function_parallel.html>`_
228
228
  """
229
229
  if not isinstance(fn, (ms.nn.Cell)):
230
230
  logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
@@ -72,7 +72,7 @@ class StructType(Enum):
72
72
  """
73
73
  Parse the binary data to get the unpacked data.
74
74
 
75
- Args
75
+ Args:
76
76
  data_struct (dict): Key is the data name, value is StructType.
77
77
  binary_data (str): This value should be a binary string.
78
78
  special_func (Callable): This is a callable function,
@@ -105,14 +105,14 @@ class StructType(Enum):
105
105
  for name, data_type in data_struct.items():
106
106
  data_size = StructType.sizeof(data_type)
107
107
  if special_process_func:
108
- unpack_data, success = special_process_func(binary_data[cursor:cursor+data_size], name,
108
+ unpack_data, success = special_process_func(binary_data[cursor:cursor + data_size], name,
109
109
  data_type, unpacked_data)
110
110
  if success:
111
111
  cursor += data_size
112
112
  unpacked_data[name] = unpack_data
113
113
  continue
114
114
 
115
- unpack_data = struct.unpack(data_type.value, binary_data[cursor: cursor+data_size])[0]
115
+ unpack_data = struct.unpack(data_type.value, binary_data[cursor: cursor + data_size])[0]
116
116
  cursor += data_size
117
117
  unpacked_data[name] = unpack_data
118
118
  return unpacked_data