mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (579) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +13 -6
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -0
  7. mindspore/_checkparam.py +3 -38
  8. mindspore/_deprecated/__init__.py +17 -0
  9. mindspore/_deprecated/jit.py +198 -0
  10. mindspore/_extends/builtin_operations.py +1 -1
  11. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  12. mindspore/_extends/parse/__init__.py +6 -7
  13. mindspore/_extends/parse/compile_config.py +83 -0
  14. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +47 -198
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +229 -99
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +11 -5
  27. mindspore/avcodec-59.dll +0 -0
  28. mindspore/avdevice-59.dll +0 -0
  29. mindspore/avfilter-8.dll +0 -0
  30. mindspore/avformat-59.dll +0 -0
  31. mindspore/avutil-57.dll +0 -0
  32. mindspore/boost/__init__.py +2 -2
  33. mindspore/boost/base.py +3 -7
  34. mindspore/boost/boost_cell_wrapper.py +138 -43
  35. mindspore/common/__init__.py +6 -3
  36. mindspore/common/_grad_function.py +56 -0
  37. mindspore/common/_pijit_context.py +14 -5
  38. mindspore/common/_register_for_tensor.py +1 -2
  39. mindspore/common/_stub_tensor.py +30 -14
  40. mindspore/common/_tensor_cpp_method.py +17 -0
  41. mindspore/common/_tensor_docs.py +4760 -0
  42. mindspore/common/api.py +480 -372
  43. mindspore/common/auto_dynamic_shape.py +41 -44
  44. mindspore/common/dtype.py +39 -36
  45. mindspore/common/dump.py +9 -6
  46. mindspore/common/file_system.py +9 -1
  47. mindspore/common/generator.py +5 -0
  48. mindspore/common/hook_handle.py +6 -2
  49. mindspore/common/initializer.py +13 -10
  50. mindspore/common/jit_begin_end.py +94 -0
  51. mindspore/common/jit_config.py +6 -1
  52. mindspore/common/jit_context.py +76 -0
  53. mindspore/common/jit_trace.py +378 -0
  54. mindspore/common/lazy_inline.py +9 -3
  55. mindspore/common/mindir_util.py +10 -2
  56. mindspore/common/mutable.py +5 -4
  57. mindspore/common/parameter.py +135 -52
  58. mindspore/common/seed.py +2 -2
  59. mindspore/common/sparse_tensor.py +23 -17
  60. mindspore/common/tensor.py +975 -1981
  61. mindspore/communication/__init__.py +7 -5
  62. mindspore/communication/_comm_helper.py +52 -2
  63. mindspore/communication/comm_func.py +240 -181
  64. mindspore/communication/management.py +95 -26
  65. mindspore/context.py +324 -573
  66. mindspore/dataset/__init__.py +65 -37
  67. mindspore/dataset/audio/__init__.py +2 -8
  68. mindspore/dataset/audio/transforms.py +3 -17
  69. mindspore/dataset/callback/ds_callback.py +2 -1
  70. mindspore/dataset/core/config.py +87 -6
  71. mindspore/dataset/engine/cache_admin.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +6 -5
  73. mindspore/dataset/engine/datasets.py +292 -267
  74. mindspore/dataset/engine/datasets_audio.py +22 -8
  75. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  76. mindspore/dataset/engine/datasets_text.py +78 -48
  77. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  78. mindspore/dataset/engine/datasets_vision.py +120 -44
  79. mindspore/dataset/engine/iterators.py +283 -63
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/obs/util.py +8 -0
  82. mindspore/dataset/engine/queue.py +40 -0
  83. mindspore/dataset/engine/samplers.py +289 -43
  84. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  85. mindspore/dataset/engine/validators.py +53 -11
  86. mindspore/dataset/text/__init__.py +7 -6
  87. mindspore/dataset/text/transforms.py +6 -5
  88. mindspore/dataset/text/utils.py +3 -3
  89. mindspore/dataset/transforms/__init__.py +0 -9
  90. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  91. mindspore/dataset/transforms/transforms.py +31 -14
  92. mindspore/dataset/utils/browse_dataset.py +1 -1
  93. mindspore/dataset/vision/__init__.py +2 -9
  94. mindspore/dataset/vision/transforms.py +202 -158
  95. mindspore/dataset/vision/utils.py +7 -5
  96. mindspore/dataset/vision/validators.py +1 -2
  97. mindspore/device_context/__init__.py +21 -0
  98. mindspore/device_context/ascend/__init__.py +25 -0
  99. mindspore/device_context/ascend/device.py +72 -0
  100. mindspore/device_context/ascend/op_debug.py +153 -0
  101. mindspore/device_context/ascend/op_precision.py +193 -0
  102. mindspore/device_context/ascend/op_tuning.py +123 -0
  103. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  104. mindspore/device_context/cpu/device.py +62 -0
  105. mindspore/device_context/cpu/op_tuning.py +43 -0
  106. mindspore/device_context/gpu/__init__.py +21 -0
  107. mindspore/device_context/gpu/device.py +70 -0
  108. mindspore/device_context/gpu/op_precision.py +67 -0
  109. mindspore/device_context/gpu/op_tuning.py +175 -0
  110. mindspore/device_manager.py +170 -0
  111. mindspore/dnnl.dll +0 -0
  112. mindspore/experimental/es/embedding_service.py +35 -27
  113. mindspore/experimental/llm_boost/__init__.py +1 -0
  114. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  115. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  116. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  119. mindspore/experimental/llm_boost/register.py +1 -0
  120. mindspore/experimental/map_parameter.py +4 -4
  121. mindspore/experimental/optim/adadelta.py +6 -6
  122. mindspore/experimental/optim/adagrad.py +4 -4
  123. mindspore/experimental/optim/adam.py +7 -0
  124. mindspore/experimental/optim/adamax.py +4 -4
  125. mindspore/experimental/optim/adamw.py +4 -0
  126. mindspore/experimental/optim/asgd.py +1 -1
  127. mindspore/experimental/optim/lr_scheduler.py +73 -46
  128. mindspore/experimental/optim/radam.py +34 -31
  129. mindspore/experimental/optim/rprop.py +1 -1
  130. mindspore/experimental/optim/sgd.py +1 -1
  131. mindspore/hal/contiguous_tensors_handle.py +6 -10
  132. mindspore/hal/device.py +55 -53
  133. mindspore/hal/event.py +52 -52
  134. mindspore/hal/memory.py +179 -120
  135. mindspore/hal/stream.py +150 -109
  136. mindspore/include/api/context.h +0 -1
  137. mindspore/include/dataset/constants.h +7 -4
  138. mindspore/include/dataset/execute.h +2 -2
  139. mindspore/jpeg62.dll +0 -0
  140. mindspore/log.py +50 -0
  141. mindspore/mindrecord/__init__.py +21 -8
  142. mindspore/mindrecord/config.py +17 -316
  143. mindspore/mindrecord/filereader.py +1 -9
  144. mindspore/mindrecord/filewriter.py +5 -15
  145. mindspore/mindrecord/mindpage.py +1 -9
  146. mindspore/mindspore_backend_common.dll +0 -0
  147. mindspore/mindspore_backend_manager.dll +0 -0
  148. mindspore/mindspore_common.dll +0 -0
  149. mindspore/mindspore_core.dll +0 -0
  150. mindspore/mindspore_dump.dll +0 -0
  151. mindspore/mindspore_frontend.dll +0 -0
  152. mindspore/mindspore_glog.dll +0 -0
  153. mindspore/mindspore_memory_pool.dll +0 -0
  154. mindspore/mindspore_ms_backend.dll +0 -0
  155. mindspore/mindspore_ops.dll +0 -0
  156. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  157. mindspore/mindspore_ops_kernel_common.dll +0 -0
  158. mindspore/mindspore_profiler.dll +0 -0
  159. mindspore/mindspore_pyboost.dll +0 -0
  160. mindspore/mindspore_pynative.dll +0 -0
  161. mindspore/mindspore_res_manager.dll +0 -0
  162. mindspore/mindspore_runtime_pipeline.dll +0 -0
  163. mindspore/mint/__init__.py +798 -761
  164. mindspore/mint/distributed/__init__.py +70 -4
  165. mindspore/mint/distributed/distributed.py +2679 -44
  166. mindspore/mint/linalg/__init__.py +8 -0
  167. mindspore/mint/nn/__init__.py +743 -22
  168. mindspore/mint/nn/functional.py +716 -23
  169. mindspore/mint/nn/layer/__init__.py +21 -4
  170. mindspore/mint/nn/layer/_functions.py +334 -0
  171. mindspore/mint/nn/layer/activation.py +276 -1
  172. mindspore/mint/nn/layer/basic.py +123 -0
  173. mindspore/mint/nn/layer/conv.py +933 -0
  174. mindspore/mint/nn/layer/normalization.py +223 -28
  175. mindspore/mint/nn/layer/padding.py +797 -0
  176. mindspore/mint/nn/layer/pooling.py +235 -0
  177. mindspore/mint/optim/__init__.py +3 -1
  178. mindspore/mint/optim/adam.py +223 -0
  179. mindspore/mint/optim/adamw.py +26 -19
  180. mindspore/mint/optim/sgd.py +171 -0
  181. mindspore/mint/special/__init__.py +2 -1
  182. mindspore/multiprocessing/__init__.py +5 -0
  183. mindspore/nn/__init__.py +4 -1
  184. mindspore/nn/cell.py +1373 -192
  185. mindspore/nn/dynamic_lr.py +2 -1
  186. mindspore/nn/layer/activation.py +29 -27
  187. mindspore/nn/layer/basic.py +51 -35
  188. mindspore/nn/layer/channel_shuffle.py +3 -3
  189. mindspore/nn/layer/container.py +1 -1
  190. mindspore/nn/layer/conv.py +53 -42
  191. mindspore/nn/layer/embedding.py +12 -11
  192. mindspore/nn/layer/normalization.py +56 -49
  193. mindspore/nn/layer/padding.py +4 -3
  194. mindspore/nn/layer/pooling.py +120 -42
  195. mindspore/nn/layer/rnn_cells.py +1 -1
  196. mindspore/nn/layer/rnns.py +2 -1
  197. mindspore/nn/layer/timedistributed.py +5 -5
  198. mindspore/nn/layer/transformer.py +59 -36
  199. mindspore/nn/learning_rate_schedule.py +8 -4
  200. mindspore/nn/loss/loss.py +58 -55
  201. mindspore/nn/optim/ada_grad.py +7 -5
  202. mindspore/nn/optim/adadelta.py +11 -9
  203. mindspore/nn/optim/adafactor.py +1 -1
  204. mindspore/nn/optim/adam.py +19 -15
  205. mindspore/nn/optim/adamax.py +8 -7
  206. mindspore/nn/optim/adasum.py +5 -5
  207. mindspore/nn/optim/asgd.py +3 -1
  208. mindspore/nn/optim/ftrl.py +11 -9
  209. mindspore/nn/optim/lamb.py +1 -1
  210. mindspore/nn/optim/lars.py +1 -4
  211. mindspore/nn/optim/lazyadam.py +12 -10
  212. mindspore/nn/optim/momentum.py +7 -6
  213. mindspore/nn/optim/optimizer.py +3 -3
  214. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  215. mindspore/nn/optim/rmsprop.py +13 -12
  216. mindspore/nn/optim/rprop.py +11 -9
  217. mindspore/nn/optim/sgd.py +9 -6
  218. mindspore/nn/optim/tft_wrapper.py +5 -2
  219. mindspore/nn/optim/thor.py +2 -1
  220. mindspore/nn/probability/bijector/bijector.py +17 -11
  221. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  222. mindspore/nn/probability/bijector/invert.py +2 -2
  223. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  224. mindspore/nn/probability/bijector/softplus.py +3 -2
  225. mindspore/nn/probability/distribution/beta.py +3 -3
  226. mindspore/nn/probability/distribution/categorical.py +1 -1
  227. mindspore/nn/probability/distribution/cauchy.py +4 -2
  228. mindspore/nn/probability/distribution/exponential.py +6 -7
  229. mindspore/nn/probability/distribution/gamma.py +2 -2
  230. mindspore/nn/probability/distribution/gumbel.py +2 -2
  231. mindspore/nn/probability/distribution/half_normal.py +5 -3
  232. mindspore/nn/probability/distribution/logistic.py +5 -3
  233. mindspore/nn/probability/distribution/poisson.py +1 -1
  234. mindspore/nn/probability/distribution/uniform.py +5 -3
  235. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  236. mindspore/nn/reinforcement/tensor_array.py +1 -1
  237. mindspore/nn/utils/init.py +13 -11
  238. mindspore/nn/wrap/__init__.py +6 -6
  239. mindspore/nn/wrap/cell_wrapper.py +181 -122
  240. mindspore/nn/wrap/grad_reducer.py +45 -36
  241. mindspore/nn/wrap/loss_scale.py +6 -7
  242. mindspore/numpy/array_creations.py +63 -65
  243. mindspore/numpy/array_ops.py +149 -144
  244. mindspore/numpy/logic_ops.py +41 -42
  245. mindspore/numpy/math_ops.py +361 -359
  246. mindspore/numpy/utils.py +17 -18
  247. mindspore/numpy/utils_const.py +5 -6
  248. mindspore/opencv_core452.dll +0 -0
  249. mindspore/opencv_imgcodecs452.dll +0 -0
  250. mindspore/opencv_imgproc452.dll +0 -0
  251. mindspore/ops/__init__.py +5 -3
  252. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  253. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  256. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  257. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  258. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  259. mindspore/ops/_register_for_op.py +0 -11
  260. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  261. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  262. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  263. mindspore/ops/_vmap/vmap_base.py +0 -2
  264. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  265. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  266. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  267. mindspore/ops/auto_generate/__init__.py +4 -3
  268. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  269. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  270. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  271. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  272. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  273. mindspore/ops/composite/__init__.py +2 -1
  274. mindspore/ops/composite/base.py +20 -25
  275. mindspore/ops/composite/math_ops.py +6 -16
  276. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  277. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  278. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  279. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  280. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  281. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  282. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  283. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  284. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  285. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  286. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  287. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  288. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  289. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  290. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  291. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  292. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  293. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  294. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  295. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  299. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  301. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  302. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  303. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  304. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  305. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  306. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  307. mindspore/ops/function/__init__.py +40 -2
  308. mindspore/ops/function/_add_attr_func.py +58 -0
  309. mindspore/ops/function/array_func.py +2089 -2403
  310. mindspore/ops/function/clip_func.py +80 -23
  311. mindspore/ops/function/debug_func.py +57 -57
  312. mindspore/ops/function/grad/__init__.py +1 -0
  313. mindspore/ops/function/grad/grad_func.py +104 -71
  314. mindspore/ops/function/image_func.py +2 -2
  315. mindspore/ops/function/linalg_func.py +47 -78
  316. mindspore/ops/function/math_func.py +4351 -3813
  317. mindspore/ops/function/nn_func.py +1712 -637
  318. mindspore/ops/function/other_func.py +159 -1
  319. mindspore/ops/function/parameter_func.py +18 -84
  320. mindspore/ops/function/random_func.py +452 -387
  321. mindspore/ops/function/reshard_func.py +4 -70
  322. mindspore/ops/function/sparse_func.py +3 -3
  323. mindspore/ops/function/sparse_unary_func.py +6 -6
  324. mindspore/ops/function/spectral_func.py +25 -58
  325. mindspore/ops/function/vmap_func.py +26 -18
  326. mindspore/ops/functional.py +23 -7
  327. mindspore/ops/functional_overload.py +1548 -0
  328. mindspore/ops/op_info_register.py +32 -244
  329. mindspore/ops/operations/__init__.py +23 -15
  330. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  331. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  332. mindspore/ops/operations/_grad_ops.py +2 -43
  333. mindspore/ops/operations/_infer_ops.py +2 -1
  334. mindspore/ops/operations/_inner_ops.py +43 -84
  335. mindspore/ops/operations/_ms_kernel.py +4 -10
  336. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  337. mindspore/ops/operations/_scalar_ops.py +3 -2
  338. mindspore/ops/operations/_sequence_ops.py +1 -1
  339. mindspore/ops/operations/_tensor_array.py +1 -1
  340. mindspore/ops/operations/array_ops.py +81 -324
  341. mindspore/ops/operations/comm_ops.py +154 -108
  342. mindspore/ops/operations/custom_ops.py +298 -87
  343. mindspore/ops/operations/debug_ops.py +157 -59
  344. mindspore/ops/operations/inner_ops.py +7 -5
  345. mindspore/ops/operations/linalg_ops.py +1 -57
  346. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  347. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  348. mindspore/ops/operations/math_ops.py +32 -234
  349. mindspore/ops/operations/nn_ops.py +212 -531
  350. mindspore/ops/operations/other_ops.py +62 -9
  351. mindspore/ops/operations/random_ops.py +13 -7
  352. mindspore/ops/operations/reshard_ops.py +1 -1
  353. mindspore/ops/operations/sparse_ops.py +2 -2
  354. mindspore/ops/primitive.py +66 -53
  355. mindspore/ops/tensor_method.py +1895 -0
  356. mindspore/ops_generate/__init__.py +0 -5
  357. mindspore/ops_generate/aclnn/__init__.py +0 -0
  358. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  359. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  360. mindspore/ops_generate/api/__init__.py +0 -0
  361. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  362. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  363. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  364. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  365. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  366. mindspore/ops_generate/api/gen_api.py +103 -0
  367. mindspore/ops_generate/api/op_api_proto.py +235 -0
  368. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  369. mindspore/ops_generate/common/__init__.py +0 -0
  370. mindspore/ops_generate/common/base_generator.py +11 -0
  371. mindspore/ops_generate/common/gen_constants.py +91 -0
  372. mindspore/ops_generate/common/gen_utils.py +348 -0
  373. mindspore/ops_generate/common/op_proto.py +473 -0
  374. mindspore/ops_generate/common/template.py +523 -0
  375. mindspore/ops_generate/gen_ops.py +22 -1069
  376. mindspore/ops_generate/op_def/__init__.py +0 -0
  377. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  378. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  379. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  380. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  381. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  382. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  383. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  384. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  385. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  386. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  387. mindspore/ops_generate/pyboost/__init__.py +0 -0
  388. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  389. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  390. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  391. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  392. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  393. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  394. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  395. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  396. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  397. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  398. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  399. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  400. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  401. mindspore/ops_generate/resources/__init__.py +0 -0
  402. mindspore/ops_generate/resources/resource_list.py +30 -0
  403. mindspore/ops_generate/resources/resource_loader.py +36 -0
  404. mindspore/ops_generate/resources/resource_manager.py +64 -0
  405. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  406. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  407. mindspore/parallel/__init__.py +7 -3
  408. mindspore/parallel/_auto_parallel_context.py +159 -40
  409. mindspore/parallel/_cell_wrapper.py +132 -15
  410. mindspore/parallel/_parallel_serialization.py +107 -5
  411. mindspore/parallel/_ps_context.py +1 -1
  412. mindspore/parallel/_recovery_context.py +7 -2
  413. mindspore/parallel/_tensor.py +142 -18
  414. mindspore/parallel/_utils.py +199 -23
  415. mindspore/parallel/algo_parameter_config.py +4 -4
  416. mindspore/parallel/auto_parallel.py +732 -0
  417. mindspore/parallel/checkpoint_convert.py +159 -0
  418. mindspore/parallel/checkpoint_transform.py +700 -35
  419. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  420. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  421. mindspore/parallel/cluster/run.py +21 -4
  422. mindspore/parallel/function/__init__.py +24 -0
  423. mindspore/parallel/function/reshard_func.py +258 -0
  424. mindspore/parallel/nn/__init__.py +25 -0
  425. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  426. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  427. mindspore/parallel/parameter_broadcast.py +25 -14
  428. mindspore/parallel/shard.py +137 -59
  429. mindspore/parallel/transform_safetensors.py +364 -305
  430. mindspore/profiler/__init__.py +22 -5
  431. mindspore/profiler/analysis/__init__.py +0 -0
  432. mindspore/profiler/analysis/parser/__init__.py +0 -0
  433. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  434. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  435. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  436. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  437. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  438. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  439. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  440. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  441. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  442. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  443. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  444. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  445. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  446. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  447. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  448. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  449. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  450. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  451. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  452. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  453. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  454. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  455. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  456. mindspore/profiler/analysis/task_manager.py +131 -0
  457. mindspore/profiler/analysis/time_converter.py +84 -0
  458. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  459. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  460. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  461. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  462. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  463. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  464. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  465. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  466. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  467. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  468. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  469. mindspore/profiler/analysis/work_flow.py +73 -0
  470. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  471. mindspore/profiler/common/command_executor.py +90 -0
  472. mindspore/profiler/common/constant.py +186 -3
  473. mindspore/profiler/common/file_manager.py +208 -0
  474. mindspore/profiler/common/log.py +130 -0
  475. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  476. mindspore/profiler/common/path_manager.py +395 -0
  477. mindspore/profiler/common/process_bar.py +168 -0
  478. mindspore/profiler/common/process_pool.py +9 -3
  479. mindspore/profiler/common/profiler_context.py +500 -0
  480. mindspore/profiler/common/profiler_info.py +304 -0
  481. mindspore/profiler/common/profiler_meta_data.py +74 -0
  482. mindspore/profiler/common/profiler_output_path.py +284 -0
  483. mindspore/profiler/common/profiler_parameters.py +251 -0
  484. mindspore/profiler/common/profiler_path_manager.py +179 -0
  485. mindspore/profiler/common/record_function.py +76 -0
  486. mindspore/profiler/common/tlv_decoder.py +76 -0
  487. mindspore/profiler/common/util.py +75 -2
  488. mindspore/profiler/dynamic_profiler.py +341 -75
  489. mindspore/profiler/envprofiler.py +163 -0
  490. mindspore/profiler/experimental_config.py +197 -0
  491. mindspore/profiler/mstx.py +242 -0
  492. mindspore/profiler/platform/__init__.py +21 -0
  493. mindspore/profiler/platform/base_profiler.py +40 -0
  494. mindspore/profiler/platform/cpu_profiler.py +124 -0
  495. mindspore/profiler/platform/gpu_profiler.py +74 -0
  496. mindspore/profiler/platform/npu_profiler.py +335 -0
  497. mindspore/profiler/profiler.py +1073 -90
  498. mindspore/profiler/profiler_action_controller.py +187 -0
  499. mindspore/profiler/profiler_interface.py +118 -0
  500. mindspore/profiler/schedule.py +243 -0
  501. mindspore/rewrite/api/node.py +15 -13
  502. mindspore/rewrite/api/symbol_tree.py +2 -3
  503. mindspore/run_check/_check_version.py +27 -20
  504. mindspore/run_check/run_check.py +1 -1
  505. mindspore/runtime/__init__.py +37 -0
  506. mindspore/runtime/device.py +27 -0
  507. mindspore/runtime/event.py +209 -0
  508. mindspore/runtime/executor.py +177 -0
  509. mindspore/runtime/memory.py +416 -0
  510. mindspore/runtime/stream.py +460 -0
  511. mindspore/runtime/thread_bind_core.py +401 -0
  512. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  513. mindspore/swresample-4.dll +0 -0
  514. mindspore/swscale-6.dll +0 -0
  515. mindspore/tinyxml2.dll +0 -0
  516. mindspore/train/__init__.py +8 -8
  517. mindspore/train/_utils.py +96 -27
  518. mindspore/train/amp.py +9 -5
  519. mindspore/train/callback/__init__.py +2 -2
  520. mindspore/train/callback/_callback.py +2 -16
  521. mindspore/train/callback/_checkpoint.py +53 -55
  522. mindspore/train/callback/_cluster_monitor.py +14 -18
  523. mindspore/train/callback/_early_stop.py +1 -1
  524. mindspore/train/callback/_flops_collector.py +103 -68
  525. mindspore/train/callback/_history.py +8 -5
  526. mindspore/train/callback/_lambda_callback.py +2 -2
  527. mindspore/train/callback/_landscape.py +0 -3
  528. mindspore/train/callback/_loss_monitor.py +2 -1
  529. mindspore/train/callback/_on_request_exit.py +6 -5
  530. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  531. mindspore/train/callback/_summary_collector.py +52 -19
  532. mindspore/train/callback/_time_monitor.py +2 -1
  533. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  534. mindspore/train/data_sink.py +25 -2
  535. mindspore/train/dataset_helper.py +15 -16
  536. mindspore/train/loss_scale_manager.py +8 -7
  537. mindspore/train/metrics/accuracy.py +3 -3
  538. mindspore/train/metrics/confusion_matrix.py +9 -9
  539. mindspore/train/metrics/error.py +3 -3
  540. mindspore/train/metrics/hausdorff_distance.py +4 -4
  541. mindspore/train/metrics/mean_surface_distance.py +3 -3
  542. mindspore/train/metrics/metric.py +0 -12
  543. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  544. mindspore/train/metrics/precision.py +11 -10
  545. mindspore/train/metrics/recall.py +9 -9
  546. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +174 -46
  548. mindspore/train/model.py +269 -136
  549. mindspore/train/serialization.py +622 -978
  550. mindspore/train/summary/_summary_adapter.py +2 -2
  551. mindspore/train/summary/summary_record.py +2 -3
  552. mindspore/train/train_thor/model_thor.py +1 -1
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +6 -3
  555. mindspore/utils/dryrun.py +140 -0
  556. mindspore/utils/hooks.py +81 -0
  557. mindspore/utils/runtime_execution_order_check.py +552 -0
  558. mindspore/utils/utils.py +138 -4
  559. mindspore/version.py +1 -1
  560. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  561. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
  562. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  563. mindspore/_install_custom.py +0 -43
  564. mindspore/common/_register_for_adapter.py +0 -74
  565. mindspore/common/_tensor_overload.py +0 -139
  566. mindspore/mindspore_np_dtype.dll +0 -0
  567. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  568. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  569. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  570. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  571. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  572. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  573. mindspore/ops_generate/gen_utils.py +0 -209
  574. mindspore/ops_generate/op_proto.py +0 -145
  575. mindspore/ops_generate/template.py +0 -261
  576. mindspore/profiler/envprofiling.py +0 -254
  577. mindspore/profiler/profiling.py +0 -1926
  578. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  579. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """parallel serialization"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore import context
19
+ from mindspore.nn.cell import Cell
20
+ from mindspore.ops import functional as F, composite as C, operations as P
21
+ import mindspore.common.dtype as mstype
22
+ from mindspore.common.sparse_tensor import Tensor
23
+ from mindspore.common.api import jit
24
+ from mindspore.common.parameter import Parameter
25
+ from mindspore.nn.layer import Identity
26
+ from mindspore.parallel._utils import _get_enable_parallel_optimizer
27
+
28
+ __all__ = ['PipelineGradReducer']
29
+
30
+
31
+ grad_scale = C.MultitypeFuncGraph("grad_scale")
32
+ shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
33
+ reciprocal = P.Reciprocal()
34
+
35
+
36
+ @grad_scale.register("Tensor", "Tensor", "Tensor")
37
+ def tensor_grad_scale_pipeline(scale, grad, accu_grad):
38
+ accu_grad = F.depend(accu_grad, grad)
39
+ new_grad = accu_grad * reciprocal(scale)
40
+ accu_grad = F.depend(accu_grad, new_grad)
41
+ zeros = F.tensor_mul(accu_grad, 0.0)
42
+ new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
43
+ return new_grad
44
+
45
+
46
+ @shard_grad_scale.register("Tensor", "Tensor", "Tensor")
47
+ def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
48
+ new_grad = grad * reciprocal(scale)
49
+ accu_grad = F.depend(accu_grad, new_grad)
50
+ new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
51
+ return new_grad
52
+
53
+
54
+ class PipelineGradReducer(Cell):
55
+ """
56
+ Functional training scenarios for gradient statute and accumulation of pipeline parallel.
57
+
58
+ Args:
59
+ parameters (list): the parameters to be updated.
60
+ scale_sense (float, optional): the scale sense of the gradient. Default: 1.0.
61
+ opt_shard(bool, optional): if use parallel optimizer, set opt_shard True. Default: ``None``.
62
+
63
+ Raise:
64
+ RuntimeError: If the mode is not graph mode.
65
+
66
+ Supported Platforms:
67
+ ``Ascend`` ``GPU``
68
+
69
+ Examples:
70
+ .. note::
71
+ Before running the following examples, you need to configure the communication environment variables.
72
+
73
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
74
+ Please see the `rank table Startup
75
+ <https://www.mindspore.cn/tutorials/en/master/parallel/rank_table.html>`_
76
+ for more details.
77
+
78
+ This example should be run with multiple devices.
79
+
80
+ >>> import numpy as np
81
+ >>> import mindspore as ms
82
+ >>> from mindspore import nn, ops, Tensor
83
+ >>> from mindspore.communication import init
84
+ >>>
85
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
86
+ >>> ms.reset_auto_parallel_context()
87
+ >>> init()
88
+ >>> ms.set_seed(1)
89
+ >>>
90
+ >>> class Network(nn.Cell):
91
+ ... def __init__(self, in_features, out_features, sens=1.0):
92
+ ... super().__init__()
93
+ ... self.layer1 = nn.Dense(in_features, 16)
94
+ ... self.relu1 = nn.ReLU()
95
+ ... self.layer2 = nn.Dense(16, 16)
96
+ ... self.relu2 = nn.ReLU()
97
+ ... self.layer3 = nn.Dense(16, out_features)
98
+ ...
99
+ ... def construct(self, x):
100
+ ... x = self.layer1(x)
101
+ ... x = self.relu1(x)
102
+ ... x = self.layer2(x)
103
+ ... x = self.relu2(x)
104
+ ... logits = self.layer3(x)
105
+ ... return logits
106
+ >>>
107
+ >>> size, in_features, out_features = 16, 32, 10
108
+ >>> net = Network(in_features, out_features)
109
+ >>> net.layer1.pipeline_stage = 0
110
+ >>> net.relu1.pipeline_stage = 0
111
+ >>> net.layer2.pipeline_stage = 0
112
+ >>> net.relu2.pipeline_stage = 1
113
+ >>> net.layer3.pipeline_stage = 1
114
+ >>> loss_fn = nn.CrossEntropyLoss()
115
+ >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
116
+ >>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 2)
117
+ >>> net_with_loss.set_train()
118
+ >>> def forward_fn(inputs, target):
119
+ ... loss = net_with_loss(inputs, target)
120
+ ... return loss
121
+ >>>
122
+ >>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params())
123
+ >>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters)
124
+ >>>
125
+ >>> @ms.jit
126
+ >>> def train_one_step(inputs, target):
127
+ ... loss, grads = grad_fn(inputs, target)
128
+ ... grads = pp_grad_reducer(grads)
129
+ ... optimizer(grads)
130
+ ... return loss, grads
131
+ >>>
132
+ >>> parallel_net = AutoParallel(train_one_step, parallel_mode="semi_auto")
133
+ >>> parallel_net.pipeline(stages=2)
134
+ >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
135
+ >>> label = Tensor(np.ones([size, out_features]).astype(np.float32))
136
+ >>> loss, _ = train_one_step(inputs, label)
137
+ >>> print(loss)
138
+ 46.36721
139
+ """
140
+ def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
141
+ super(PipelineGradReducer, self).__init__(auto_prefix=False)
142
+ self._check_mode()
143
+ self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
144
+ self.grad_reducer = Identity()
145
+ self.degree = Tensor(1, mstype.float32)
146
+ self.scale_sense = Parameter(scale_sense, name='scale_sense')
147
+ self.hyper_map = C.HyperMap()
148
+ if opt_shard is None:
149
+ self.opt_shard = _get_enable_parallel_optimizer()
150
+ else:
151
+ self.opt_shard = opt_shard
152
+
153
+ @jit
154
+ def construct(self, grads):
155
+ new_grads = None
156
+ if self.opt_shard:
157
+ grads = self.grad_reducer(grads)
158
+ new_grads = self.hyper_map(F.partial(shard_grad_scale, self.scale_sense * self.degree),
159
+ grads, self.accu_grads)
160
+ else:
161
+ accu_grads = self.grad_reducer(self.accu_grads)
162
+ new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
163
+ return new_grads
164
+
165
+ def _check_mode(self):
166
+ """check parallel mode"""
167
+ mode = context.get_context('mode')
168
+ if mode != context.GRAPH_MODE:
169
+ raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
@@ -19,7 +19,10 @@ __all__ = ["parameter_broadcast"]
19
19
 
20
20
  import numpy as np
21
21
  import mindspore as ms
22
- from mindspore.communication import get_rank, create_group, get_group_size
22
+ from mindspore.communication import create_group, get_group_size
23
+ from mindspore.parallel._utils import _get_auto_parallel_net, _parallel_mode_map, _check_rank
24
+ # disable pylint too broad Exception
25
+ # pylint: disable=W0212
23
26
 
24
27
 
25
28
  def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
@@ -34,7 +37,8 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
34
37
  layout (Dict): Parameter layout dictionary. Come from
35
38
  :func:`mindspore.nn.Cell.parameter_layout_dict`
36
39
  or read from file(for example: "strategy.ckpt" saved by using the
37
- `strategy_ckpt_config` parameter of :func:`mindspore.set_auto_parallel_context`).
40
+ `strategy_ckpt_config` parameter of
41
+ :func:`mindspore.parallel.auto_parallel.AutoParallel.save_param_strategy_file` ).
38
42
  The key is param name, the value is the layout of this parameter.
39
43
  cur_rank (int, optional): current rank id. Default: ``0``.
40
44
  initial_rank (int, optional): Start rank id for each pipeline. Default: ``0``.
@@ -45,6 +49,9 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
45
49
  ValueError: Parameter name in `layout` can not be found in
46
50
  :func:`mindspore.nn.Cell.parameters_dict`.
47
51
 
52
+ Supported Platforms:
53
+ ``Ascend``
54
+
48
55
  Examples:
49
56
  >>> import os
50
57
  >>> import mindspore as ms
@@ -53,11 +60,11 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
53
60
  >>> from mindspore.communication import init
54
61
  >>> from mindspore.common.initializer import initializer
55
62
  >>> from mindspore.train import Model
56
- >>> from mindspore.parallel.parameter_broadcast import parameter_broadcast
57
63
  >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
64
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
65
+ >>> from mindspore.parallel import parameter_broadcast
58
66
  >>> ms.set_context(mode=ms.GRAPH_MODE)
59
- >>> ms.set_context(max_device_memory="28GB")
60
- >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
67
+ >>> ms.runtime.set_memory(max_size="28GB")
61
68
  >>> init()
62
69
  >>> ms.set_seed(1)
63
70
  >>> class Network(nn.Cell):
@@ -90,7 +97,8 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
90
97
  >>> dataset = create_dataset()
91
98
  >>> optim = nn.SGD(net.trainable_params(), 1e-2)
92
99
  >>> loss = nn.CrossEntropyLoss()
93
- >>> model = Model(net, loss_fn=loss, optimizer=optim)
100
+ >>> parallel_net = AutoParallel(net)
101
+ >>> model = Model(parallel_net, loss_fn=loss, optimizer=optim)
94
102
  >>> model.train(1, dataset)
95
103
  >>> ms.save_checkpoint(net, "./simple.ckpt", False)
96
104
  >>> layout = model.train_network.parameter_layout_dict
@@ -104,17 +112,20 @@ def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
104
112
  ... print("step end, cur step num: ", cb_params.cur_step_num, flush=True)
105
113
  >>> model.train(1, dataset, callbacks=[LossCallBack()])
106
114
  """
