mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -28,15 +28,15 @@ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post,
28
28
  from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm, _clean_rootinfo
29
29
  from mindspore._c_expression import clean_tdt_channel
30
30
  from mindspore._c_expression import _pre_launch_send_recv
31
- from mindspore._c_expression import send_recv, reset_params
31
+ from mindspore._c_expression import send_recv, reset_params, direct_copy_to_host
32
+ from mindspore._c_expression import _reg_snapshot_params, _reset_snapshot_state, _clear_snapshot_saving_flag
32
33
  from mindspore._c_expression import CollectiveManager
33
34
  from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
34
- from mindspore._c_expression import TensorPy as Tensor_
35
35
  from mindspore.ops.operations.manually_defined._inner import TensorReport
36
36
  import mindspore
37
37
  import mindspore.common.dtype as mstype
38
- from mindspore.parallel._recovery_context import _set_recovery_context
39
38
  from mindspore import runtime
39
+ from mindspore._c_expression import set_is_arf
40
40
 
41
41
 
42
42
  def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
@@ -157,6 +157,7 @@ def _tft_clean_callback(is_uce_error, args, ctx):
157
157
  CollectiveManager.get_instance().resume_hccl_comm()
158
158
  logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
159
159
  if ctx.tft.tft_get_repair_type() == "recover":
160
+ _reset_snapshot_state()
160
161
  logger.warning(f"Destroy hcom")
161
162
  _finalize_comm()
162
163
  logger.warning(f"Destroy hcom end")
@@ -166,10 +167,10 @@ def _tft_clean_callback(is_uce_error, args, ctx):
166
167
  def _tft_stop_callback(args, cb_ctx):
167
168
  """ Callback used for TFT stop function."""
168
169
  logger.warning(f"Enter _tft_stop_callback device_id: {cb_ctx.device_id}")
169
- _stop_device(cb_ctx.device_id)
170
170
  if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
171
171
  raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
172
172
  cb_ctx.is_uce_rank = False
173
+ _stop_device(cb_ctx.device_id)
173
174
  if cb_ctx.tft.tft_get_repair_type() == "recover":
174
175
  logger.warning(f"Reset limit step")
175
176
  cb_ctx.tft.tft_reset_limit_step()
@@ -181,7 +182,7 @@ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
181
182
  logger.warning(f"Enter _tft_rebuild_sub_groups, device id: {ctx.device_id}")
182
183
  _rebuild_world_group()
183
184
  _rebuild_sub_group()
184
- _set_recovery_context(is_arf=True)
185
+ set_is_arf(True)
185
186
  logger.warning(f"try to pre launch send recv before real launch")
186
187
  _pre_launch_send_recv(context.get_context('device_id'))
187
188
  logger.warning(f"Pre launch send recv before real launch end")
@@ -201,7 +202,10 @@ class TrainFaultTolerance(Callback):
201
202
  ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
202
203
  a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
203
204
  is created in that directory. Default: ``None``.
204
- kwargs (dict): Other dictionary type parameters.
205
+ kwargs (dict): Other dictionary type parameters. When argument `ckpt_save_path` is ``None``, `kwargs` must
206
+ provide a parameter named `ckpt_save_fn`, which points to a function used to save checkpoint. The
207
+ prototype of `ckpt_save_fn` is ``def save_ckpt(cb_params, append_dict)``. When both `ckpt_save_path`
208
+ and `ckpt_save_fn` are provided, `ckpt_save_fn` is used in priority.
205
209
 
206
210
  Raises:
207
211
  Exception: TFT init failed.
@@ -328,7 +332,7 @@ class TrainFaultTolerance(Callback):
328
332
  # `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
329
333
  # i.e. (param_dict, remove_redundancy)
330
334
  self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
331
- if self._only_enable_tre():
335
+ if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
332
336
  return
333
337
  self.tft = _tft_handler.get_tft()
334
338
  self._check_init()
@@ -341,8 +345,7 @@ class TrainFaultTolerance(Callback):
341
345
  self.is_uce_rank = False
342
346
 
343
347
  self.assign = mindspore.ops.Assign()
344
- self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
345
- self.s1 = mindspore.hal.Stream()
348
+ self.g_one = Tensor([1], dtype=mstype.int32)
346
349
  _tft_sem_enable()
347
350
  self._tft_register()
348
351
 
@@ -352,7 +355,21 @@ class TrainFaultTolerance(Callback):
352
355
  non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
