mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,63 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Stress detect."""
16
+ from mindspore import _c_expression
17
+ from mindspore import log as logger
18
+ from mindspore.communication import init, create_group, get_rank
19
+ from mindspore.communication import get_local_rank_size
20
+
21
+
22
+ def stress_detect(detect_type="aic"):
23
+ """
24
+ Used to detect whether there are faults in hardware accuracy or communication between links.
25
+ The common usage scenario is to initiate a new thread or call this interface through a Callback function
26
+ at each step or when saving checkpoints, to check whether hardware malfunctions could affect accuracy.
27
+
28
+ Args:
29
+ detect_type (str, optional): The type of stress test to perform. There are two options available: ``'aic'`` and
30
+ ``'hccs'``, which perform AiCore and HCCS link stress tests on the device, respectively. Default: "aic".
31
+
32
+ Returns:
33
+ int, the return value represents the error type. 0 indicates normal. 1 indicates failure to start some or
34
+ all test cases. 2 indicates a hardware failure, and it is recommended to replace the device.
35
+
36
+ Supported Platforms:
37
+ ``Ascend``
38
+
39
+ Examples:
40
+ >>> from mindspore.tools import stress_detect
41
+ >>> ret = stress_detect()
42
+ >>> print(ret)
43
+ 0
44
+ """
45
+ if detect_type not in ["aic", "hccs"]:
46
+ logger.error(f"For stress detect, detection type must be 'aic' or 'hccs'."
47
+ f"But got {detect_type}. Exiting stress detect.")
48
+ return 1
49
+
50
+ if detect_type == "aic":
51
+ return _c_expression.stress_detect("aic")
52
+
53
+ init()
54
+ local_ranks = []
55
+ local_rank_size = get_local_rank_size()
56
+ node_num = get_rank() // local_rank_size
57
+ for i in range(local_rank_size):
58
+ local_ranks.append(local_rank_size * node_num + i)
59
+ if get_rank() in local_ranks:
60
+ group = f"new_group_{node_num}"
61
+ create_group(group, local_ranks)
62
+
63
+ return _c_expression.stress_detect(group)
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -25,8 +25,8 @@ from mindspore.train import amp
25
25
  from mindspore.train.amp import build_train_network
26
26
  from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
27
27
  from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
28
- load, parse_print, async_ckpt_thread_status, convert_model, export_split_mindir, \
29
- load_checkpoint_async, check_checkpoint, get_ckpt_path_with_strategy, ckpt_to_safetensors, safetensors_to_ckpt, \
28
+ load, async_ckpt_thread_status, export_split_mindir, \
29
+ load_checkpoint_async, get_ckpt_path_with_strategy, ckpt_to_safetensors, safetensors_to_ckpt, \
30
30
  build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, restore_group_info_list
31
31
  from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
32
32
  CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, FlopsUtilizationCollector, \
@@ -37,9 +37,9 @@ from mindspore.train.metrics import *
37
37
  from mindspore.train.data_sink import data_sink
38
38
 
39
39
  __all__ = ["Model", "DatasetHelper", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
40
- "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "check_checkpoint",
41
- "load_param_into_net", "export", "load", "export_split_mindir", "parse_print", "async_ckpt_thread_status",
42
- "convert_model", "data_sink", "load_checkpoint_async", "get_ckpt_path_with_strategy", "ckpt_to_safetensors",
40
+ "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
41
+ "load_param_into_net", "export", "load", "export_split_mindir", "async_ckpt_thread_status",
42
+ "data_sink", "load_checkpoint_async", "get_ckpt_path_with_strategy", "ckpt_to_safetensors",
43
43
  "safetensors_to_ckpt", "build_searched_strategy", "merge_sliced_parameter", "load_distributed_checkpoint",
44
44
  "restore_group_info_list"]
45
45
  __all__.extend(callback.__all__)
mindspore/train/_utils.py CHANGED
@@ -26,7 +26,7 @@ import numpy as np
26
26
  from mindspore.common.tensor import Tensor
27
27
  from mindspore._c_expression import TensorPy as Tensor_
28
28
  from mindspore._c_expression import MSContext, ms_ctx_param
29
- from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
29
+ from mindspore.common.dtype import _dtype_to_nptype, _pytype_to_dtype
30
30
  from mindspore.common import dtype as mstype
31
31
  from mindspore import context
32
32
  from mindspore import log as logger
@@ -54,7 +54,7 @@ def _convert_type(types):
54
54
  """