107
- if not layout:
115
+ if not layout or get_group_size() <= 1:
108
116
  return
109
117
  from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
110
118
  from mindspore.nn.wrap.cell_wrapper import AllreduceGraph
111
- origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode")
112
- if origin_parallel_mode not in ("semi_auto_parallel", "auto_parallel"):
113
- return
114
- if cur_rank != get_rank():
115
- raise ValueError(f"For parameter broadcast, the cur_rank: {cur_rank} is wrong.")
116
- if initial_rank % (get_group_size() / ms.get_auto_parallel_context("pipeline_stages")) != 0:
117
- raise ValueError(f"For parameter broadcast, the initial_rank: {initial_rank} is wrong.")
119
+ origin_parallel_mode = ""
120
+ pipeline_stages = 1
121
+ parallel_net = _get_auto_parallel_net(net)
122
+ if type(parallel_net).__name__ == 'AutoParallel':
123
+ origin_parallel_mode = _parallel_mode_map(parallel_net._parallel_mode)
124
+ pipeline_stages = parallel_net._pipeline_stages
125
+ else:
126
+ origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode")
127
+ pipeline_stages = ms.get_auto_parallel_context("pipeline_stages")
128
+ _check_rank(cur_rank, initial_rank, pipeline_stages)
118
129
  param_redundancy = get_parameter_redundancy(layout, initial_rank)