353
356
  if any(flag in env_enable for flag in non_tre_flags):
354
357
  return False
355
- return "TRE:1" in env_enable
358
+ return "TRE:1" in env_enable or "TRE:2" in env_enable
359
+
360
+ @staticmethod
361
+ def _only_enable_ckpt_d2h_async():
362
+ """Check whether only set MS_ENABLE_CKPT_D2H_ASYNC=1 without setting MS_ENABLE_TFT"""
363
+ if os.getenv("MS_ENABLE_TFT", "") != "":
364
+ return False
365
+ return os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
366
+
367
+ @staticmethod
368
+ def _enable_snapshot():
369
+ """Check whether parameter snapshot enabled"""
370
+ enable_step_tre = "TRE:2" in os.getenv("MS_ENABLE_TFT", "")
371
+ enable_ckpt_d2h_async = os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
372
+ return enable_step_tre or enable_ckpt_d2h_async
356
373
 
357
374
  def _only_enable_tsp(self):
358
375
  """Check if only configured MS_ENABLE_TFT='{TSP:1}'"""
@@ -387,9 +404,7 @@ class TrainFaultTolerance(Callback):
387
404
  def _is_params_consistent(self):
388
405
  for key, param in self.cb_params.train_network.parameters_and_names():
389
406
  if "tft_g_one_flag" in key:
390
- with mindspore.hal.StreamCtx(self.s1):
391
- tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
392
- self.s1.synchronize()
407
+ tft_g_one_flag = direct_copy_to_host(param)
393
408
  return int(tft_g_one_flag) == 1
394
409
  return False
395
410
 
@@ -434,7 +449,7 @@ class TrainFaultTolerance(Callback):
434
449
  super(TFTOptSubCls, self).__init__(*args, **kwargs)
435
450
  self.report = TensorReport()
436
451
  self.report_end = TensorReport()
437
- self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
452
+ self.report_end.add_prim_attr("optimizer_end", True)
438
453
  self.depend = ops.Depend()
439
454
  self.allreduce_sum = ops.AllReduce()
440
455
  self.allreduce_sum.add_prim_attr("tft_report_before", True)
@@ -448,7 +463,27 @@ class TrainFaultTolerance(Callback):
448
463
  self.report_end("tft_report", self.tft_g_one_flag)
449
464
  return opt_ret
450
465
 
451
- return TFTOptSubCls
466
+ class TFTOptSnapShotCls(origin_opt_cls):
467
+ """
468
+ Optimizer wrapper class when using tft.
469
+ """
470
+
471
+ def __init__(self, *args, **kwargs):
472
+ super(TFTOptSnapShotCls, self).__init__(*args, **kwargs)
473
+ self.report = TensorReport()
474
+ self.report.add_prim_attr("side_effect_mem", True).add_prim_attr("snapshot", True)
475
+ self.dummy_input = Tensor([1], dtype=mstype.int32)
476
+
477
+ def construct(self, gradients, **kwargs):
478
+ """Add fake op TensorReport to insert wait event for copying parameters"""
479
+ self.report("tft_report", self.dummy_input)
480
+ opt_ret = super(TFTOptSnapShotCls, self).construct(gradients, **kwargs)
481
+ return opt_ret
482
+
483
+ env_tft = os.getenv('MS_ENABLE_TFT', '')
484
+ features = ['TTP:1', 'UCE:1', 'ARF:1']
485
+ need_redundancy = any([env_tft.find(feat) >= 0 for feat in features])
486
+ return TFTOptSubCls if need_redundancy else TFTOptSnapShotCls
452
487
 
453
488
  def _tft_register(self):
454
489
  """Register callback functions."""
@@ -476,6 +511,17 @@ class TrainFaultTolerance(Callback):
476
511
  _clean_rootinfo()
477
512
  self.clean_unique_id = True
478
513
 
514
+ def on_train_step_begin(self, run_context):
515
+ """
516
+ Clear saving snapshot state at each step begin.
517
+
518
+ Args:
519
+ run_context (RunContext): Context of the train running. Refer to
520
+ :class:`mindspore.train.RunContext` for detail.
521
+ """
522
+ if self._enable_snapshot():
523
+ _clear_snapshot_saving_flag()
524
+
479
525
  def on_train_step_end(self, run_context):
