mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """parallel serialization"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore import context
19
+ from mindspore.nn.cell import Cell
20
+ from mindspore.ops import functional as F, composite as C, operations as P
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.common.sparse_tensor import Tensor
23
+ from mindspore.common.api import jit
24
+ from mindspore.common.parameter import Parameter
25
+ from mindspore.nn.layer import Identity
26
+ from mindspore.parallel._utils import _get_enable_parallel_optimizer
27
+
28
+ __all__ = ['PipelineGradReducer']
29
+
30
+
31
+ grad_scale = C.MultitypeFuncGraph("grad_scale")
32
+ shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
33
+ reciprocal = P.Reciprocal()
34
+
35
+
36
+ @grad_scale.register("Tensor", "Tensor", "Tensor")
37
+ def tensor_grad_scale_pipeline(scale, grad, accu_grad):
38
+ accu_grad = F.depend(accu_grad, grad)
39
+ new_grad = accu_grad * reciprocal(scale)
40
+ accu_grad = F.depend(accu_grad, new_grad)
41
+ zeros = F.tensor_mul(accu_grad, 0.0)
42
+ new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
43
+ return new_grad
44
+
45
+
46
+ @shard_grad_scale.register("Tensor", "Tensor", "Tensor")
47
+ def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
48
+ new_grad = grad * reciprocal(scale)
49
+ accu_grad = F.depend(accu_grad, new_grad)
50
+ new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
51
+ return new_grad
52
+
53
+
54
+ class PipelineGradReducer(Cell):
55
+ """
56
+ Functional training scenarios for gradient statute and accumulation of pipeline parallel.
57
+
58
+ Args:
59
+ parameters (list): the parameters to be updated.
60
+ scale_sense (float, optional): the scale sense of the gradient. Default: 1.0.
61
+ opt_shard(bool, optional): if use parallel optimizer, set opt_shard True. Default: ``None``.
62
+
63
+ Raise:
64
+ RuntimeError: If the mode is not graph mode.
65
+
66
+ Supported Platforms:
67
+ ``Ascend`` ``GPU``
68
+
69
+ Examples:
70
+ .. note::
71
+ Before running the following examples, you need to configure the communication environment variables.
72
+
73
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
74
+ Please see the `rank table Startup
75
+ <https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
76
+ for more details.
77
+
78
+ This example should be run with multiple devices.
79
+
80
+ >>> import numpy as np
81
+ >>> import mindspore as ms
82
+ >>> from mindspore import nn, ops, Tensor
83
+ >>> from mindspore.communication import init
84
+ >>>
85
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
86
+ >>> ms.reset_auto_parallel_context()
87
+ >>> init()
88
+ >>> ms.set_seed(1)
89
+ >>>
90
+ >>> class Network(nn.Cell):
91
+ ... def __init__(self, in_features, out_features, sens=1.0):
92
+ ... super().__init__()
93
+ ... self.layer1 = nn.Dense(in_features, 16)
94
+ ... self.relu1 = nn.ReLU()
95
+ ... self.layer2 = nn.Dense(16, 16)
96
+ ... self.relu2 = nn.ReLU()
97
+ ... self.layer3 = nn.Dense(16, out_features)
98
+ ...
99
+ ... def construct(self, x):
100
+ ... x = self.layer1(x)
101
+ ... x = self.relu1(x)
102
+ ... x = self.layer2(x)
103
+ ... x = self.relu2(x)
104
+ ... logits = self.layer3(x)
105
+ ... return logits
106
+ >>>
107
+ >>> size, in_features, out_features = 16, 32, 10
108
+ >>> net = Network(in_features, out_features)
109
+ >>> net.layer1.pipeline_stage = 0
110
+ >>> net.relu1.pipeline_stage = 0
111
+ >>> net.layer2.pipeline_stage = 0
112
+ >>> net.relu2.pipeline_stage = 1
113
+ >>> net.layer3.pipeline_stage = 1
114
+ >>> loss_fn = nn.CrossEntropyLoss()
115
+ >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
116
+ >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 2)
117
+ >>> net_with_loss.set_train()
118
+ >>> def forward_fn(inputs, target):
119
+ ... loss = net_with_loss(inputs, target)
120
+ ... return loss
121
+ >>>
122
+ >>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params())
123
+ >>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters)
124
+ >>>
125
+ >>> @ms.jit
126
+ >>> def train_one_step(inputs, target):
127
+ ... loss, grads = grad_fn(inputs, target)
128
+ ... grads = pp_grad_reducer(grads)
129
+ ... optimizer(grads)
130
+ ... return loss, grads
131
+ >>>
132
+ >>> parallel_net = AutoParallel(train_one_step, parallel_mode="semi_auto")
133
+ >>> parallel_net.pipeline(stages=2)
134
+ >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
135
+ >>> label = Tensor(np.ones([size, out_features]).astype(np.float32))
136
+ >>> loss, _ = train_one_step(inputs, label)
137
+ >>> print(loss)
138
+ 46.36721
139
+ """
140
+ def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
141
+ super(PipelineGradReducer, self).__init__(auto_prefix=False)
142
+ self._check_mode()
143
+ self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
144
+ self.grad_reducer = Identity()
145
+ self.degree = Tensor(1, mstype.float32)
146
+ self.scale_sense = Parameter(scale_sense, name='scale_sense')
147
+ self.hyper_map = C.HyperMap()
148
+ if opt_shard is None:
149
+ self.opt_shard = _get_enable_parallel_optimizer()
150
+ else:
151
+ self.opt_shard = opt_shard
152
+
153
+ @jit
154
+ def construct(self, grads):
155
+ new_grads = None
156
+ if self.opt_shard:
157
+ grads = self.grad_reducer(grads)
158
+ new_grads = self.hyper_map(F.partial(shard_grad_scale, self.scale_sense * self.degree),
159
+ grads, self.accu_grads)
160
+ else:
161
+ accu_grads = self.grad_reducer(self.accu_grads)
162
+ new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
163
+ return new_grads
164
+
165
+ def _check_mode(self):
166
+ """check parallel mode"""
167
+ mode = context.get_context('mode')
168
+ if mode != context.GRAPH_MODE:
169
+ raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
@@ -19,7 +19,10 @@ __all__ = ["parameter_broadcast"]
19
19
 
20
20
  import numpy as np
21
21
  import mindspore as ms
22
- from mindspore.communication import get_rank, create_group, get_group_size
22
+ from mindspore.communication import create_group, get_group_size
23
+ from mindspore.parallel._utils import _get_auto_parallel_net, _parallel_mode_map, _check_rank
24
+ # disable pylint too broad Exception
25
+ # pylint: disable=W0212
23
26
 
24
27
 
25
28
  def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
@@ -34,7 +37,8 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
34
37
  layout (Dict): Parameter layout dictionary. Come from
35
38
  :func:`mindspore.nn.Cell.parameter_layout_dict`
36
39
  or read from file(for example: "strategy.ckpt" saved by using the
37
- `strategy_ckpt_config` parameter of :func:`mindspore.set_auto_parallel_context`).
40
+ `strategy_ckpt_config` parameter of
41
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.save_param_strategy_file` ).
38
42
  The key is param name, the value is the layout of this parameter.