119
130
  if not param_redundancy:
120
131
  return
@@ -15,14 +15,77 @@
15
15
  """shard"""
16
16
 
17
17
  import copy
18
+ import numpy as np
18
19
  import mindspore as ms
19
20
  from mindspore import log as logger
20
21
  from mindspore._c_expression import Shard_
21
22
 
22
23
 
24
+ class _DistributedTensorInfo:
25
+ """
26
+ Describe the distributed information of a tensor.
27
+
28
+ Args:
29
+ distributed_info (Union[Layout, DeviceMesh]): The distributed information of a tensor.
30
+
31
+ Raises:
32
+ TypeError: If `distributed_info` is not a Layout type.
33
+
34
+ Examples:
35
+ >>> from mindspore import _DistributedTensorInfo, Layout
36
+ >>> layout = Layout((2, 2), ("dp", "mp"))
37
+ >>> src_layout = layout("dp", "mp")
38
+ >>> distributed_info = _DistributedTensorInfo(src_layout)
39
+ >>> print(distributed_info.sharding_strategy)
40
+ [2, 2]
41
+ """
42
+
43
+ def __init__(self, distributed_info):
44
+ if isinstance(distributed_info, Layout):
45
+ self._layout = distributed_info
46
+ self._distributed_info = distributed_info
47
+ else:
48
+ raise TypeError(
49
+ f"DistributedTensorInfo only supports Layout or DeviceMesh as input, but got {type(distributed_info)}")
50
+ self._sharding_strategy = None
51
+
52
+ @property
53
+ def layout(self):
54
+ """return layout of current tensor"""
55
+ return self._layout
56
+
57
+ @property
58
+ def distributed_info(self):
59
+ """return the distributed info, it depends on user's input """
60
+ return self._distributed_info
61
+
62
+ @property
63
+ def sharding_strategy(self):
64
+ """return the sharding strategy of current tensor"""
65
+ if self._sharding_strategy is None:
66
+ layout_info = self._layout.to_dict()
67
+ device_matrix = layout_info["device_matrix"]
68
+ tensor_map = layout_info["tensor_map"]
69
+ sharding_strategy = []
70
+ for map_value in tensor_map:
71
+ if isinstance(map_value, (tuple, list)):
72
+ shard_size = 1
73
+ for value in map_value:
74
+ if value != -1:
75
+ shard_size *= device_matrix[len(device_matrix) - value - 1]
76
+ sharding_strategy.append(shard_size)
77
+ else:
78
+ if map_value != -1:
79
+ sharding_strategy.append(device_matrix[len(device_matrix) - map_value - 1])
80
+ else:
81
+ sharding_strategy.append(1)
82
+ self._sharding_strategy = sharding_strategy
83
+ return self._sharding_strategy
84
+
85
+
23
86
  class Layout:
24
87
  """
25
- Parallel layout describes the detailed sharding information.
88
+ Topological abstraction describing cluster devices for tensor slice placement on the cluster.
26
89
 
27
90
  Note:
28
91
  - It is valid only in semi auto parallel or auto parallel mode.
@@ -35,28 +98,35 @@ class Layout:
35
98
  alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
36
99
  When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
37
100
  copies on the corresponding partition dimension on a single card.
101
+ rank_list (list, optional): Data is allocated to the device according to rank_list. Default: ``None``.
102
+
38
103
  Raises:
39
104
  TypeError: `device_matrix` is not a tuple type.
40
105
  TypeError: `alias_name` is not a tuple type.
106
+ TypeError: 'rank_list' is not a list type.
41
107
  ValueError: `device_matrix` length is not equal to `alias_name` length.
42
108
  TypeError: The element of `device_matrix` is not int type.
43
109
  TypeError: The element of `alias_name` is not a str type.
110
+ TypeError: The element of `rank_list` is not int type.
44
111
  ValueError: The element of `alias_name` is an empty str.
45
112
  ValueError: The element of `alias_name` is "None".
46
113
  ValueError: `alias_name` contains repeated element.
47
114
 
115
+ Supported Platforms:
116
+ ``Ascend``
117
+
48
118
  Examples:
49
- >>> from mindspore import Layout
119
+ >>> from mindspore.parallel import Layout
50
120
  >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
51
121
  >>> layout0 = layout("dp", "mp")
52
122
  >>> print(layout0.to_dict())
53
- {"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False}
54
- >>> # Total device num is 4, but split the tensor in local device into two copies.
123
+ {'device_matrix': (2, 2, 2), 'tensor_map': (2, 0), 'interleaved_parallel': False,
124
+ 'alias_name': {'dp', 'sp', 'mp'}, 'rank_list': [0, 1, 2, 3, 4, 5, 6, 7]}
55
125
  >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
56
126
  >>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
57
127
  """