480
526
  """
481
527
  Report status to MindIO TFT after every step finished.
@@ -484,7 +530,7 @@ class TrainFaultTolerance(Callback):
484
530
  run_context (RunContext): Context of the train running. Refer to
485
531
  :class:`mindspore.train.RunContext` for detail.
486
532
  """
487
- if self._only_enable_tre():
533
+ if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
488
534
  return
489
535
 
490
536
  cb_params = run_context.original_args()
@@ -524,10 +570,15 @@ class TrainFaultTolerance(Callback):
524
570
  run_context (RunContext): Context of the train running. Refer to
525
571
  :class:`mindspore.train.RunContext` for detail.
526
572
  """
573
+ cb_params = run_context.original_args()
574
+ if self._enable_snapshot():
575
+ param_dict = {}
576
+ for param in cb_params.train_network.trainable_params():
577
+ param_dict[param.name] = param
578
+ _reg_snapshot_params(param_dict)
527
579
  if self._only_enable_tsp():
528
580
  return
529
- cb_params = run_context.original_args()
530
- if self._only_enable_tre():
581
+ if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
531
582
  self.cb_params = cb_params
532
583
  return
533
584
  sink_size = cb_params.get("sink_size", 0)
@@ -238,12 +238,8 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
238
238
 
239
239
  real_sink_fun = _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config)
240
240
 
241
- loop = sink_size
242
- if jit_config is not None and context.get_context('mode') == context.GRAPH_MODE:
243
- loop = 1
244
-
245
241
  out = None
246
- for _ in range(loop):
242
+ for _ in range(sink_size):
247
243
  out = real_sink_fun()
248
244
 
249
245
  return out
mindspore/train/model.py CHANGED
@@ -28,7 +28,7 @@ import numpy as np
28
28
 
29
29
  import mindspore
30
30
  from mindspore import log as logger
31
- from mindspore.train.serialization import save_checkpoint, load_checkpoint
31
+ from mindspore.train.serialization import save_checkpoint
32
32
  from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
33
33
  from mindspore.common.tensor import Tensor
34
34
  from mindspore.train.metrics import get_metrics, get_metric_fn
@@ -40,16 +40,12 @@ from mindspore.train.callback import __all__ as internal_cb_names
40
40
  from mindspore.train.callback._cluster_monitor import ClusterMonitor
41
41
  from mindspore import context
42
42
  from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_parameter_broadcast, \
43
- _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
44
- _reset_op_id_with_offset
45
- from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
46
- _cache_enable, _enable_distributed_mindrt
43
+ _device_number_check, _parameter_broadcast_check, _parallel_predict_check
47
44
  from mindspore.train.metrics import Loss
48
45
  from mindspore.log import vlog_print
49
46
  from mindspore import nn
50
47
  from mindspore.boost import AutoBoost
51
48
  from mindspore.context import ParallelMode
52
- from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
53
49
  from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
54
50
  from mindspore.common.api import _pynative_executor, ARG_SPECIFIED, TOTAL_ARG_LEN
55
51
  from mindspore.dataset.core.config import get_debug_mode
@@ -57,7 +53,8 @@ from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_trai
57
53
  from mindspore.train import amp
58
54
  from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
59
55
  from mindspore._c_expression import _get_optimzer_timestamps
60
- from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo
56
+ from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo, check_is_arf, set_is_arf
57
+ from mindspore._c_expression import _get_snapshot_params, _is_snapshot_valid
61
58
 
62
59
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
63
60
  from .serialization import load_param_into_net
@@ -163,7 +160,7 @@ def _handle_exception_info(obj, uce_env, tft, e):
163
160
  tft.tft_report_error(force_stop_err)
164
161
  elif "ARF FINISH" in e_str:
165
162
  logger.warning(f"ARF FINISH")
166
- _set_recovery_context(is_arf=True)
163
+ set_is_arf(True)
167
164
  tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
168
165
  else:
169
166
  logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
@@ -175,7 +172,12 @@ def _handle_training_result_error(model, tft_obj):
175
172
  """
176
173
  Handle training result error for resuming training.
