mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/_utils.py CHANGED
@@ -16,22 +16,23 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  import os
19
- import threading
20
- from datetime import datetime
19
+ import sys
21
20
  import json
22
21
  from collections.abc import Iterable
23
22
 
23
+ import time
24
24
  import numpy as np
25
25
 
26
26
  from mindspore.common.tensor import Tensor
27
- from mindspore._c_expression import Tensor as Tensor_
27
+ from mindspore._c_expression import TensorPy as Tensor_
28
+ from mindspore._c_expression import MSContext, ms_ctx_param
28
29
  from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
29
30
  from mindspore.common import dtype as mstype
30
31
  from mindspore import context
31
32
  from mindspore import log as logger
32
33
  from mindspore import _checkparam as Validator
33
34
  from mindspore.common.api import _cell_graph_executor
34
- from mindspore.communication import get_group_size
35
+ from mindspore.communication.management import get_rank, get_group_size
35
36
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
36
37
  from mindspore.train.checkpoint_pb2 import Checkpoint
37
38
  from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
@@ -65,6 +66,11 @@ def _get_types_and_shapes(dataset):
65
66
  return dataset_types, dataset_shapes
66
67
 
67
68
 
69
+ def enable_data_broadcast():
70
+ """Get status to indicate if enable dataset broadcast."""
71
+ return MSContext.get_instance().get_param(ms_ctx_param.dataset_broadcast_opt_level) > 0
72
+
73
+
68
74
  def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
69
75
  """Initialize and execute the dataset graph."""
70
76
  batch_size = exec_dataset.get_batch_size()
@@ -77,15 +83,12 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
77
83
  if queue_name is None:
78
84
  queue_name = str("")
79
85
 
86
+ # Don't enable dynamic shape(multi-subgraph) feature in pp/data_broadcast mode,
87
+ # otherwise get_data_info will stuck since some rank do not consume data.
80
88
  use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
89
+ data_broadcast = enable_data_broadcast()
81
90
 
82
- # temp env to disable dynamic feature of sink size 1
83
- dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
84
- dynamic_sink1 = True
85
- if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
86
- dynamic_sink1 = False
87
-
88
- if use_pipeline_parallel or not dynamic_sink1:
91
+ if use_pipeline_parallel or data_broadcast:
89
92
  create_data_info_queue = False
90
93
 
91
94
  exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
@@ -374,20 +377,40 @@ def _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundan
374
377
  param_redundancy_dict[key] = tuple(redundancy_list)
375
378
 
376
379
 