58
128
 
59
- def __init__(self, device_matrix, alias_name):
129
+ def __init__(self, device_matrix, alias_name, rank_list=None):
60
130
  if not isinstance(device_matrix, tuple):
61
131
  raise TypeError(f'device_matrix must be tuple type, but got:{type(device_matrix)}')
62
132
  if not isinstance(alias_name, tuple):
@@ -82,6 +152,20 @@ class Layout:
82
152
  self._device_shape = device_matrix
83
153
  self._alias_name = alias_name
84
154
  self._tensor_map = None
155
+ self._rank_list = list(range(np.prod(np.array(self._device_shape))))
156
+ if rank_list is not None:
157
+ if not isinstance(rank_list, list):
158
+ raise TypeError(f"The rank_list should be a list, but got {type(rank_list).__name__}.")
159
+ for in_ele in rank_list:
160
+ if not isinstance(in_ele, int):
161
+ raise TypeError(f"The element of rank_list should be int, but got {type(in_ele).__name__}.")
162
+ if len(np.array(rank_list).shape) != 1:
163
+ raise ValueError(
164
+ f"The rank_list should be a 1-D list, but got {len(np.array(rank_list).shape)}-D list.")
165
+ if len(rank_list) != np.prod(np.array(self._device_shape)):
166
+ raise ValueError(f"The length of rank_list should be equal to the product of device_matrix, "
167
+ f"but got {len(rank_list)} and {np.prod(np.array(self._device_shape))}.")
168
+ self._rank_list = rank_list
85
169
 