177
174
  """
178
- ckpt_load_fn = tft_obj.ckpt_load_func
175
+ def load_snapshot_params():
176
+ param_dict = {}
177
+ for name, tensor in _get_snapshot_params().items():
178
+ param_dict[name] = mindspore.Parameter(tensor, name=name)
179
+ return (param_dict, False)
180
+ ckpt_load_fn = load_snapshot_params if _is_snapshot_valid() else tft_obj.ckpt_load_func
179
181
  train_network = tft_obj.cb_params.train_network
180
182
  logger.warning("Process training result error start.")
181
183
  # 1. Clear tdt channel
@@ -234,6 +236,20 @@ def _update_ckpt_callback_info(resume_train_step, **kwargs):
234
236
  ckpt_obj._append_step_num = resume_train_step
235
237
 
236
238
 
239
+ def _get_tft_obj(**kwargs):
240
+ """
241
+ Get TrainFaultTolerance from kwargs of callback
242
+ """
243
+ obj = None
244
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
245
+ obj = kwargs.get('callbacks')
246
+ if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
247
+ for item in kwargs.get('callbacks'):
248
+ if isinstance(item, TrainFaultTolerance):
249
+ obj = item
250
+ return obj
251
+
252
+
237
253
  def _handle_tft(func):
238
254
  """
239
255
  Decorator function, which starts uce handle process when an exception occurs during training.
@@ -241,17 +257,11 @@ def _handle_tft(func):
241
257
 
242
258
  @wraps(func)
243
259
  def wrapper(self, *args, **kwargs):
244
- obj = None
245
- if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
246
- obj = kwargs.get('callbacks')
247
- if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
248
- for item in kwargs.get('callbacks'):
249
- if isinstance(item, TrainFaultTolerance):
250
- obj = item
251
- if obj:
260
+ obj = _get_tft_obj(**kwargs)
261
+ if obj and not TrainFaultTolerance._only_enable_ckpt_d2h_async():
252
262
  tft_env = os.getenv("MS_ENABLE_TFT", "")
253
263
  uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env or "HCCE:1" in tft_env
254
- tre_env = "TRE:1" in tft_env
264
+ tre_env = "TRE:1" in tft_env or "TRE:2" in tft_env
255
265
  while True:
256
266
  try:
257
267
  return func(self, *args, **kwargs)
@@ -556,9 +566,7 @@ class Model:
556
566
  self._current_epoch_num = 0
557
567
  self._current_step_num = 0
558
568
  self.epoch_iter = 0
559
- self.enable_recovery = False
560
569
  self._backbone_is_train = True
561
- self.need_load_ckpt = False
562
570
  self._lite_full_predictor = None
563
571
  self._lite_incremental_predictor = None
564
572
  self._mindspore_lite = None
@@ -731,10 +739,7 @@ class Model:
731
739
  metrics = dict()
732
740
  # There's no need for server to execute eval, just give fake metrics.
733
741
  for key, value in self._metric_fns.items():
734
- if not _is_role_pserver():
735
- metrics[key] = value.eval()
736
- else:
737
- metrics[key] = 1
742
+ metrics[key] = value.eval()
738
743
  return metrics
739
744
 
740
745
  def _get_scaling_sens(self):
@@ -768,7 +773,7 @@ class Model:
768
773
  logger.info("Begin to connect network with dataset.")
769
774
  network = connect_network_with_dataset(network, dataset_helper)
770
775
 
771
- if (_get_recovery_context("enable_recovery") or self._need_reset_data) and is_train:
776
+ if self._need_reset_data and is_train:
772
777
  _set_training_dataset(dataset_helper)
773
778
 
774
779
  network.set_train(is_train)
@@ -810,9 +815,7 @@ class Model:
810
815
  :param cb_params: callback params
811
816
  :return: none