55
55
  ms_types = []
56
56
  for np_type in types:
57
- ms_type = pytype_to_dtype(np_type)
57
+ ms_type = _pytype_to_dtype(np_type) # pylint:disable=protected-access
58
58
  ms_types.append(ms_type)
59
59
  return ms_types
60
60
 
@@ -131,7 +131,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
131
131
  new_shape += (item * batch_expand_num,)
132
132
  else:
133
133
  new_shape += (item,)
134
- tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)), dtype=type_)
134
+ tensor = Tensor(np.zeros(new_shape, _dtype_to_nptype(type_)), dtype=type_) # pylint:disable=protected-access
135
135
  tensor.virtual_flag = True
136
136
  tensor_list.append(tensor)
137
137
  return tensor_list
@@ -344,15 +344,7 @@ def _get_layout_opt_shard(layout_obj, param_redundancy_dict):
344
344
  """Layout ckpt append opt shard."""
345
345
  for key, value in layout_obj.items():
346
346
  if value[5]:
347
- world_groups = ("hccl_world_group", "nccl_world_group", "mccl_world_group")
348
- if value[5] in world_groups:
349
- opt_para_num = get_group_size()
350
- elif "-" in value[5]:
351
- opt_para_str = value[5].split("-")[0]
352
- opt_para_num = int(opt_para_str)
353
- else:
354
- raise ValueError(f"For get_parameter_redundancy, the format of the parallel communication domain for "
355
- f"the optimizer is incorrect.")
347
+ opt_para_num = get_group_size(value[5])
356
348
  param_redundancy_ranks = param_redundancy_dict.get(key)
357
349
  res = []
358
350
  for param_ranks in param_redundancy_ranks:
@@ -582,17 +574,12 @@ def _progress_bar(iterable, total=None):
582
574
  print_progress_bar(i)
583
575
 
584
576
 
585
- def _load_and_transform(path, name_map, load_func, transform_func=None):
577
+ def _load_and_transform(path, name_map, load_func):
586
578
  """use load_func to load and use transform_func to convert"""
587
- if load_func is not None:
588
- param_dict = load_func(path)
589
- else:
590
- param_dict = path
579
+ param_dict = load_func(path)
591
580
  transform_dict = {}
581
+
592
582
  for k, v in param_dict.items():
593
583
  new_name = name_map.get(k, k) if name_map is not None else k
594
- if transform_func is not None:
595
- transform_dict[new_name] = transform_func(v, new_name)
596
- else:
597
- transform_dict[new_name] = v
584
+ transform_dict[new_name] = v
598
585
  return transform_dict
mindspore/train/amp.py CHANGED
@@ -463,9 +463,6 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
463
463
  ``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
464
464
  ``BiasAdd``, ``AddN``, ``Concat``
465
465
 