86
170
  def __call__(self, *tensor_map):
87
171
  self._tensor_map = ()
@@ -122,8 +206,8 @@ class Layout:
122
206
  raise ValueError("The tensor_map of layout is None")
123
207
  interleaved_parallel = "interleaved_parallel" in self._alias_name
124
208
  return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map,
125
- "interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name}
126
-
209
+ "interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name,
210
+ "rank_list": self._rank_list}
127
211
 
128
212
 
129
213
  class Shard(Shard_):
@@ -141,18 +225,6 @@ class Shard(Shard_):
141
225
  self.level = None
142
226
 
143
227
  def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
144
- parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
145
- if parallel_mode not in ("auto_parallel", "semi_auto_parallel"):
146
- raise AssertionError(
147
- f"Cell shard only supports auto parallel and semi auto parallel.")
148
- if ms.context.get_context("device_target") not in ("Ascend", "GPU"):
149
- raise AssertionError(
150
- f"'Shard' now only supports 'Ascend' and 'GPU'")
151
- if parallel_mode == "auto_parallel" and \
152
- ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
153
- raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard' when the "
154
- f"'parallel_mode' is 'auto_parallel.'")
155
-
156
228
  if not isinstance(in_strategy, tuple):
157
229
  raise TypeError(
158
230
  f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}.")