39
43
  cur_rank (int, optional): current rank id. Default: ``0``.
40
44
  initial_rank (int, optional): Start rank id for each pipeline. Default: ``0``.
@@ -45,6 +49,9 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
45
49
  ValueError: Parameter name in `layout` can not be found in
46
50
  :func:`mindspore.nn.Cell.parameters_dict`.
47
51
 
52
+ Supported Platforms:
53
+ ``Ascend``
54
+
48
55
  Examples:
49
56
  >>> import os
50
57
  >>> import mindspore as ms
@@ -53,11 +60,11 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
53
60
  >>> from mindspore.communication import init
54
61
  >>> from mindspore.common.initializer import initializer
55
62
  >>> from mindspore.train import Model
56
- >>> from mindspore.parallel.parameter_broadcast import parameter_broadcast
57
63
  >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
64
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
65
+ >>> from mindspore.parallel import parameter_broadcast
58
66
  >>> ms.set_context(mode=ms.GRAPH_MODE)
59
- >>> ms.set_context(max_device_memory="28GB")
60
- >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
67
+ >>> ms.runtime.set_memory(max_size="28GB")
61
68
  >>> init()
62
69
  >>> ms.set_seed(1)
63
70
  >>> class Network(nn.Cell):
@@ -90,7 +97,8 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
90
97
  >>> dataset = create_dataset()