466
- For details on automatic mixed precision, refer to
467
- `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
468
-
469
466
  Note:
470
467
  - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
471
468
  can result in a larger network hierarchy and slower performance.
@@ -821,8 +818,10 @@ def get_white_list():
821
818
  <class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
822
819
  <class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
823
820
  <class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
824
- <class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
825
- <class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
821
+ <class 'mindspore.ops.auto_generate.gen_ops_prim.MatMul'>,
822
+ <class 'mindspore.ops.auto_generate.gen_ops_prim.BatchMatMul'>,
823
+ <class 'mindspore.ops.auto_generate.gen_ops_prim.PReLU'>,
824
+ <class 'mindspore.ops.auto_generate.gen_ops_prim.ReLU'>,
826
825
  <class 'mindspore.ops.operations.math_ops.Ger'>]
827
826
  """
828
827
  white_list = AMP_WHITE_LIST.copy()
@@ -874,8 +873,8 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
874
873
  white list is not used.
875
874
  black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
876
875
  black list is not used.
877
- dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
878
- default: ``mstype.float16`` .
876
+ dtype (Type, optional): The type used in lower precision calculations, can be ``mstype.float16`` or
877
+ ``mstype.bfloat16`` , default: ``mstype.float16`` .
879
878
 
880
879
  Returns:
881
880
  network (Cell), A network supporting mixed precision.
@@ -60,7 +60,8 @@ def _fill_param_into_net(net, parameter_list):
60
60
  if np_val.shape == (1,):
61
61
  parameter_dict[param_name] = Parameter(np_val, name=param_name)
62
62
  elif np_val.shape == ():
63
- parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
63
+ # pylint:disable=protected-access
64
+ parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype._pytype_to_dtype(np_val.dtype)),
64
65
  name=param_name)
65
66
  else:
66
67
  parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
@@ -27,7 +27,6 @@ from mindspore.train._utils import _make_directory
27
27
  from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
28
28
  _wait_async_thread_save_ckpt, _check_async_save
29
29
  from mindspore.parallel._cell_wrapper import destroy_allgather_cell
30
- from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
31
30
  from mindspore.communication.management import get_rank, get_group_size
32
31
  from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy, _get_pp_size_from_redundancy_map
33
32
  from mindspore.train.callback._callback import Callback
@@ -509,9 +508,6 @@ class ModelCheckpoint(Callback):
509
508
  if callable(prefix):
510
509
  self._prefix_func = prefix
511
510
 
512
- if context.get_context("device_target") == "GPU" and _get_recovery_context("enable_recovery"):
513
- _set_recovery_context(ckpt_path=self._directory)
514
-
515
511
  if config is None:
516
512
  self._config = CheckpointConfig()
517
513
  else:
@@ -577,11 +573,6 @@ class ModelCheckpoint(Callback):
577
573
  self._directory = self._directory_func(cb_params)
578
574
  _make_directory(self._directory)
579
575
  collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
580
- # In disaster recovery scenario, the training process may be rolled back to the last step where
581
- # the ckpt was successfully saved, so the _last_triggered_step should be updated.
582
- if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
583
- self._last_triggered_step = cb_params.last_save_ckpt_step
584
- cb_params.last_save_ckpt_step = None
585
576
 
586
577
  # save graph (only once)
587
578
  if not self._graph_saved:
@@ -628,13 +619,6 @@ class ModelCheckpoint(Callback):
628
619
  if "step_num" in self._append_dict:
629
620
  self._append_dict["step_num"] = self._append_step_num + step_num
630
621
 
631
- def _update_save_step(self, cb_params):
632
- """update step if used async d2h copy"""
633
- step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
634
- if self._d2h_async and self._run_mode == context.GRAPH_MODE:
635
- step_num_in_epoch -= 1
636
- return step_num_in_epoch
637
-
638
622
  def _save_ckpt(self, cb_params, force_to_save=False):
639
623
  """Save checkpoint files."""
640
624
  if cb_params.cur_step_num == self._last_triggered_step:
@@ -645,7 +629,7 @@ class ModelCheckpoint(Callback):
645
629
  self._flush_from_cache(cb_params)
646
630
 
647
631
  save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
648
- step_num_in_epoch = self._update_save_step(cb_params)
632
+ step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
649
633
 
650
634
  if save_ckpt:
651
635
 
@@ -31,7 +31,6 @@ from mindspore.communication.management import (create_group, get_group_size,
31
31
  from mindspore.parallel._auto_parallel_context import auto_parallel_context
32
32
  from mindspore.ops import operations as P
33
33
  from mindspore.common import Tensor
34
- from mindspore import context
35
34
  import mindspore.nn as nn
36
35
 
37
36
 
@@ -152,16 +151,21 @@ class FlopsUtilizationCollector(Callback):
152
151
  """
153
152
  Check whether FlopsUtilizationCollector is working in the current environment
154
153
  """
155
- if context.get_context("mode") != context.GRAPH_MODE:
156
- if self.verbose:
157
- raise ValueError("FlopsUtilizationCollector now only support graph mode.")
158
- logger.info("FlopsUtilizationCollector now only support graph mode.")
159
- return False
160
154
  cb_params = run_context.original_args()
161
155
  if cb_params.mode == 'train':
162
156
  network = cb_params.train_network
157
+ if not network.compiled:
158
+ if self.verbose:
159
+ raise ValueError("FlopsUtilizationCollector now only support graph mode.")
160
+ logger.info("FlopsUtilizationCollector now only support graph mode.")
161
+ return False
163
162
  elif cb_params.mode == 'eval':
164
163
  network = cb_params.eval_network
164
+ if not network.compiled:
165
+ if self.verbose:
166
+ raise ValueError("FlopsUtilizationCollector now only support graph mode.")
167
+ logger.info("FlopsUtilizationCollector now only support graph mode.")
168
+ return False
165
169
  else:
166
170
  if self.verbose:
167
171
  raise ValueError('FlopsUtilizationCollector only support train and eval mode!')
@@ -28,15 +28,15 @@ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post,
28
28
  from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm, _clean_rootinfo
29
29
  from mindspore._c_expression import clean_tdt_channel
30
30
  from mindspore._c_expression import _pre_launch_send_recv
31
- from mindspore._c_expression import send_recv, reset_params
31
+ from mindspore._c_expression import send_recv, reset_params, direct_copy_to_host
32
+ from mindspore._c_expression import _reg_snapshot_params, _reset_snapshot_state, _clear_snapshot_saving_flag
32
33
  from mindspore._c_expression import CollectiveManager
33
34
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
34
- from mindspore._c_expression import TensorPy as Tensor_
35
35
  from mindspore.ops.operations.manually_defined._inner import TensorReport
36
36
  import mindspore
37
37
  import mindspore.common.dtype as mstype
38
- from mindspore.parallel._recovery_context import _set_recovery_context
39
38
  from mindspore import runtime
39
+ from mindspore._c_expression import set_is_arf
40
40
 
41
41
 
42
42
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
@@ -157,6 +157,7 @@ def _tft_clean_callback(is_uce_error, args, ctx):
157
157
  CollectiveManager.get_instance().resume_hccl_comm()
158
158
  logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
159
159
  if ctx.tft.tft_get_repair_type() == "recover":
160
+ _reset_snapshot_state()
160
161
  logger.warning(f"Destroy hcom")
161
162
  _finalize_comm()
162
163
  logger.warning(f"Destroy hcom end")
@@ -166,11 +167,10 @@ def _tft_clean_callback(is_uce_error, args, ctx):
166
167
  def _tft_stop_callback(args, cb_ctx):
167
168
  """ Callback used for TFT stop function."""
168
169
  logger.warning(f"Enter _tft_stop_callback device_id: {cb_ctx.device_id}")
169
- _stop_device(cb_ctx.device_id)
170
- cb_ctx.stop_been_called = True
171
170
  if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
172
171
  raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
173
172
  cb_ctx.is_uce_rank = False
173
+ _stop_device(cb_ctx.device_id)
174
174
  if cb_ctx.tft.tft_get_repair_type() == "recover":
175
175
  logger.warning(f"Reset limit step")
176
176
  cb_ctx.tft.tft_reset_limit_step()
@@ -182,7 +182,7 @@ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
182
182
  logger.warning(f"Enter _tft_rebuild_sub_groups, device id: {ctx.device_id}")
183
183
  _rebuild_world_group()
184
184
  _rebuild_sub_group()
185
- _set_recovery_context(is_arf=True)
185
+ set_is_arf(True)
186
186
  logger.warning(f"try to pre launch send recv before real launch")
187
187
  _pre_launch_send_recv(context.get_context('device_id'))
188
188
  logger.warning(f"Pre launch send recv before real launch end")
@@ -192,7 +192,7 @@ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
192
192
  class TrainFaultTolerance(Callback):
193
193
  """
194
194
  This callback is used to enable the TFT feature
195
- `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
195
+ `MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/600/clusterscheduling/ref/mindiottp/mindiotft001.html>`_
196
196
  and will execute TFT operations during training process, such as TFT init, report and exception handle.