@@ -181,7 +253,8 @@ class Shard(Shard_):
181
253
  "will be overwritten as False.")
182
254
  ms.set_algo_parameters(fully_use_devices=False)
183
255
 
184
- if ms.context.get_auto_parallel_context("full_batch_is_set") is False:
256
+ if ms.context.get_auto_parallel_context("full_batch_is_set") is False and \
257
+ ms.context.get_context("mode") == ms.context.PYNATIVE_MODE:
185
258
  logger.warning("When calling the shard interface, "
186
259
  "'dataset_strategy' or 'full_batch' is not manually set by the user, "
187
260
  "and the 'dataset_strategy' will be set to 'full_batch'.")
@@ -193,13 +266,13 @@ class Shard(Shard_):
193
266
 
194
267
  if isinstance(fn, ms.nn.Cell):
195
268
  for param in fn.trainable_params():
196
- param.is_in_shard = True
269
+ param.param_info.is_in_pynative_shard = True
197
270
 
198
271
  # Set parameter layout to corresponding parameter
199
272
  self._set_param_layout_into_parameter(fn, parameter_plan)
200
273
 
201
274
  def shard_fn(*args):
202
- @ms.common.jit(hash_args=fn)
275
+ @ms.common.jit(hash_args=fn, backend="ms_backend")
203
276
  def after_shard(*args):
204
277
  return shard_(fn, in_strategy, out_strategy, device, level)(*args)
205
278
 
@@ -290,7 +363,7 @@ class Shard(Shard_):
290
363
  for stra in strategy:
291
364
  if not isinstance(stra, (tuple, Layout)):
292
365
  raise TypeError(
293
- f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.Layout), "
366
+ f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.parallel.Layout), "
294
367
  f"but got {type(stra).__name__}")
295
368
  if isinstance(stra, Layout):
296
369
  strategy_set.add("layout")
@@ -312,7 +385,7 @@ class Shard(Shard_):
312
385
  for in_ele in layout:
313
386
  if not isinstance(in_ele, Layout):
314
387
  raise TypeError(f"The {log_info} item should be a object of class Layout.")
315
- layout_value += (in_ele.to_dict(),)
388
+ layout_value += ({k: v for k, v in in_ele.to_dict().items() if k != "rank_list"},)
316
389
  return layout_value
317
390
 
318
391
  def _check_tuple_strategy(self, dim_strategy):
@@ -323,8 +396,8 @@ class Shard(Shard_):
323
396
 
324
397
  def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