377
- def get_parameter_redundancy(layout_obj, initial_rank=0):
380
+ def _get_initial_rank(parameter_layout):
381
+ """Get the initial rank of pp."""
382
+ for k, _ in parameter_layout.items():
383
+ dev_matrix = parameter_layout[k][0]
384
+ break
385
+ dev_num = 1
386
+ if dev_matrix:
387
+ for i in dev_matrix:
388
+ dev_num *= i
389
+ rank_id = get_rank()
390
+ initial_rank = (rank_id // dev_num) * dev_num
391
+ return initial_rank
392
+
393
+
394
+ def _get_pp_size_from_redundancy_map(param_redundancy):
395
+ """Get pp size from redundancy map."""
396
+ for _, v in param_redundancy.items():
397
+ return len(v) * len(v[0])
398
+
399
+
400
+ def get_parameter_redundancy(layout_obj, initial_rank=None):
378
401
  """
379
402
  Get parameter redundancy map.
380
403
 
381
404
  Args:
382
405
  layout_obj (Union[str, layout): File name of `strategy.ckpt` or net.parameter_layout_dict.
383
- initial_rank (int): Start rank id for each pipeline. Default: 0.
406
+ initial_rank (int): Start rank id for each pipeline. Default: ``None``.
384
407
 
385
408
  Returns:
386
409
  Dict, dict of parameter redundancy info.
387
410
 
388
411
  Examples:
389
412
  >>> from mindspore.train.utils import get_parameter_redundancy
390
- >>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt")
413
+ >>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt", initial_rank=0)
391
414
  {'param1': ((0, 1, 2, 3, 4, 5, 6, 7),),
392
415
  'param2': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)),
393
416
  'param3': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)),
@@ -404,7 +427,8 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
404
427
  from mindspore.communication.management import get_process_group_ranks
405
428
  groups_ranks = (tuple(get_process_group_ranks()),)
406
429
  param_redundancy_dict = {param.name: groups_ranks for _, param in layout_obj.parameters_and_names()}
407
- return param_redundancy_dict
430
+ sorted_param_redundancy_dict = {key: param_redundancy_dict[key] for key in sorted(param_redundancy_dict.keys())}
431
+ return sorted_param_redundancy_dict
408
432
  else:
409
433
  parameter_layout = {}
410
434
  for k, v in layout_obj.items():
@@ -412,6 +436,9 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
412
436
 
413
437
  param_redundancy_dict = {}
414
438
 
439
+ if initial_rank is None:
440
+ initial_rank = _get_initial_rank(parameter_layout)
441
+
415
442
  _get_parameter_redundancy_without_opt_shard(parameter_layout, param_redundancy_dict, initial_rank)
416
443
 
417
444
  if isinstance(layout_obj, str):
@@ -419,7 +446,8 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
419
446
  else:
420
447
  _get_layout_opt_shard(layout_obj, param_redundancy_dict)
421
448
 
422
- return param_redundancy_dict
449
+ sorted_param_redundancy_dict = {key: param_redundancy_dict[key] for key in sorted(param_redundancy_dict.keys())}
450
+ return sorted_param_redundancy_dict
423
451
 
424
452
 
425
453
  def _collect_settings_by_rank(redundancy_map):
@@ -514,12 +542,47 @@ def parse_hccl_file(hccl_file_path):
514
542
  return rankid_dict
515
543
 
516
544
 
517
- def vlog_print(level, module, file, line, message):
518
- '''Read environment variable VLOG_v and print to log'''
519
- if os.environ.get("VLOG_v") == level:
520
- now = datetime.now()
521
- formatted_time = now.strftime("%Y-%m-%d-%H:%M:%S.%f")[:-3] + f".{now.microsecond // 1000}"
522
- path = 'mindspore' + file.split("mindspore")[-1]
523
- pid = os.getpid()
524
- thread_id = threading.get_ident()
525
- print(f"[V{level}] {module}({pid},{thread_id},python):{formatted_time} [{path}:{line}] {message}", flush=True)
545
+ def _progress_bar(iterable, total=None):
546
+ """
547
+ Decorate an iterable object, returning an iterator which acts exactly
548
+ like the original iterable, but prints a dynamically updating
549
+ progressbar every time a value is requested.
550
+ """
551
+ if total is None:
552
+ total = len(iterable)
553
+
554
+ start_time = time.time()
555
+
556
+ def print_progress_bar(iteration):
557
+ percent = f"{100 * (iteration / float(total)):.1f}"
558
+ bar_length = 40
559
+ filled_length = int(bar_length * iteration // total)
560
+ bar = '█' * filled_length + '-' * (bar_length - filled_length)
561
+
562
+ elapsed_time = time.time() - start_time
563
+ estimated_total_time = elapsed_time / iteration * total
564
+ remaining_time = estimated_total_time - elapsed_time
565
+
566
+ elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
567
+ remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
568
+
569
+ sys.stdout.reconfigure(encoding="utf-8")
570
+ print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
571
+ if iteration == total:
572
+ print()
573
+
574
+ for i, item in enumerate(iterable, start=1):
575
+ yield item
576
+ print_progress_bar(i)
577
+
578
+
579
+ def _load_and_transform(path, name_map, load_func, transform_func):
580
+ if load_func is not None:
581
+ param_dict = load_func(path)
582
+ else:
583
+ param_dict = path
584
+ transform_dict = {}
585
+ for k, v in param_dict.items():
586
+ new_name = name_map.get(k, k) if name_map is not None else k
587
+ transform_dict[new_name] = transform_func(v, new_name)
588
+ return transform_dict
mindspore/train/amp.py CHANGED
@@ -101,6 +101,7 @@ AMP_AUTO_BLACK_LIST = [
101
101
  P.LayerNorm,
102
102
  gen.LayerNormExt,
103
103
  P.BatchNorm,
104
+ gen.BatchNormExt,
104
105
  gen.GroupNorm,
105
106
  P.KLDivLoss,
106
107
  P.SmoothL1Loss,
@@ -112,6 +113,7 @@ AMP_AUTO_BLACK_LIST = [
112
113
  P.Pdist,
113
114
  P.Cdist,
114
115
  P.Renorm,
116
+ gen.MSELossExt,
115
117
  ]
116
118
 
117
119
  # Indicates which inputs of primitives need to be converted
@@ -428,15 +430,15 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
428
430
 
429
431
  ``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
430
432
  ``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
431
- ``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
433
+ ``BatchNormExt``, ``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
432
434
  ``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
433
435
  ``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
434
- ``Norm``
436
+ ``Norm``, ``MSELossExt``
435
437
 
436
438
  Operators in `promote_list` are:
437
439
 
438
440
  ``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
439
- ``BiasAdd``
441
+ ``BiasAdd``, ``AddN``, ``Concat``
440
442
 
441
443
  For details on automatic mixed precision, refer to
442
444
  `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
@@ -636,7 +638,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
636
638
 
637
639
 
638
640
  def _is_grad_accumulation(mcell):
639
- if mcell.cls_name == "GradAccumulationCell":
641
+ if mcell.cls_name == "GradAccumulationCell" or mcell.cls_name == "GradAccumulation":
640
642
  return True
641
643
  for cell in mcell.cells():
642
644
  if _is_grad_accumulation(cell):
@@ -837,12 +839,14 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
837
839
  - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
838
840
  can result in a larger network hierarchy and slower performance.
839
841
  - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
840
- mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
842
+ mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` or `level`
841
843
  need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
842
844
  - Primitives for blacklist is not support yet.
843
845
 
844
846
  Args:
845
847
  network (Cell): Definition of the network.
848
+
849
+ Keyword Args:
846
850
  white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
847
851
  white list is not used.
848
852
  black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
@@ -36,9 +36,9 @@ from mindspore.train.callback._reduce_lr_on_plateau import ReduceLROnPlateau
36
36
  from mindspore.train.callback._on_request_exit import OnRequestExit
37
37
  from mindspore.train.callback._backup_and_restore import BackupAndRestore
38
38
  from mindspore.train.callback._flops_collector import FlopsUtilizationCollector
39
- from mindspore.train.callback._tft_register import TFTRegister
39
+ from mindspore.train.callback._train_fault_tolerance import TrainFaultTolerance
40
40
 
41
41
  __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "FlopsUtilizationCollector",
42
42
  "SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
43
43
  "History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping", "OnRequestExit", "BackupAndRestore",
44
- "TFTRegister"]
44
+ "TrainFaultTolerance"]
@@ -121,10 +121,7 @@ class Callback:
121
121
  When creating a custom Callback, model context information can be obtained in Callback
122
122
  methods by calling `RunContext.original_args()`, which is a dictionary varivable
123
123
  recording current attributes. Users can add custimized attributes to the information.
124
- Training process can also be stopped by calling `request_stop` method. For details
125
- of custom Callback, please check
126
- `Callback tutorial <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/
127
- callback.html#customized-callback-mechanism>`_.
124
+ Training process can also be stopped by calling `request_stop` method.
128
125
 
129
126
  Examples:
130
127
  >>> import numpy as np
@@ -491,9 +488,7 @@ class RunContext:
491
488
 
492
489
  Callback objects not only can obtain the Model context information by calling by
493
490
  `RunContext.original_args()` and add extra attributes to the information, but also can stop the
494
- training process by calling `request_stop` method. For details of custom Callback,
495
- please check
496
- `Callback Mechanism <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/callback.html>`_.
491
+ training process by calling `request_stop` method.
497
492
 
498
493
  `RunContext.original_args()` holds the model context information as a dictionary variable, and
499
494
  different attributes of the dictionary are stored in training or eval process. Details are as follows:
@@ -572,10 +567,6 @@ class RunContext:
572
567
 
573
568
  Returns:
574
569
  Dict, an object that holds the original arguments of model.
575
-
576
- Tutorial Examples:
577
- - `Callback Mechanism - Customized Callback Mechanism
578
- <https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#customized-callback-mechanism>`_
579
570
  """
580
571
  return self._original_args
581
572
 
@@ -585,11 +576,6 @@ class RunContext:
585
576
 
586
577
  Callbacks can use this function to request stop of iterations.
587
578
  model.train() checks whether this is called or not.
588
-
589
- Tutorial Examples:
590
- - `Callback Mechanism - Customized Training Termination Time
591
- <https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#
592
- customized-training-termination-time>`_
593
579
  """
594
580
  self._stop_requested = True
595
581
 
@@ -18,25 +18,22 @@ from __future__ import absolute_import
18
18
  import os
19
19
  import stat
20
20
  import time
21
- import threading
22
21
 
23
22
  import mindspore.context as context
24
23
  from mindspore import log as logger
25
24
  from mindspore import nn
26
25
  from mindspore import _checkparam as Validator
27
26
  from mindspore.train._utils import _make_directory
28
- from mindspore.train.serialization import save_checkpoint, _save_graph
27
+ from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
28
+ _wait_async_thread_save_ckpt, _check_async_save
29
29
  from mindspore.parallel._cell_wrapper import destroy_allgather_cell
30
30
  from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
31
- from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
32
- from mindspore.parallel._utils import _get_device_num
33
- from mindspore.communication.management import get_rank
34
- from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
35
- from mindspore.train.callback._callback import Callback, set_cur_net
31
+ from mindspore.communication.management import get_rank, get_group_size
32
+ from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy, _get_pp_size_from_redundancy_map
33
+ from mindspore.train.callback._callback import Callback
36
34
  from mindspore.common.tensor import Tensor
37
35
  from mindspore.common.parameter import Parameter
38
36
  from mindspore.common.generator import Generator
39
- from mindspore.common.api import _cell_graph_executor
40
37
  from mindspore._c_expression import collect_host_info, get_clock_syscnt
41
38
 
42
39
  _cur_dir = os.getcwd()
@@ -44,15 +41,6 @@ SAVE_DIR = _cur_dir
44
41
  _info_list = ["epoch_num", "step_num"]
45
42
 
46
43
 
47
- def _wait_async_save_ckpt(async_save=False):
48
- """Waiting for asynchronous saving of ckpt to complete."""
49
- if async_save:
50
- thread_list = threading.enumerate()
51
- for thread in thread_list:
52
- if thread.getName() == "asyn_save_ckpt":
53
- thread.join()
54
-
55
-
56
44
  def _get_dp_tp_from_redundancy(redundancy_tuple):
57
45
  """From redundancy get dp and tp"""
58
46
  dp = []
@@ -76,6 +64,15 @@ def _get_dp_tp_from_layout(parameter_redundancy_dict):
76
64
  return dp, tp
77
65
 
78
66
 
67
+ def _wait_async_save_ckpt(async_save=False):
68
+ """Waiting for asynchronous saving of ckpt to complete."""
69
+ if async_save:
70
+ if async_save == "process":
71
+ _wait_async_process_save_ckpt()
72
+ else:
73
+ _wait_async_thread_save_ckpt()
74
+
75
+
79
76
  def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
80
77
  """Check if there is a file with the same name."""
81
78
  if callable(prefix) or callable(directory):
@@ -87,7 +84,7 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
87
84
  name_ext = os.path.splitext(filename)
88
85
  if exception and filename[-16:] != "_breakpoint.ckpt":
89
86
  continue
90
- if not exception and (name_ext[-1] != ".ckpt" or filename[-16:] == "_breakpoint.ckpt"):
87
+ if not exception and (name_ext[-1] not in (".ckpt", ".safetensors") or filename[-16:] == "_breakpoint.ckpt"):
91
88
  continue
92
89
  # find same prefix file
93
90
  if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
@@ -106,10 +103,10 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
106
103
  return prefix
107
104
 
108
105
 
109
- def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, exception_save=False,
106
+ def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, exception_save=False,
110
107
  map_param_inc=False, global_step_num=None):
111
- param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
112
- or exception_save or map_param_inc or global_step_num is not None)
108
+ param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or exception_save or map_param_inc
109
+ or global_step_num is not None)
113
110
  if format == "safetensors" and param_not_default:
114
111
  raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
115
112
 
@@ -139,7 +136,10 @@ class CheckpointConfig:
139
136
  integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
140
137
  Integrated save function is only supported in automatic parallel scene, not supported
141
138
  in manual parallel. Default: ``True`` .
142
- async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: ``False`` .
139
+ async_save (Union[bool, str], optional):Whether to use asynchronous saving of the checkpoint file or
140
+ safetensors file, if True, the asynchronous thread is used by default. If the type
141
+ is string, the method of asynchronous saving, it can be "process" or "thread".
142
+ Default: ``False`` .
143
143
  saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
144
144
  with the network in training, the initial value of saved_network will be saved. Default: ``None`` .
145
145
  append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and
@@ -247,7 +247,7 @@ class CheckpointConfig:
247
247
  self._keep_checkpoint_max = 1
248
248
 
249
249
  self._integrated_save = Validator.check_bool(integrated_save)
250
- self._async_save = Validator.check_bool(async_save)
250
+ self._async_save = _check_async_save(async_save)
251
251
  self._saved_network = saved_network
252
252
  self._append_dict = self._handle_append_info(append_info)
253
253
  self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
@@ -258,8 +258,7 @@ class CheckpointConfig:
258
258
  self.enable_redundance = kwargs.get('enable_redundance', False)
259
259
  self.remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
260
260
 
261
- _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, exception_save,
262
- self._map_param_inc)
261
+ _check_format_and_other_params(format, enc_key, enc_mode, crc_check, exception_save, self._map_param_inc)
263
262
 
264
263
  @property
265
264
  def save_checkpoint_steps(self):
@@ -313,10 +312,10 @@ class CheckpointConfig:
313
312
  @property
314
313
  def async_save(self):
315
314
  """
316
- Get the value of whether asynchronous execution saves the checkpoint to a file.
315
+ Get the value of whether or how asynchronous execution saves the checkpoint to a file.
317
316
 
318
317
  Returns:
319
- bool, whether asynchronous execution saves the checkpoint to a file.
318
+ (bool, str), whether or how asynchronous execution saves the checkpoint to a file.
320
319
  """
321
320
  return self._async_save
322
321
 
@@ -449,8 +448,9 @@ class ModelCheckpoint(Callback):
449
448
  Note:
450
449
  In the distributed training scenario, please specify different directories for each training process
451
450
  to save the checkpoint file. Otherwise, the training may fail.
452
- If this callback is used in the `model` function, the checkpoint file will saved
453
- parameters of the optimizer by default.
451
+ If this callback is used in the
452
+ `Model <https://www.mindspore.cn/docs/en/master/api_python/train/mindspore.train.Model.html>`_ function,
453
+ the checkpoint file will saved parameters of the optimizer by default.
454
454
 
455
455
  Args:
456
456
  prefix (Union[str, callable object]): The prefix name or callable object to generate name of checkpoint files.
@@ -511,7 +511,7 @@ class ModelCheckpoint(Callback):
511
511
  if callable(prefix):
512
512
  self._prefix_func = prefix
513
513
 
514
- if _get_recovery_context("enable_recovery"):
514
+ if context.get_context("device_target") == "GPU" and _get_recovery_context("enable_recovery"):
515
515
  _set_recovery_context(ckpt_path=self._directory)
516
516
 
517
517
  if config is None:
@@ -538,6 +538,8 @@ class ModelCheckpoint(Callback):
538
538
  self._graph_saved = False
539
539
  self._need_flush_from_cache = True
540
540
  self._map_param_inc = self._config.map_param_inc
541
+ self._d2h_async = os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
542
+ self._run_mode = context.get_context("mode")
541
543
 
542
544
  def step_end(self, run_context):
543
545
  """
@@ -551,19 +553,17 @@ class ModelCheckpoint(Callback):
551
553
  from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
552
554
  ckpt_storage_path = self._directory
553
555
  rank_id = get_rank()
554
- stage_num = _get_auto_parallel_context("pipeline_stages")
555
- stage_rank_num = _get_device_num() // stage_num
556
+ device_num = get_group_size()
556
557
  param_layout = cb_params.train_network.parameter_layout_dict
557
558
  if not param_layout:
558
- layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num, "stage_layout": None}
559
+ layout = {"stage_num": 1, "stage_rank_num": device_num, "stage_layout": None}
559
560
  aiturbo.init(ckpt_storage_path, rank_id, layout, None, False, None)