812
817
  """
813
- if os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") != "1":
814
- return
815
- if context.get_context("device_target") == "Ascend":
818
+ if TrainFaultTolerance._enable_snapshot() and context.get_context("device_target") == "Ascend":
816
819
  cb_params.need_ckpt, cb_params.save_checkpoint_steps, \
817
820
  cb_params.last_triggered_step = self._check_need_ckpt(cb_params.list_callback)
818
821
  logger.info(f"need_ckpt:{cb_params.need_ckpt},"
@@ -1018,13 +1021,10 @@ class Model:
1018
1021
  callbacks = cb_params.list_callback
1019
1022
  cb_params.train_dataset_element = None
1020
1023
  cb_params.network = self._network
1021
- # Embedding cache server only run one step.
1022
- if _is_role_pserver() and _cache_enable():
1023
- epoch = 1
1024
1024
  cb_params.last_save_ckpt_step = None
1025
1025
  cb_params.latest_ckpt_file = None
1026
1026
  cb_params.loss_scale_mananger = self._loss_scale_manager
1027
- cb_params.is_arf = _get_recovery_context("is_arf")
1027
+ cb_params.is_arf = check_is_arf()
1028
1028
  cb_params.initial_step = self._initial_step
1029
1029
 
1030
1030
  # build callback list
@@ -1086,12 +1086,6 @@ class Model:
1086
1086
  dataset_helper = train_dataset._dataset_helper
1087
1087
 
1088
1088
  self.epoch_iter = 0
1089
- self._check_enable_recovery()
1090
- # Used to check whether need perform recovery for process which is restarted.
1091
- self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
1092
- # Check whether this process is embedding cache server.
1093
- is_embedding_cache_server = _is_role_pserver() and _cache_enable()
1094
-
1095
1089
  while self.epoch_iter < (epoch - initial_epoch):
1096
1090
  cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
1097
1091
  self._current_epoch_num = cb_params.cur_epoch_num
@@ -1107,11 +1101,6 @@ class Model:
1107
1101
  cb_params.train_network = train_network
1108
1102
  cb_params.dataset_helper = dataset_helper
1109
1103
 
1110
- # Perform recovery for process which is restarted.
1111
- self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
1112
- # Perform recovery for process which is not restarted.
1113
- self._reset_training_step_for_normal_process(cb_params, dataset_helper)
1114
-
1115
1104
  # For data sink dataset_helper only iter once, other wise iter epoch_size times.
1116
1105
  for inputs in dataset_helper:
1117
1106
  if is_graph:
@@ -1126,36 +1115,17 @@ class Model:
1126
1115
  outputs = train_network(*inputs)
1127
1116
  cb_params.net_outputs = outputs
1128
1117
 
1129
- # In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
1130
- need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
1131
- if need_exec_callback_step_end:
1132
- list_callback.on_train_step_end(run_context)
1118
+ list_callback.on_train_step_end(run_context)
1119
+
1133
1120
  if cb_params.is_arf:
1134
1121
  cb_params.is_arf = False
1135
- _set_recovery_context(is_arf=False)
1122
+ set_is_arf(False)
1136
1123
  _clean_rootinfo()
1137
1124
 
1138
- # Embedding cache server only run one step.
1139
- if is_embedding_cache_server:
1140
- break
1141
-
1142
1125
  dataset_helper.continue_send()
1143
1126
 
1144
- # When it's distributed training and using MindRT,
1145
- # the node id should be reset to start from 0.
1146
- # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1147
- if _enable_distributed_mindrt():
1148
- _reset_op_id_with_offset()
1149
-
1150
1127
  self._eval_during_train(valid_infos, cb_params, list_callback)
1151
-
1152
- # In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
1153
- # Embedding cache server need not do epoch end callback, this process only run one step.
1154
- need_exec_callback_epoch_end = not ((self.enable_recovery and _get_recovery_context("need_reset"))
1155
- or is_embedding_cache_server)
1156
-
1157
- if need_exec_callback_epoch_end:
1158
- list_callback.on_train_epoch_end(run_context)
1128
+ list_callback.on_train_epoch_end(run_context)
1159
1129
  if "metrics" in cb_params or "eval_results" in cb_params:
1160
1130
  cb_params.pop("metrics", None)
1161
1131
  cb_params.pop("eval_results", None)
@@ -1164,12 +1134,7 @@ class Model:
1164
1134
  if should_stop:
1165
1135
  break
1166
1136
 
1167
- need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
1168
- and not _get_recovery_context("latest_ckpt_file")
1169
1137
  self.epoch_iter += 1
1170
- if need_reset_to_beginning:
1171
- self.epoch_iter = 0
1172
- cb_params.cur_step_num = 0
1173
1138
 
1174
1139
  dataset_helper.stop_send()
1175
1140
  dataset_helper.release()
@@ -1203,93 +1168,6 @@ class Model:
1203
1168
  cb_params.dataset_sink_mode = train_dataset_sink_mode
1204
1169
  cb_params.net_outputs = train_net_outputs
1205
1170
 
1206
- def _check_enable_recovery(self):
1207
- """
1208
- Check whether enable recovery and execution mode consistency.
1209
- """
1210
-
1211
- enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
1212
- if not enable_recovery:
1213
- self.enable_recovery = False
1214
- else:
1215
- self.enable_recovery = enable_recovery and _is_role_worker()
1216
-
1217
- def _check_need_load_ckpt(self, cb_params, dataset_size, sink_size=-1):
1218
- """
1219
- Check whether need to load checkpoint after abnormal process restart.
1220
-
1221
- Args:
1222
- cb_params (_InternalCallbackParam): Callback parameters.
1223
- dataset_size (int): The number of batches in a dataset.
1224
- sink_size (int): Control the amount of data in each sink. Default: -1.
1225
- """
1226
- if context.get_context("device_target") != "GPU":
1227
- return
1228
- if not self.enable_recovery:
1229
- self.need_load_ckpt = False
1230
-
1231
- cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
1232
- if cb_params.latest_ckpt_file:
1233
- recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
1234
- recovery_step_num = _get_recovery_context("latest_ckpt_step")
1235
- dataset_sink_size = sink_size if sink_size > 0 else dataset_size
1236
- cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_sink_size + recovery_step_num
1237
- cb_params.last_save_ckpt_step = cb_params.cur_step_num
1238
- self.epoch_iter = recovery_epoch_num
1239
- self.need_load_ckpt = True
1240
- else:
1241
- self.need_load_ckpt = False
1242
-
1243
- def _reset_training_step_for_abnormal_process(self, cb_params, dataset_helper):
1244
- """
1245
- Execute recovery for abnormal exit process when restart.
1246
-
1247
- Args:
1248
- cb_params (_InternalCallbackParam): Callback parameters.
1249
- """
1250
-
1251
- if self.need_load_ckpt:
1252
- try:
1253
- load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
1254
- except BaseException as e:
1255
- os.remove(cb_params.latest_ckpt_file)
1256
- raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
1257
- + cb_params.latest_ckpt_file) from e
1258
- _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
1259
- self.need_load_ckpt = False
1260
-
1261
- def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
1262
- """
1263
- Execute recovery for normal process when there is process exit abnormally.
1264
-
1265
- Args:
1266
- cb_params (_InternalCallbackParam): Callback parameters.
1267
- dataset_helper (DatasetHelper): A class to process the MindData dataset,
1268
- it provides the type, shape and queue name of the dataset to wrap the `GetNext`.
1269
- """
1270
-
1271
- if self.enable_recovery and _get_recovery_context("need_reset"):
1272
- cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
1273
- if cb_params.latest_ckpt_file:
1274
- try:
1275
- load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
1276
- except BaseException as e:
1277
- os.remove(cb_params.latest_ckpt_file)
1278
- raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
1279
- + cb_params.latest_ckpt_file) from e
1280
-
1281
- recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
1282
- recovery_step_num = _get_recovery_context("latest_ckpt_step")
1283
- cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_helper.sink_size() + recovery_step_num
1284
- self.epoch_iter = recovery_epoch_num
1285
- cb_params.cur_epoch_num = self.epoch_iter + 1
1286
- cb_params.last_save_ckpt_step = cb_params.cur_step_num
1287
- _reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
1288
- else:
1289
- _reset_training_dataset(0, dataset_helper.iter.dataset.get_dataset_size())
1290
-
1291
- _set_recovery_context(need_reset=False)
1292
-
1293
1171
  def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0,
1294
1172
  valid_infos=None):
1295
1173
  """