197
197
 
198
198
  Note:
@@ -202,7 +202,10 @@ class TrainFaultTolerance(Callback):
202
202
  ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
203
203
  a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
204
204
  is created in that directory. Default: ``None``.
205
- kwargs (dict): Other dictionary type parameters.
205
+ kwargs (dict): Other dictionary type parameters. When argument `ckpt_save_path` is ``None``, `kwargs` must
206
+ provide a parameter named `ckpt_save_fn`, which points to a function used to save checkpoint. The
207
+ prototype of `ckpt_save_fn` is ``def save_ckpt(cb_params, append_dict)``. When both `ckpt_save_path`
208
+ and `ckpt_save_fn` are provided, `ckpt_save_fn` is used in priority.
206
209
 
207
210
  Raises:
208
211
  Exception: TFT init failed.
@@ -329,7 +332,7 @@ class TrainFaultTolerance(Callback):
329
332
  # `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
330
333
  # i.e. (param_dict, remove_redundancy)
331
334
  self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
332
- if self._only_enable_tre():
335
+ if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
333
336
  return
334
337
  self.tft = _tft_handler.get_tft()
335
338
  self._check_init()
@@ -340,11 +343,9 @@ class TrainFaultTolerance(Callback):
340
343
  self.learning_rate = None
341
344
  self.has_init_replica = False