560
561
  else:
561
- device_num = _get_device_num()
562
- chunk_size = device_num // stage_num
563
- initial_rank = (rank_id // chunk_size) * chunk_size
564
- param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
562
+ param_redundancy_dict = get_parameter_redundancy(param_layout)
563
+ pp_size = _get_pp_size_from_redundancy_map(param_redundancy_dict)
564
+ stage_num = device_num // pp_size
565
565
  dp, _ = _get_dp_tp_from_layout(param_redundancy_dict)
566
- layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num,
566
+ layout = {"stage_num": stage_num, "stage_rank_num": pp_size,
567
567
  "stage_layout": param_redundancy_dict}
568
568
  single_params = remove_param_redundancy(param_redundancy_dict)
569
569
  single_params = {device_id: list(params) for device_id, params in single_params.items()}
@@ -632,6 +632,13 @@ class ModelCheckpoint(Callback):
632
632
  if "step_num" in self._append_dict:
633
633
  self._append_dict["step_num"] = self._append_step_num + step_num
634
634
 
635
+ def _update_save_step(self, cb_params):
636
+ """update step if used async d2h copy"""
637
+ step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
638
+ if self._d2h_async and self._run_mode == context.GRAPH_MODE:
639
+ step_num_in_epoch -= 1
640
+ return step_num_in_epoch
641
+
635
642
  def _save_ckpt(self, cb_params, force_to_save=False):
636
643
  """Save checkpoint files."""
637
644
  if cb_params.cur_step_num == self._last_triggered_step:
@@ -642,10 +649,12 @@ class ModelCheckpoint(Callback):
642
649
  self._flush_from_cache(cb_params)
643
650
 
644
651
  save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
645
- step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
652
+ step_num_in_epoch = self._update_save_step(cb_params)
646
653
 
647
654
  if save_ckpt:
655
+
648
656
  _wait_async_save_ckpt(self._config.async_save)
657
+
649
658
  if self._prefix_func:
650
659
  cur_ckpoint_file = self._prefix + f".{self._config.format}"
651
660
  else:
@@ -670,12 +679,6 @@ class ModelCheckpoint(Callback):
670
679
  self._last_time_for_keep = time.time()
671
680
  self._last_triggered_step = cb_params.cur_step_num
672
681
 
673
- # TODO(MS_DISABLE_REF_MODE): Delete when remove MS_DISABLE_REF_MODE env.
674
- if context.get_context("enable_ge") and os.getenv('MS_DISABLE_REF_MODE') \
675
- and context.get_context("mode") == context.GRAPH_MODE:
676
- set_cur_net(cb_params.train_network)
677
- cb_params.train_network.add_flags(ge_sync_data=True)
678
- _cell_graph_executor(cb_params.train_network, phase='save')
679
682
  self._append_dict_content(cb_params.cur_epoch_num, cb_params.cur_step_num)
680
683
  network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
681
684
  if os.getenv("AITURBO") == "1":
@@ -684,18 +687,13 @@ class ModelCheckpoint(Callback):
684
687
  crc_check=self._config.crc_check, incremental=self._map_param_inc,
685
688
  global_step_num=cb_params.cur_step_num)