91
98
  >>> optim = nn.SGD(net.trainable_params(), 1e-2)
92
99
  >>> loss = nn.CrossEntropyLoss()
93
- >>> model = Model(net, loss_fn=loss, optimizer=optim)
100
+ >>> parallel_net = AutoParallel(net)
101
+ >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim)
94
102
  >>> model.train(1, dataset)
95
103
  >>> ms.save_checkpoint(net, "./simple.ckpt", False)
96
104
  >>> layout = model.train_network.parameter_layout_dict
@@ -104,17 +112,20 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
104
112
  ... print("step end, cur step num: ", cb_params.cur_step_num, flush=True)
105
113
  >>> model.train(1, dataset, callbacks=[LossCallBack()])
106
114
  """
107
- if not layout:
115
+ if not layout or get_group_size() <= 1:
108
116
  return
109
117
  from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
110
118
  from mindspore.nn.wrap.cell_wrapper import AllreduceGraph
111
- origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode")
112
- if origin_parallel_mode not in ("semi_auto_parallel", "auto_parallel"):
113
- return
114
- if cur_rank != get_rank():
115
- raise ValueError(f"For parameter broadcast, the cur_rank: {cur_rank} is wrong.")
116
- if initial_rank % (get_group_size() / ms.get_auto_parallel_context("pipeline_stages")) != 0:
117
- raise ValueError(f"For parameter broadcast, the initial_rank: {initial_rank} is wrong.")
119
+ origin_parallel_mode = ""
120
+ pipeline_stages = 1
121
+ parallel_net = _get_auto_parallel_net(net)
122
+ if type(parallel_net).__name__ == 'AutoParallel':
123
+ origin_parallel_mode = _parallel_mode_map(parallel_net._parallel_mode)
124
+ pipeline_stages = parallel_net._pipeline_stages
125
+ else:
126
+ origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode")
127
+ pipeline_stages = ms.get_auto_parallel_context("pipeline_stages")
128
+ _check_rank(cur_rank, initial_rank, pipeline_stages)
118
129
  param_redundancy = get_parameter_redundancy(layout, initial_rank)
119
130
  if not param_redundancy:
120
131
  return
@@ -15,14 +15,77 @@
15
15
  """shard"""
16
16
 
17
17
  import copy
18
+ import numpy as np
18
19
  import mindspore as ms
19
20
  from mindspore import log as logger
20
21
  from mindspore._c_expression import Shard_
21
22
 
22
23
 
24
+ class _DistributedTensorInfo:
25
+ """
26
+ Describe the distributed information of a tensor.
27
+
28
+ Args:
29
+ distributed_info (Union[Layout, DeviceMesh]): The distributed information of a tensor.
30
+
31
+ Raises:
32
+ TypeError: If `distributed_info` is not a Layout type.
33
+
34
+ Examples:
35
+ >>> from mindspore import _DistributedTensorInfo, Layout
36
+ >>> layout = Layout((2, 2), ("dp", "mp"))
37
+ >>> src_layout = layout("dp", "mp")
38
+ >>> distributed_info = _DistributedTensorInfo(src_layout)
39
+ >>> print(distributed_info.sharding_strategy)
40
+ [2, 2]
41
+ """
42
+
43
+ def __init__(self, distributed_info):
44
+ if isinstance(distributed_info, Layout):
45
+ self._layout = distributed_info
46
+ self._distributed_info = distributed_info
47
+ else:
48
+ raise TypeError(
49
+ f"DistributedTensorInfo only supports Layout or DeviceMesh as input, but got {type(distributed_info)}")
50
+ self._sharding_strategy = None
51
+
52
+ @property
53
+ def layout(self):
54
+ """return layout of current tensor"""
55
+ return self._layout
56
+
57
+ @property
58
+ def distributed_info(self):
59
+ """return the distributed info, it depends on user's input """
60
+ return self._distributed_info
61
+
62
+ @property
63
+ def sharding_strategy(self):
64
+ """return the sharding strategy of current tensor"""
65
+ if self._sharding_strategy is None:
66
+ layout_info = self._layout.to_dict()
67
+ device_matrix = layout_info["device_matrix"]
68
+ tensor_map = layout_info["tensor_map"]
69
+ sharding_strategy = []
70
+ for map_value in tensor_map:
71
+ if isinstance(map_value, (tuple, list)):
72
+ shard_size = 1
73
+ for value in map_value:
74
+ if value != -1:
75
+ shard_size *= device_matrix[len(device_matrix) - value - 1]
76
+ sharding_strategy.append(shard_size)
77
+ else:
78
+ if map_value != -1:
79
+ sharding_strategy.append(device_matrix[len(device_matrix) - map_value - 1])
80
+ else:
81
+ sharding_strategy.append(1)
82
+ self._sharding_strategy = sharding_strategy
83
+ return self._sharding_strategy
84
+
85
+
23
86
  class Layout:
24
87
  """
25
- Parallel layout describes the detailed sharding information.
88
+ Topological abstraction describing cluster devices for tensor slice placement on the cluster.
26
89
 
27
90
  Note:
28
91
  - It is valid only in semi auto parallel or auto parallel mode.
@@ -35,28 +98,36 @@ class Layout:
35
98
  alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
36
99
  When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
37
100
  copies on the corresponding partition dimension on a single card.
101
+ rank_list (list, optional): Data is allocated to the device according to rank_list. Default: ``None``.
102
+
38
103
  Raises:
39
104
  TypeError: `device_matrix` is not a tuple type.
40
105
  TypeError: `alias_name` is not a tuple type.
106
+ TypeError: 'rank_list' is not a list type.
41
107
  ValueError: `device_matrix` length is not equal to `alias_name` length.
42
108
  TypeError: The element of `device_matrix` is not int type.
43
109
  TypeError: The element of `alias_name` is not a str type.
110
+ TypeError: The element of `rank_list` is not int type.
44
111
  ValueError: The element of `alias_name` is an empty str.
45
112
  ValueError: The element of `alias_name` is "None".
46
113
  ValueError: `alias_name` contains repeated element.
47
114
 
115
+ Supported Platforms:
116
+ ``Ascend``
117
+
48
118
  Examples:
49
- >>> from mindspore import Layout
119
+ >>> from mindspore.parallel import Layout
50
120
  >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
51
121
  >>> layout0 = layout("dp", "mp")
52
122
  >>> print(layout0.to_dict())
53
- {"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False}
123
+ {"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False,
124
+ 'alias_name': {'dp', 'sp', 'mp'}, "rank_list": [0, 1, 2, 3, 4, 1, 6, 7]}
54
125
  >>> # Total device num is 4, but split the tensor in local device into two copies.
55
126
  >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
56
127
  >>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
57
128
  """
58
129
 
59
- def __init__(self, device_matrix, alias_name):
130
+ def __init__(self, device_matrix, alias_name, rank_list=None):
60
131
  if not isinstance(device_matrix, tuple):
61
132
  raise TypeError(f'device_matrix must be tuple type, but got:{type(device_matrix)}')
62
133
  if not isinstance(alias_name, tuple):