@@ -1314,7 +1192,6 @@ class Model:
1314
1192
  cb_params.dataset_sink_mode = False
1315
1193
  run_context = RunContext(cb_params)
1316
1194
  list_callback.on_train_begin(run_context)
1317
- is_embedding_cache_server = _is_role_pserver() and _cache_enable()
1318
1195
 
1319
1196
  for i in range(initial_epoch, epoch):
1320
1197
  cb_params.cur_epoch_num = i + 1
@@ -1345,21 +1222,12 @@ class Model:
1345
1222
  list_callback.on_train_step_end(run_context)
1346
1223
  if cb_params.is_arf:
1347
1224
  cb_params.is_arf = False
1348
- _set_recovery_context(is_arf=False)
1225
+ set_is_arf(False)
1349
1226
  _clean_rootinfo()
1350
- # Embedding cache server only run one step.
1351
- if is_embedding_cache_server:
1352
- break
1353
1227
  should_stop = run_context.get_stop_requested()
1354
1228
  if should_stop:
1355
1229
  break
1356
1230
 
1357
- # When it's distributed training and using MindRT,
1358
- # the node id should be reset to start from 0.
1359
- # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1360
- if _enable_distributed_mindrt():
1361
- _reset_op_id_with_offset()
1362
-
1363
1231
  self._eval_during_train(valid_infos, cb_params, list_callback)