686
689
  elif self._config.remove_redundancy:
687
- parallel_mode = context.get_auto_parallel_context("parallel_mode")
688
- if parallel_mode == "stand_alone":
690
+ if get_group_size() == 1:
689
691
  raise TypeError(f"The deduplication feature for saving checkpoint can only be used "
690
- f"in parallel scenarios, but got {parallel_mode}.")
692
+ f"in parallel scenarios, but got 'stand_alone'.")
691
693
  param_layout = network.parameter_layout_dict
692
694
  rank_id = get_rank()
693
695
  if param_layout:
694
- device_num = _get_device_num()
695
- stage_num = _get_auto_parallel_context("pipeline_stages")
696
- chunk_size = device_num // stage_num
697
- initial_rank = (rank_id // chunk_size) * chunk_size
698
- param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
696
+ param_redundancy_dict = get_parameter_redundancy(param_layout)
699
697
  single_params = remove_param_redundancy(param_redundancy_dict)
700
698
  save_param_names = single_params.get(rank_id)
701
699
  param_layout_set = set(param_layout.keys())
@@ -704,14 +702,14 @@ class ModelCheckpoint(Callback):
704
702
  f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.")
705
703
 
706
704
  def choice_func(x):
707
- return x not in param_layout_set or x in save_param_names
705
+ return x not in param_layout_set or (save_param_names is not None and x in save_param_names)
708
706
  else:
709
707
  param_redundancy_dict = get_parameter_redundancy(network)
710
708
  single_params = remove_param_redundancy(param_redundancy_dict)
711
709
  save_param_names = single_params.get(rank_id)
712
710
 
713
711
  def choice_func(x):
714
- return x in save_param_names
712
+ return save_param_names is not None and x in save_param_names
715
713
  save_checkpoint(network, cur_file, False, self._config.async_save,
716
714
  self._append_dict, self._config.enc_key, self._config.enc_mode,
717
715
  crc_check=self._config.crc_check, format=self._config.format,
@@ -24,9 +24,8 @@ from threading import RLock
24
24
  from mindspore.train.callback._callback import Callback
25
25
  from mindspore.communication.management import get_rank, get_local_rank
26
26
  from mindspore import log as logger
27
- from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
28
27
  from mindspore.parallel._utils import _get_device_num
29
- from mindspore.train._utils import get_parameter_redundancy
28
+ from mindspore.train._utils import get_parameter_redundancy, _get_pp_size_from_redundancy_map
30
29
 
31
30
  _perf_mutex = RLock()
32
31
 
@@ -42,7 +41,7 @@ def _get_dp_tp_from_redundancy(redundancy_tuple):
42
41
  return dp, tp
43
42
 
44
43
 
45
- def _get_dp_tp_from_layout(parameter_layout_dict, initial_rank=0):
44
+ def _get_dp_tp_from_layout(parameter_layout_dict, initial_rank=None):
46
45
  """From layout dict get dp and tp"""
47
46
  tp = []
48
47
  dp = []
@@ -132,21 +131,9 @@ class ClusterMonitor(Callback):
132
131
  self.full_path = self.log_path + self.log_name
133
132
 
134
133
  self.write_dp_tp_flag = True
135
- self.initial_rank = 0
136
134
 
137
135
  def begin(self, run_context):
138
136
  _remove_pre_log()
139
- pp_num = _get_auto_parallel_context("pipeline_stages")
140
- device_num = _get_device_num()
141
-
142
- original_list = list(range(device_num))
143
- chunk_size = device_num // pp_num
144
- split_pp_lists = []
145
- for i in range(0, device_num, chunk_size):
146
- end_index = i + chunk_size if i + chunk_size <= device_num else device_num
147
- split_pp_lists.append(original_list[i:end_index])
148
-
149
- self.initial_rank = (self.global_rank // chunk_size) * chunk_size
150
137
  with _perf_mutex:
151
138
  dir_path = os.path.dirname(self.full_path)
152
139
  if not os.path.exists(dir_path):
@@ -157,8 +144,6 @@ class ClusterMonitor(Callback):
157
144
  with open(self.full_path, 'w') as file:
158
145
  log_message = f'UUID:{self.uuid_value}\nFRAMEWORK:{self.frame_work}\nGLOBAL RANKID:{self.global_rank}\n'
159
146
  file.write(log_message)
160
- for _, split_pp_list in enumerate(split_pp_lists):
161
- file.write(f'PP:{split_pp_list}\n')
162
147
  os.chmod(self.full_path, stat.S_IRUSR)
163
148
 
164
149
  def step_begin(self, run_context):
@@ -183,10 +168,21 @@ class ClusterMonitor(Callback):
183
168
  if self.enabled and self.enabled_dtp_group and self.write_dp_tp_flag:
184
169
  cb_params = run_context.original_args()
185
170
  param_layout_dict = cb_params.train_network.parameter_layout_dict
186
- dp, tp = _get_dp_tp_from_layout(param_layout_dict, self.initial_rank)
171
+ device_num = _get_device_num()
172
+ original_list = list(range(device_num))
173
+ param_redundancy_dict = get_parameter_redundancy(param_layout_dict)
174
+ pp_size = _get_pp_size_from_redundancy_map(param_redundancy_dict)
175
+ split_pp_lists = []
176
+ for i in range(0, device_num, pp_size):
177
+ end_index = i + pp_size if i + pp_size <= device_num else device_num
178
+ split_pp_lists.append(original_list[i:end_index])
179
+ dp, tp = _get_dp_tp_from_layout(param_layout_dict)
180
+
187
181
  with _perf_mutex:
188
182
  os.chmod(self.full_path, stat.S_IWUSR)
189
183
  with open(self.full_path, 'a') as file:
184
+ for _, split_pp_list in enumerate(split_pp_lists):
185
+ file.write(f'PP:{split_pp_list}\n')
190
186
  for dp_value in dp:
191
187
  file.write(f'dp:{dp_value}\n')
192
188
  for tp_value in tp:
@@ -198,7 +198,7 @@ class EarlyStopping(Callback):
198
198
  """
199
199
  Get the monitor value at the end of epoch during training.
200
200
 
201
- If `mindspore.train.callback.ReduceLROnPlateau` used with `model.train`, no evaluation process
201
+ If :class:`mindspore.train.callback.ReduceLROnPlateau` used with `model.train`, no evaluation process
202
202
  during training, only monitor="loss" is valid; if it used with `model.fit`, evaluation process will be
203
203
  performed at the end of epoch, valid monitor is "loss", "eval_loss" and metrics passed to `Model`.
204
204