@@ -82,6 +153,20 @@ class Layout:
82
153
  self._device_shape = device_matrix
83
154
  self._alias_name = alias_name
84
155
  self._tensor_map = None
156
+ self._rank_list = list(range(np.prod(np.array(self._device_shape))))
157
+ if rank_list is not None:
158
+ if not isinstance(rank_list, list):
159
+ raise TypeError(f"The rank_list should be a list, but got {type(rank_list).__name__}.")
160
+ for in_ele in rank_list:
161
+ if not isinstance(in_ele, int):
162
+ raise TypeError(f"The element of rank_list should be int, but got {type(in_ele).__name__}.")
163
+ if len(np.array(rank_list).shape) != 1:
164
+ raise ValueError(
165
+ f"The rank_list should be a 1-D list, but got {len(np.array(rank_list).shape)}-D list.")
166
+ if len(rank_list) != np.prod(np.array(self._device_shape)):
167
+ raise ValueError(f"The length of rank_list should be equal to the product of device_matrix, "
168
+ f"but got {len(rank_list)} and {np.prod(np.array(self._device_shape))}.")
169
+ self._rank_list = rank_list
85
170
 
86
171
  def __call__(self, *tensor_map):
87
172
  self._tensor_map = ()
@@ -122,8 +207,8 @@ class Layout:
122
207
  raise ValueError("The tensor_map of layout is None")
123
208
  interleaved_parallel = "interleaved_parallel" in self._alias_name
124
209
  return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map,
125
- "interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name}
126
-
210
+ "interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name,
211
+ "rank_list": self._rank_list}
127
212
 
128
213
 
129
214
  class Shard(Shard_):
@@ -141,18 +226,6 @@ class Shard(Shard_):
141
226
  self.level = None
142
227
 
143
228
  def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
144
- parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
145
- if parallel_mode not in ("auto_parallel", "semi_auto_parallel"):
146
- raise AssertionError(
147
- f"Cell shard only supports auto parallel and semi auto parallel.")
148
- if ms.context.get_context("device_target") not in ("Ascend", "GPU"):
149
- raise AssertionError(
150
- f"'Shard' now only supports 'Ascend' and 'GPU'")
151
- if parallel_mode == "auto_parallel" and \
152
- ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
153
- raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard' when the "
154
- f"'parallel_mode' is 'auto_parallel.'")
155
-
156
229
  if not isinstance(in_strategy, tuple):
157
230
  raise TypeError(
158
231
  f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}.")
@@ -181,7 +254,8 @@ class Shard(Shard_):
181
254
  "will be overwritten as False.")
182
255
  ms.set_algo_parameters(fully_use_devices=False)
183
256
 
184
- if ms.context.get_auto_parallel_context("full_batch_is_set") is False:
257
+ if ms.context.get_auto_parallel_context("full_batch_is_set") is False and \
258
+ ms.context.get_context("mode") == ms.context.PYNATIVE_MODE:
185
259
  logger.warning("When calling the shard interface, "
186
260
  "'dataset_strategy' or 'full_batch' is not manually set by the user, "
187
261
  "and the 'dataset_strategy' will be set to 'full_batch'.")
@@ -193,13 +267,13 @@ class Shard(Shard_):
193
267
 
194
268
  if isinstance(fn, ms.nn.Cell):
195
269
  for param in fn.trainable_params():
196
- param.is_in_shard = True
270
+ param.param_info.is_in_pynative_shard = True
197
271
 
198
272
  # Set parameter layout to corresponding parameter
199
273
  self._set_param_layout_into_parameter(fn, parameter_plan)
200
274
 
201
275
  def shard_fn(*args):
202
- @ms.common.jit(hash_args=fn)
276
+ @ms.common.jit(hash_args=fn, backend="ms_backend")
203
277
  def after_shard(*args):
204
278
  return shard_(fn, in_strategy, out_strategy, device, level)(*args)
205
279
 