342
345
  self.is_uce_rank = False
343
- self.stop_been_called = False
344
346
 
345
347
  self.assign = mindspore.ops.Assign()
346
- self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
347
- self.s1 = mindspore.hal.Stream()
348
+ self.g_one = Tensor([1], dtype=mstype.int32)
348
349
  _tft_sem_enable()
349
350
  self._tft_register()
350
351
 
@@ -354,7 +355,21 @@ class TrainFaultTolerance(Callback):
354
355
  non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
355
356
  if any(flag in env_enable for flag in non_tre_flags):
356
357
  return False
357
- return "TRE:1" in env_enable
358
+ return "TRE:1" in env_enable or "TRE:2" in env_enable
359
+
360
+ @staticmethod
361
+ def _only_enable_ckpt_d2h_async():
362
+ """Check whether only set MS_ENABLE_CKPT_D2H_ASYNC=1 without setting MS_ENABLE_TFT"""
363
+ if os.getenv("MS_ENABLE_TFT", "") != "":
364
+ return False
365
+ return os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
366
+
367
+ @staticmethod
368
+ def _enable_snapshot():
369
+ """Check whether parameter snapshot enabled"""
370
+ enable_step_tre = "TRE:2" in os.getenv("MS_ENABLE_TFT", "")
371
+ enable_ckpt_d2h_async = os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
372
+ return enable_step_tre or enable_ckpt_d2h_async
358
373
 
359
374
  def _only_enable_tsp(self):
360
375
  """Check if only configured MS_ENABLE_TFT='{TSP:1}'"""
@@ -382,18 +397,14 @@ class TrainFaultTolerance(Callback):
382
397
  _tft_handler.init(config=None)
383
398
  self.tft = _tft_handler.get_tft()
384
399
  logger.warning(f"TFT handle init ok.")
385
- mode = context.get_context("mode")
386
400
  device_target = context.get_context("device_target")
387
- if device_target != "Ascend" or mode != context.GRAPH_MODE:
388
- raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
389
- f"device:{device_target}, run mode: {mode}")
401
+ if device_target != "Ascend":
402
+ raise ValueError(f"MindIO adataper only support on Ascend device but got device {device_target}!")
390
403
 
391
404
  def _is_params_consistent(self):
392
405
  for key, param in self.cb_params.train_network.parameters_and_names():
393
406
  if "tft_g_one_flag" in key:
394
- with mindspore.hal.StreamCtx(self.s1):
395
- tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
396
- self.s1.synchronize()
407
+ tft_g_one_flag = direct_copy_to_host(param)
397
408
  return int(tft_g_one_flag) == 1
398
409
  return False
399
410
 
@@ -438,7 +449,7 @@ class TrainFaultTolerance(Callback):
438
449
  super(TFTOptSubCls, self).__init__(*args, **kwargs)
439
450
  self.report = TensorReport()
440
451
  self.report_end = TensorReport()
441
- self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
452
+ self.report_end.add_prim_attr("optimizer_end", True)
442
453
  self.depend = ops.Depend()
443
454
  self.allreduce_sum = ops.AllReduce()
444
455
  self.allreduce_sum.add_prim_attr("tft_report_before", True)
@@ -452,7 +463,27 @@ class TrainFaultTolerance(Callback):
452
463
  self.report_end("tft_report", self.tft_g_one_flag)
453
464
  return opt_ret
454
465
 
455
- return TFTOptSubCls
466
+ class TFTOptSnapShotCls(origin_opt_cls):
467
+ """
468
+ Optimizer wrapper class when using tft.
469
+ """
470
+
471
+ def __init__(self, *args, **kwargs):
472
+ super(TFTOptSnapShotCls, self).__init__(*args, **kwargs)
473
+ self.report = TensorReport()
474
+ self.report.add_prim_attr("side_effect_mem", True).add_prim_attr("snapshot", True)
475
+ self.dummy_input = Tensor([1], dtype=mstype.int32)
476
+
477
+ def construct(self, gradients, **kwargs):
478
+ """Add fake op TensorReport to insert wait event for copying parameters"""
479
+ self.report("tft_report", self.dummy_input)
480
+ opt_ret = super(TFTOptSnapShotCls, self).construct(gradients, **kwargs)
481
+ return opt_ret
482
+
483
+ env_tft = os.getenv('MS_ENABLE_TFT', '')
484
+ features = ['TTP:1', 'UCE:1', 'ARF:1']
485
+ need_redundancy = any([env_tft.find(feat) >= 0 for feat in features])
486
+ return TFTOptSubCls if need_redundancy else TFTOptSnapShotCls
456
487
 
457
488
  def _tft_register(self):
458
489
  """Register callback functions."""
@@ -480,6 +511,17 @@ class TrainFaultTolerance(Callback):
480
511
  _clean_rootinfo()
481
512
  self.clean_unique_id = True
482
513
 
514
+ def on_train_step_begin(self, run_context):
515
+ """
516
+ Clear saving snapshot state at each step begin.
517
+
518
+ Args:
519
+ run_context (RunContext): Context of the train running. Refer to
520
+ :class:`mindspore.train.RunContext` for detail.
521
+ """
522
+ if self._enable_snapshot():
523
+ _clear_snapshot_saving_flag()
524
+
483
525
  def on_train_step_end(self, run_context):
484
526
  """
485
527
  Report status to MindIO TFT after every step finished.
@@ -488,7 +530,7 @@ class TrainFaultTolerance(Callback):
488
530
  run_context (RunContext): Context of the train running. Refer to
489
531
  :class:`mindspore.train.RunContext` for detail.
490
532
  """
491
- if self._only_enable_tre():
533
+ if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
492
534
  return
493
535
 
494
536
  cb_params = run_context.original_args()
@@ -528,10 +570,15 @@ class TrainFaultTolerance(Callback):
528
570
  run_context (RunContext): Context of the train running. Refer to
529
571
  :class:`mindspore.train.RunContext` for detail.