1364
1232
 
1365
1233
  train_dataset.reset()
@@ -1367,9 +1235,7 @@ class Model:
1367
1235
  # if param is cache enable, flush data from cache to host before epoch end
1368
1236
  self._flush_from_cache(cb_params)
1369
1237
 
1370
- # Embedding cache server need not do epoch end callback, this process only run one step.
1371
- if not is_embedding_cache_server:
1372
- list_callback.on_train_epoch_end(run_context)
1238
+ list_callback.on_train_epoch_end(run_context)
1373
1239
  if "metrics" in cb_params or "eval_results" in cb_params:
1374
1240
  cb_params.pop("metrics", None)
1375
1241
  cb_params.pop("eval_results", None)
@@ -1446,10 +1312,6 @@ class Model:
1446
1312
  """
1447
1313
  _init_auto_parallel_context(self._network)
1448
1314
  _check_tft()
1449
- device_target = context.get_context("device_target")
1450
- if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1451
- logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1452
- dataset_sink_mode = False
1453
1315
 
1454
1316
  Validator.check_bool(dataset_sink_mode)
1455
1317
  if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
@@ -1461,11 +1323,6 @@ class Model:
1461
1323
  "the value of epoch in train {} separately."
1462
1324
  .format(train_dataset._warmup_epoch, epoch))
1463
1325
 
1464
- # Parameter server and embedding cache mode check.
1465
- if _is_ps_mode():
1466
- if not dataset_sink_mode and _cache_enable():
1467
- raise ValueError("Embedding cache mode should run with 'dataset_sink_mode=True'.")
1468
-
1469
1326
  self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
1470
1327
 
1471
1328
  Validator.check_is_int(sink_size)
@@ -1496,12 +1353,6 @@ class Model:
1496
1353
  sink_size=sink_size,
1497
1354
  initial_epoch=initial_epoch)
1498
1355
 
1499
- # When it's distributed training and using MindRT,
1500
- # the node id should be reset to start from 0.
1501
- # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1502
- if _enable_distributed_mindrt():
1503
- _reset_op_id_with_offset()
1504
-
1505
1356
  _clear_auto_parallel_context(self._network)
1506
1357
 
1507
1358
  @staticmethod
@@ -1599,10 +1450,6 @@ class Model:
1599
1450
  >>> model.fit(2, train_dataset, valid_dataset)
1600
1451
  """
1601
1452
  _init_auto_parallel_context(self._network)
1602
- device_target = context.get_context("device_target")
1603
- if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
1604
- logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
1605
- dataset_sink_mode = False
1606
1453
 
1607
1454
  dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
1608
1455
  valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
@@ -1896,13 +1743,6 @@ class Model:
1896
1743
 
1897
1744
  self._clear_metrics()
1898
1745
 
1899
- # Embedding cache server as a storage service, no need to execute eval.
1900
- is_embedding_cache_server = _is_role_pserver() and _cache_enable()
1901
- if is_embedding_cache_server:
1902
- metrics = self._get_metrics()
1903
- cb_params.metrics = metrics
1904
- return metrics
1905
-
1906
1746
  if context.get_context("device_target") == "CPU" and dataset_sink_mode:
1907
1747
  dataset_sink_mode = False
1908
1748
  logger.info("CPU cannot support dataset sink mode currently."
@@ -1914,13 +1754,7 @@ class Model:
1914
1754
  else:
1915
1755
  eval_result = self._eval_process(valid_dataset, list_callback, cb_params)
1916
1756
 
1917
- # When it's distributed training and using MindRT,
1918
- # the node id should be reset to start from 0.
1919
- # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
1920
- if _enable_distributed_mindrt():
1921
- _reset_op_id_with_offset()
1922
1757
  _clear_auto_parallel_context(self._network)
1923
-
1924
1758
  return eval_result
1925
1759
 
1926
1760
  def _predict_lite(self, *predict_data, config=None):
@@ -2171,13 +2005,6 @@ class Model:
2171
2005
  result = self._predict_network(*predict_data)
2172
2006
 
2173
2007
  check_output_data(result)
2174
-
2175
- # When it's distributed training and using MindRT,
2176
- # the node id should be reset to start from 0.
2177
- # This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
2178
- if _enable_distributed_mindrt():
2179
- _reset_op_id_with_offset()
2180
-
2181
2008
  return result
2182
2009
 
2183
2010
  def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):