@@ -290,7 +364,7 @@ class Shard(Shard_):
290
364
  for stra in strategy:
291
365
  if not isinstance(stra, (tuple, Layout)):
292
366
  raise TypeError(
293
- f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.Layout), "
367
+ f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.parallel.Layout), "
294
368
  f"but got {type(stra).__name__}")
295
369
  if isinstance(stra, Layout):
296
370
  strategy_set.add("layout")
@@ -312,7 +386,7 @@ class Shard(Shard_):
312
386
  for in_ele in layout:
313
387
  if not isinstance(in_ele, Layout):
314
388
  raise TypeError(f"The {log_info} item should be a object of class Layout.")
315
- layout_value += (in_ele.to_dict(),)
389
+ layout_value += ({k: v for k, v in in_ele.to_dict().items() if k != "rank_list"},)
316
390
  return layout_value
317
391
 
318
392
  def _check_tuple_strategy(self, dim_strategy):
@@ -323,8 +397,8 @@ class Shard(Shard_):
323
397
 
324
398
  def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
325
399
  """
326
- Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
327
- generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
400
+ Specify the input and output slicing strategy for a Cell or function.
401
+ In PyNative mode, use this method to specify a Cell for distributed
328
402
  execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
329
403
  strategy for others will be set by sharding propagation.
330
404
  in_strategy and out_strategy define the input and output layout respectively.
@@ -334,33 +408,37 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
334
408
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
335
409
 
336
410
  Note:
337
- If ms.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
338
- "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
339
- If the input contain Parameter, its strategy should be set in `in_strategy`.
411
+ - If shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
412
+ "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
413
+ - If the input contain Parameter, its strategy should be set in `in_strategy`.
414
+ - This method currently does not support dynamic shapes.
340
415
 
341
416
  Args:
342
417
  fn (Union[Cell, Function]): Function to be executed in parallel.
343
- Its arguments and return value must be Tensor or Parameter.
418
+ Its arguments and return value must be Tensor.
344
419
  If `fn` is a Cell with parameters, `fn` needs to be an instantiated object,
345
420
  otherwise its arguments cannot be accessed.
346
421
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple(int) or
347
- tuple(mindspore.Layout).
422
+ tuple(mindspore.parallel.Layout).
348
423
  Tuple defines the layout of the corresponding input.
349
- out_strategy (Union[tuple, None]): Define the layout of outputs similar with `in_strategy`.
350
- It is not in use right now. Default: ``None`` .
351
- parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
424
+ out_strategy (Union[tuple, None], optional): Define the layout of outputs similar with `in_strategy`.
425
+ Default: ``None`` .
426
+ parameter_plan (Union[dict, None], optional): Define the layout for the specified parameters.
427
+ Each element in dict
352
428
  defines the layout of the parameter like "param_name: layout".
353
429
  The key is a parameter name of type 'str'.
354
- The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple,
430
+ The value is a 1-D integer tuple or a 1-D mindspore.parallel.Layout tuple,
355
431
  indicating the corresponding layout.
356
432
  If the parameter name is incorrect or the corresponding parameter
357
- has been set, the parameter setting will be ignored.
433
+ has been set, the parameter setting will be ignored. Supported
434
+ only when `fn` is a Cell with parameters.
358
435
  Default: ``None`` .
359
- device (string): Select a certain `device` target. It is not in use right now.
360
- Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
361
- level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
362
- over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
363
- use right now. Support [0, 1, 2]. Default: ``0`` .
436
+ device (str, optional): Select a certain `device` target. It is not in use right now.
437
+ Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
438
+ level (int, optional): Option for parallel strategy infer algorithm, namely the object function,
439
+ maximize computation
440
+ over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
441
+ use right now. Support [0, 1, 2]. Default: ``0`` .
364
442
 
365
443
  Returns:
366
444
  Function, return the function that will be executed under auto parallel process.
@@ -370,26 +448,28 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
370
448
  AssertionError: If device_target it not "Ascend" or "GPU".
371
449
  TypeError: If `in_strategy` is not a tuple.
372
450
  TypeError: If `out_strategy` is not a tuple or None.
373
- TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.Layout).
374
- TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.Layout).
451
+ TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
452
+ TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
375
453
  TypeError: If `parameter_plan` is not a dict or None.
376
454
  TypeError: If any key in `parameter_plan` is not a str.
377
- TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.Layout).
455
+ TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.parallel.Layout).
378
456
  TypeError: If `device` is not a str.
379
457
  TypeError: If `level` is not an integer.
380
458
 
381
459
  Supported Platforms:
382
- ``Ascend`` ``GPU``
460
+ ``Ascend``
383
461
 
384
462
  Examples:
385
463
  >>> import numpy as np
386
464
  >>> import mindspore as ms
387
- >>> from mindspore import Tensor, nn
465
+ >>> from mindspore import Tensor, nn, ops
388
466
  >>> from mindspore.communication import init
467
+ >>> from mindspore.parallel import shard
468
+ >>> from mindspore.parallel import Layout
469
+ >>> from mindspore.nn.utils import no_init_parameters
470
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
389
471
  >>> ms.set_context(mode=ms.GRAPH_MODE)
390
472
  >>> init()
391
- >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
392
- ... device_num=8)
393
473
  >>>
394
474
  >>> # Case 1: cell uses functional
395
475
  >>> class BasicBlock(nn.Cell):
@@ -401,7 +481,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
401
481
  >>> x = ops.abs(x)
402
482
  >>> return x + y
403
483
  >>> # shard a function with tuple(int) strategies
404
- >>> self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
484
+ >>> self.shard_my_add = shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
405
485
  >>>
406
486
  >>> def construct(self, x, u):
407
487
  >>> x = self.gelu(x)
@@ -429,7 +509,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
429
509
  >>> super(Net, self).__init__()
430
510
  >>> # setting cell sharding strategy and parameter_plan by tuple(int)
431
511
  >>> self.layer_net1 = NetForward()
432
- >>> self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
512
+ >>> self.layer_net1_shard = shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
433
513
  ... parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
434
514
  >>>
435
515
  >>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
@@ -437,7 +517,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
437
517
  >>> layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
438
518
  >>> in_layout = (layout("dp", "mp"), layout("mp", "sp"))
439
519
  >>> param_layout = layout("dp", "sp")
440
- >>> self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout,
520
+ >>> self.layer_net2_shard = shard(self.layer_net2, in_strategy=in_layout,
441
521
  ... parameter_plan={"self.layer_net2.block2.weight": param_layout})
442
522
  >>> self.flatten = nn.Flatten()
443
523
  >>> self.layer1 = nn.Dense(64, 64)
@@ -455,26 +535,25 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
455
535
  >>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
456
536
  >>> return x
457
537
  >>>
458
- >>> net = Net()
538
+ >>> with no_init_parameters():
539
+ >>> net = Net()
459
540
  >>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
460
541
  >>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
461
- >>> net(x, y)
542
+ >>> parallel_net = AutoParallel(net, parallel_mode='sharding_propagation')
543
+ >>> parallel_net(x, y)
462
544
  >>>
463
545
  >>> # Case 2: function uses functional sharding
464
546
  >>> def test_shard(x, y):
465
547
  ... return x + y
466
548
  >>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
467
549
  >>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
468
- >>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
550
+ >>> output = shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
469
551
  >>> print(output.shape)
470
552
  (32, 10)
471
553
 
472
- Tutorial Examples:
473
- - `Functional Operator Sharding
474
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/shard_function_parallel.html>`_
475
- - `mindspore.Layout
476
- <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html>`_
477
554
  """
555
+ if ms.communication.management.get_group_size() == 1:
556
+ return fn
478
557
  if not isinstance(fn, (ms.nn.Cell)):
479
558
  logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
480
559
  "otherwise, the result may be incorrect.")