530
572
  """
573
+ cb_params = run_context.original_args()
574
+ if self._enable_snapshot():
575
+ param_dict = {}
576
+ for param in cb_params.train_network.trainable_params():
577
+ param_dict[param.name] = param
578
+ _reg_snapshot_params(param_dict)
531
579
  if self._only_enable_tsp():
532
580
  return
533
- cb_params = run_context.original_args()
534
- if self._only_enable_tre():
581
+ if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
535
582
  self.cb_params = cb_params
536
583
  return
537
584
  sink_size = cb_params.get("sink_size", 0)
@@ -16,7 +16,7 @@
16
16
  from functools import wraps
17
17
  import mindspore.ops as ops
18
18
  from mindspore import context
19
- from mindspore.common.dtype import pytype_to_dtype
19
+ from mindspore.common.dtype import _pytype_to_dtype
20
20
  from mindspore.common.api import jit
21
21
  from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, enable_data_broadcast
22
22
  from mindspore.train.dataset_helper import _has_dynamic_shape, _check_inputs
@@ -61,7 +61,7 @@ def _init_sink_dataset(dataset, sink_size, input_signature, create_info):
61
61
  _check_inputs(input_signature, dataset_shapes, dataset_types)
62
62
 
63
63
  queue_name = transfer_dataset.queue_name
64
- if _need_to_full() and context.get_context('mode') == context.GRAPH_MODE:
64
+ if _need_to_full():
65
65
  device_num = _get_device_num() // _get_pipeline_stages()
66
66
  dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
67
67
  next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
@@ -94,12 +94,12 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
94
94
 
95
95
  queue_name = dataset.__transfer_dataset__.queue_name
96
96
  dataset_types, dataset_shapes = dataset.__transfer_dataset__.get_data_info()
97
- dataset_types = [pytype_to_dtype(x) for x in dataset_types]
97
+ dataset_types = [_pytype_to_dtype(x) for x in dataset_types] # pylint:disable=protected-access
98
98
  key = str(dataset_types) + str(dataset_shapes)
99
99
  if key in dataset.__sink_aux__.next_ops:
100
100
  next_op = dataset.__sink_aux__.next_ops[key]
101
101
  else:
102
- if _need_to_full() and context.get_context('mode') == context.GRAPH_MODE:
102
+ if _need_to_full():
103
103
  device_num = _get_device_num() // _get_pipeline_stages()
104
104
  dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
105
105
  next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
@@ -238,12 +238,8 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
238
238
 
239
239
  real_sink_fun = _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config)
240
240
 
241
- loop = sink_size
242
- if jit_config is not None and context.get_context('mode') == context.GRAPH_MODE:
243
- loop = 1
244
-
245
241
  out = None
246
- for _ in range(loop):
242
+ for _ in range(sink_size):
247
243
  out = real_sink_fun()
248
244
 
249
245
  return out
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,8 +20,8 @@ import copy
20
20
 
21
21
  from mindspore import _checkparam as Validator
22
22
  from mindspore import log as logger
23
- from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
24
- from mindspore.common.dtype import pytype_to_dtype
23
+ from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_new_shapes
24
+ from mindspore.common.dtype import _pytype_to_dtype
25
25
  from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
26
26
  from mindspore.common._utils import is_shape_unknown
27
27
  from mindspore.dataset.core import config as dataset_config
@@ -34,7 +34,7 @@ from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_t
34
34
  _origin_shapes, _dynamic_shape_for_dataset
35
35
  from mindspore.parallel._ps_context import _is_role_sched
36
36
  from mindspore.ops import operations as P
37
- from mindspore.common.auto_dynamic_shape import _auto_dynamic_shape
37
+ from mindspore.common.dynamic_shape.auto_dynamic_shape import _auto_dynamic_shape
38
38
 
39
39
 
40
40
  def _send_data(dataset, epoch_num):
@@ -275,7 +275,7 @@ def connect_network_with_dataset(network, dataset_helper):
275
275
  # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
276
276
  if _need_to_full():
277
277
  dataset_shapes = _to_full_shapes(dataset_shapes, _get_device_num() // _get_pipeline_stages())
278
- dataset_types = [pytype_to_dtype(x) for x in dataset_types]
278
+ dataset_types = [_pytype_to_dtype(x) for x in dataset_types] # pylint:disable=protected-access
279
279
  if not is_dynamic:
280
280
  dataset_shapes = _auto_dynamic_shape.auto_dynamic_generate_compile_args(dataset_shapes, True)
281
281
  key = str(dataset_types) + str(dataset_shapes)