325
398
  """
326
- Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
327
- generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
399
+ Specify the input and output slicing strategy for a Cell or function.
400
+ In PyNative mode, use this method to specify a Cell for distributed
328
401
  execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
329
402
  strategy for others will be set by sharding propagation.
330
403
  in_strategy and out_strategy define the input and output layout respectively.
@@ -334,33 +407,37 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
334
407
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
335
408
 
336
409
  Note:
337
- If ms.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
338
- "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
339
- If the input contain Parameter, its strategy should be set in `in_strategy`.
410
+ - It is valid only in semi auto parallel or auto parallel mode.
411
+ In other parallel modes, strategies set here will be ignored.
412
+ - If the input contain Parameter, its strategy should be set in `in_strategy`.
413
+ - This method currently does not support dynamic shapes.
340
414
 
341
415
  Args:
342
416
  fn (Union[Cell, Function]): Function to be executed in parallel.
343
- Its arguments and return value must be Tensor or Parameter.
417
+ Its arguments and return value must be Tensor.
344
418
  If `fn` is a Cell with parameters, `fn` needs to be an instantiated object,
345
419
  otherwise its arguments cannot be accessed.
346
420
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple(int) or
347
- tuple(mindspore.Layout).
421
+ tuple(mindspore.parallel.Layout).
348
422
  Tuple defines the layout of the corresponding input.
349
- out_strategy (Union[tuple, None]): Define the layout of outputs similar with `in_strategy`.
350
- It is not in use right now. Default: ``None`` .
351
- parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
423
+ out_strategy (Union[tuple, None], optional): Define the layout of outputs similar with `in_strategy`.
424
+ Default: ``None`` .
425
+ parameter_plan (Union[dict, None], optional): Define the layout for the specified parameters.
426
+ Each element in dict
352
427
  defines the layout of the parameter like "param_name: layout".
353
428
  The key is a parameter name of type 'str'.
354
- The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple,
429
+ The value is a 1-D integer tuple or a 1-D mindspore.parallel.Layout tuple,
355
430
  indicating the corresponding layout.
356
431
  If the parameter name is incorrect or the corresponding parameter
357
- has been set, the parameter setting will be ignored.
432
+ has been set, the parameter setting will be ignored. Supported
433
+ only when `fn` is a Cell with parameters.
358
434
  Default: ``None`` .
359
- device (string): Select a certain `device` target. It is not in use right now.
360
- Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
361
- level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
362
- over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
363
- use right now. Support [0, 1, 2]. Default: ``0`` .
435
+ device (str, optional): Select a certain `device` target. It is not in use right now.
436
+ Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
437
+ level (int, optional): Option for parallel strategy infer algorithm, namely the object function,
438
+ maximize computation
439
+ over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
440
+ use right now. Support [0, 1, 2]. Default: ``0`` .
364
441
 
365
442
  Returns:
366
443
  Function, return the function that will be executed under auto parallel process.
@@ -370,26 +447,28 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
370
447
  AssertionError: If device_target it not "Ascend" or "GPU".
371
448
  TypeError: If `in_strategy` is not a tuple.
372
449
  TypeError: If `out_strategy` is not a tuple or None.
373
- TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.Layout).
374
- TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.Layout).
450
+ TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
451
+ TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
375
452
  TypeError: If `parameter_plan` is not a dict or None.
376
453
  TypeError: If any key in `parameter_plan` is not a str.
377
- TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.Layout).
454
+ TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.parallel.Layout).
378
455
  TypeError: If `device` is not a str.
379
456
  TypeError: If `level` is not an integer.
380
457
 
381
458
  Supported Platforms:
382
- ``Ascend`` ``GPU``
459
+ ``Ascend``
383
460
 
384
461
  Examples:
385
462
  >>> import numpy as np
386
463
  >>> import mindspore as ms
387
- >>> from mindspore import Tensor, nn
464
+ >>> from mindspore import Tensor, nn, ops
388
465
  >>> from mindspore.communication import init
466
+ >>> from mindspore.parallel import shard
467
+ >>> from mindspore.parallel import Layout
468
+ >>> from mindspore.nn.utils import no_init_parameters
469
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
389
470
  >>> ms.set_context(mode=ms.GRAPH_MODE)
390
471
  >>> init()
391
- >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
392
- ... device_num=8)
393
472
  >>>
394
473
  >>> # Case 1: cell uses functional
395
474
  >>> class BasicBlock(nn.Cell):
@@ -401,7 +480,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
401
480
  >>> x = ops.abs(x)
402
481
  >>> return x + y
403
482
  >>> # shard a function with tuple(int) strategies
404
- >>> self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
483
+ >>> self.shard_my_add = shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
405
484
  >>>
406
485
  >>> def construct(self, x, u):
407
486
  >>> x = self.gelu(x)
@@ -429,7 +508,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
429
508
  >>> super(Net, self).__init__()
430
509
  >>> # setting cell sharding strategy and parameter_plan by tuple(int)
431
510
  >>> self.layer_net1 = NetForward()
432
- >>> self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
511
+ >>> self.layer_net1_shard = shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
433
512
  ... parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
434
513
  >>>
435
514
  >>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
@@ -437,7 +516,7 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
437
516
  >>> layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
438
517
  >>> in_layout = (layout("dp", "mp"), layout("mp", "sp"))
439
518
  >>> param_layout = layout("dp", "sp")
440
- >>> self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout,
519
+ >>> self.layer_net2_shard = shard(self.layer_net2, in_strategy=in_layout,
441
520
  ... parameter_plan={"self.layer_net2.block2.weight": param_layout})
442
521
  >>> self.flatten = nn.Flatten()
443
522
  >>> self.layer1 = nn.Dense(64, 64)
@@ -455,26 +534,25 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
455
534
  >>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
456
535
  >>> return x
457
536
  >>>
458
- >>> net = Net()
537
+ >>> with no_init_parameters():
538
+ >>> net = Net()
459
539
  >>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
460
540
  >>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
461
- >>> net(x, y)
541
+ >>> parallel_net = AutoParallel(net, parallel_mode='sharding_propagation')
542
+ >>> parallel_net(x, y)
462
543
  >>>
463
544
  >>> # Case 2: function uses functional sharding
464
545
  >>> def test_shard(x, y):
465
546
  ... return x + y
466
547
  >>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
467
548
  >>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
468
- >>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
549
+ >>> output = shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
469
550
  >>> print(output.shape)
470
551
  (32, 10)
471
552
 
472
- Tutorial Examples:
473
- - `Functional Operator Sharding
474
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/shard_function_parallel.html>`_
475
- - `mindspore.Layout
476
- <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html>`_
477
553
  """
554
+ if ms.communication.management.get_group_size() == 1:
555
+ return fn
478
556
  if not isinstance(fn, (ms.nn.Cell)):
479
557
  logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
480
558
  "otherwise, the result may be incorrect.")