mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -28,15 +28,15 @@ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post,
|
|
|
28
28
|
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm, _clean_rootinfo
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
30
|
from mindspore._c_expression import _pre_launch_send_recv
|
|
31
|
-
from mindspore._c_expression import send_recv, reset_params
|
|
31
|
+
from mindspore._c_expression import send_recv, reset_params, direct_copy_to_host
|
|
32
|
+
from mindspore._c_expression import _reg_snapshot_params, _reset_snapshot_state, _clear_snapshot_saving_flag
|
|
32
33
|
from mindspore._c_expression import CollectiveManager
|
|
33
34
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
34
|
-
from mindspore._c_expression import TensorPy as Tensor_
|
|
35
35
|
from mindspore.ops.operations.manually_defined._inner import TensorReport
|
|
36
36
|
import mindspore
|
|
37
37
|
import mindspore.common.dtype as mstype
|
|
38
|
-
from mindspore.parallel._recovery_context import _set_recovery_context
|
|
39
38
|
from mindspore import runtime
|
|
39
|
+
from mindspore._c_expression import set_is_arf
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
@@ -157,6 +157,7 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
157
157
|
CollectiveManager.get_instance().resume_hccl_comm()
|
|
158
158
|
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
159
159
|
if ctx.tft.tft_get_repair_type() == "recover":
|
|
160
|
+
_reset_snapshot_state()
|
|
160
161
|
logger.warning(f"Destroy hcom")
|
|
161
162
|
_finalize_comm()
|
|
162
163
|
logger.warning(f"Destroy hcom end")
|
|
@@ -166,10 +167,10 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
166
167
|
def _tft_stop_callback(args, cb_ctx):
|
|
167
168
|
""" Callback used for TFT stop function."""
|
|
168
169
|
logger.warning(f"Enter _tft_stop_callback device_id: {cb_ctx.device_id}")
|
|
169
|
-
_stop_device(cb_ctx.device_id)
|
|
170
170
|
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
171
171
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
172
172
|
cb_ctx.is_uce_rank = False
|
|
173
|
+
_stop_device(cb_ctx.device_id)
|
|
173
174
|
if cb_ctx.tft.tft_get_repair_type() == "recover":
|
|
174
175
|
logger.warning(f"Reset limit step")
|
|
175
176
|
cb_ctx.tft.tft_reset_limit_step()
|
|
@@ -181,7 +182,7 @@ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
|
181
182
|
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: {ctx.device_id}")
|
|
182
183
|
_rebuild_world_group()
|
|
183
184
|
_rebuild_sub_group()
|
|
184
|
-
|
|
185
|
+
set_is_arf(True)
|
|
185
186
|
logger.warning(f"try to pre launch send recv before real launch")
|
|
186
187
|
_pre_launch_send_recv(context.get_context('device_id'))
|
|
187
188
|
logger.warning(f"Pre launch send recv before real launch end")
|
|
@@ -201,7 +202,10 @@ class TrainFaultTolerance(Callback):
|
|
|
201
202
|
ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
|
|
202
203
|
a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
|
|
203
204
|
is created in that directory. Default: ``None``.
|
|
204
|
-
kwargs (dict): Other dictionary type parameters.
|
|
205
|
+
kwargs (dict): Other dictionary type parameters. When argument `ckpt_save_path` is ``None``, `kwargs` must
|
|
206
|
+
provide a parameter named `ckpt_save_fn`, which points to a function used to save checkpoint. The
|
|
207
|
+
prototype of `ckpt_save_fn` is ``def save_ckpt(cb_params, append_dict)``. When both `ckpt_save_path`
|
|
208
|
+
and `ckpt_save_fn` are provided, `ckpt_save_fn` is used in priority.
|
|
205
209
|
|
|
206
210
|
Raises:
|
|
207
211
|
Exception: TFT init failed.
|
|
@@ -328,7 +332,7 @@ class TrainFaultTolerance(Callback):
|
|
|
328
332
|
# `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
|
|
329
333
|
# i.e. (param_dict, remove_redundancy)
|
|
330
334
|
self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
|
|
331
|
-
if self._only_enable_tre():
|
|
335
|
+
if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
|
|
332
336
|
return
|
|
333
337
|
self.tft = _tft_handler.get_tft()
|
|
334
338
|
self._check_init()
|
|
@@ -341,8 +345,7 @@ class TrainFaultTolerance(Callback):
|
|
|
341
345
|
self.is_uce_rank = False
|
|
342
346
|
|
|
343
347
|
self.assign = mindspore.ops.Assign()
|
|
344
|
-
self.g_one =
|
|
345
|
-
self.s1 = mindspore.hal.Stream()
|
|
348
|
+
self.g_one = Tensor([1], dtype=mstype.int32)
|
|
346
349
|
_tft_sem_enable()
|
|
347
350
|
self._tft_register()
|
|
348
351
|
|
|
@@ -352,7 +355,21 @@ class TrainFaultTolerance(Callback):
|
|
|
352
355
|
non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
|
|
353
356
|
if any(flag in env_enable for flag in non_tre_flags):
|
|
354
357
|
return False
|
|
355
|
-
return "TRE:1" in env_enable
|
|
358
|
+
return "TRE:1" in env_enable or "TRE:2" in env_enable
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def _only_enable_ckpt_d2h_async():
|
|
362
|
+
"""Check whether only set MS_ENABLE_CKPT_D2H_ASYNC=1 without setting MS_ENABLE_TFT"""
|
|
363
|
+
if os.getenv("MS_ENABLE_TFT", "") != "":
|
|
364
|
+
return False
|
|
365
|
+
return os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
def _enable_snapshot():
|
|
369
|
+
"""Check whether parameter snapshot enabled"""
|
|
370
|
+
enable_step_tre = "TRE:2" in os.getenv("MS_ENABLE_TFT", "")
|
|
371
|
+
enable_ckpt_d2h_async = os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
|
|
372
|
+
return enable_step_tre or enable_ckpt_d2h_async
|
|
356
373
|
|
|
357
374
|
def _only_enable_tsp(self):
|
|
358
375
|
"""Check if only configured MS_ENABLE_TFT='{TSP:1}'"""
|
|
@@ -387,9 +404,7 @@ class TrainFaultTolerance(Callback):
|
|
|
387
404
|
def _is_params_consistent(self):
|
|
388
405
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
389
406
|
if "tft_g_one_flag" in key:
|
|
390
|
-
|
|
391
|
-
tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
|
|
392
|
-
self.s1.synchronize()
|
|
407
|
+
tft_g_one_flag = direct_copy_to_host(param)
|
|
393
408
|
return int(tft_g_one_flag) == 1
|
|
394
409
|
return False
|
|
395
410
|
|
|
@@ -434,7 +449,7 @@ class TrainFaultTolerance(Callback):
|
|
|
434
449
|
super(TFTOptSubCls, self).__init__(*args, **kwargs)
|
|
435
450
|
self.report = TensorReport()
|
|
436
451
|
self.report_end = TensorReport()
|
|
437
|
-
self.report_end.add_prim_attr("
|
|
452
|
+
self.report_end.add_prim_attr("optimizer_end", True)
|
|
438
453
|
self.depend = ops.Depend()
|
|
439
454
|
self.allreduce_sum = ops.AllReduce()
|
|
440
455
|
self.allreduce_sum.add_prim_attr("tft_report_before", True)
|
|
@@ -448,7 +463,27 @@ class TrainFaultTolerance(Callback):
|
|
|
448
463
|
self.report_end("tft_report", self.tft_g_one_flag)
|
|
449
464
|
return opt_ret
|
|
450
465
|
|
|
451
|
-
|
|
466
|
+
class TFTOptSnapShotCls(origin_opt_cls):
|
|
467
|
+
"""
|
|
468
|
+
Optimizer wrapper class when using tft.
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
def __init__(self, *args, **kwargs):
|
|
472
|
+
super(TFTOptSnapShotCls, self).__init__(*args, **kwargs)
|
|
473
|
+
self.report = TensorReport()
|
|
474
|
+
self.report.add_prim_attr("side_effect_mem", True).add_prim_attr("snapshot", True)
|
|
475
|
+
self.dummy_input = Tensor([1], dtype=mstype.int32)
|
|
476
|
+
|
|
477
|
+
def construct(self, gradients, **kwargs):
|
|
478
|
+
"""Add fake op TensorReport to insert wait event for copying parameters"""
|
|
479
|
+
self.report("tft_report", self.dummy_input)
|
|
480
|
+
opt_ret = super(TFTOptSnapShotCls, self).construct(gradients, **kwargs)
|
|
481
|
+
return opt_ret
|
|
482
|
+
|
|
483
|
+
env_tft = os.getenv('MS_ENABLE_TFT', '')
|
|
484
|
+
features = ['TTP:1', 'UCE:1', 'ARF:1']
|
|
485
|
+
need_redundancy = any([env_tft.find(feat) >= 0 for feat in features])
|
|
486
|
+
return TFTOptSubCls if need_redundancy else TFTOptSnapShotCls
|
|
452
487
|
|
|
453
488
|
def _tft_register(self):
|
|
454
489
|
"""Register callback functions."""
|
|
@@ -476,6 +511,17 @@ class TrainFaultTolerance(Callback):
|
|
|
476
511
|
_clean_rootinfo()
|
|
477
512
|
self.clean_unique_id = True
|
|
478
513
|
|
|
514
|
+
def on_train_step_begin(self, run_context):
|
|
515
|
+
"""
|
|
516
|
+
Clear saving snapshot state at each step begin.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
520
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
521
|
+
"""
|
|
522
|
+
if self._enable_snapshot():
|
|
523
|
+
_clear_snapshot_saving_flag()
|
|
524
|
+
|
|
479
525
|
def on_train_step_end(self, run_context):
|
|
480
526
|
"""
|
|
481
527
|
Report status to MindIO TFT after every step finished.
|
|
@@ -484,7 +530,7 @@ class TrainFaultTolerance(Callback):
|
|
|
484
530
|
run_context (RunContext): Context of the train running. Refer to
|
|
485
531
|
:class:`mindspore.train.RunContext` for detail.
|
|
486
532
|
"""
|
|
487
|
-
if self._only_enable_tre():
|
|
533
|
+
if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
|
|
488
534
|
return
|
|
489
535
|
|
|
490
536
|
cb_params = run_context.original_args()
|
|
@@ -524,10 +570,15 @@ class TrainFaultTolerance(Callback):
|
|
|
524
570
|
run_context (RunContext): Context of the train running. Refer to
|
|
525
571
|
:class:`mindspore.train.RunContext` for detail.
|
|
526
572
|
"""
|
|
573
|
+
cb_params = run_context.original_args()
|
|
574
|
+
if self._enable_snapshot():
|
|
575
|
+
param_dict = {}
|
|
576
|
+
for param in cb_params.train_network.trainable_params():
|
|
577
|
+
param_dict[param.name] = param
|
|
578
|
+
_reg_snapshot_params(param_dict)
|
|
527
579
|
if self._only_enable_tsp():
|
|
528
580
|
return
|
|
529
|
-
|
|
530
|
-
if self._only_enable_tre():
|
|
581
|
+
if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
|
|
531
582
|
self.cb_params = cb_params
|
|
532
583
|
return
|
|
533
584
|
sink_size = cb_params.get("sink_size", 0)
|
mindspore/train/data_sink.py
CHANGED
|
@@ -238,12 +238,8 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
238
238
|
|
|
239
239
|
real_sink_fun = _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config)
|
|
240
240
|
|
|
241
|
-
loop = sink_size
|
|
242
|
-
if jit_config is not None and context.get_context('mode') == context.GRAPH_MODE:
|
|
243
|
-
loop = 1
|
|
244
|
-
|
|
245
241
|
out = None
|
|
246
|
-
for _ in range(
|
|
242
|
+
for _ in range(sink_size):
|
|
247
243
|
out = real_sink_fun()
|
|
248
244
|
|
|
249
245
|
return out
|
mindspore/train/model.py
CHANGED
|
@@ -28,7 +28,7 @@ import numpy as np
|
|
|
28
28
|
|
|
29
29
|
import mindspore
|
|
30
30
|
from mindspore import log as logger
|
|
31
|
-
from mindspore.train.serialization import save_checkpoint
|
|
31
|
+
from mindspore.train.serialization import save_checkpoint
|
|
32
32
|
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
33
33
|
from mindspore.common.tensor import Tensor
|
|
34
34
|
from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
@@ -40,16 +40,12 @@ from mindspore.train.callback import __all__ as internal_cb_names
|
|
|
40
40
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
41
41
|
from mindspore import context
|
|
42
42
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_parameter_broadcast, \
|
|
43
|
-
_device_number_check, _parameter_broadcast_check, _parallel_predict_check
|
|
44
|
-
_reset_op_id_with_offset
|
|
45
|
-
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
46
|
-
_cache_enable, _enable_distributed_mindrt
|
|
43
|
+
_device_number_check, _parameter_broadcast_check, _parallel_predict_check
|
|
47
44
|
from mindspore.train.metrics import Loss
|
|
48
45
|
from mindspore.log import vlog_print
|
|
49
46
|
from mindspore import nn
|
|
50
47
|
from mindspore.boost import AutoBoost
|
|
51
48
|
from mindspore.context import ParallelMode
|
|
52
|
-
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
53
49
|
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
|
|
54
50
|
from mindspore.common.api import _pynative_executor, ARG_SPECIFIED, TOTAL_ARG_LEN
|
|
55
51
|
from mindspore.dataset.core.config import get_debug_mode
|
|
@@ -57,7 +53,8 @@ from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_trai
|
|
|
57
53
|
from mindspore.train import amp
|
|
58
54
|
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
55
|
from mindspore._c_expression import _get_optimzer_timestamps
|
|
60
|
-
from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo
|
|
56
|
+
from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo, check_is_arf, set_is_arf
|
|
57
|
+
from mindspore._c_expression import _get_snapshot_params, _is_snapshot_valid
|
|
61
58
|
|
|
62
59
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
63
60
|
from .serialization import load_param_into_net
|
|
@@ -163,7 +160,7 @@ def _handle_exception_info(obj, uce_env, tft, e):
|
|
|
163
160
|
tft.tft_report_error(force_stop_err)
|
|
164
161
|
elif "ARF FINISH" in e_str:
|
|
165
162
|
logger.warning(f"ARF FINISH")
|
|
166
|
-
|
|
163
|
+
set_is_arf(True)
|
|
167
164
|
tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
|
|
168
165
|
else:
|
|
169
166
|
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
@@ -175,7 +172,12 @@ def _handle_training_result_error(model, tft_obj):
|
|
|
175
172
|
"""
|
|
176
173
|
Handle training result error for resuming training.
|
|
177
174
|
"""
|
|
178
|
-
|
|
175
|
+
def load_snapshot_params():
|
|
176
|
+
param_dict = {}
|
|
177
|
+
for name, tensor in _get_snapshot_params().items():
|
|
178
|
+
param_dict[name] = mindspore.Parameter(tensor, name=name)
|
|
179
|
+
return (param_dict, False)
|
|
180
|
+
ckpt_load_fn = load_snapshot_params if _is_snapshot_valid() else tft_obj.ckpt_load_func
|
|
179
181
|
train_network = tft_obj.cb_params.train_network
|
|
180
182
|
logger.warning("Process training result error start.")
|
|
181
183
|
# 1. Clear tdt channel
|
|
@@ -234,6 +236,20 @@ def _update_ckpt_callback_info(resume_train_step, **kwargs):
|
|
|
234
236
|
ckpt_obj._append_step_num = resume_train_step
|
|
235
237
|
|
|
236
238
|
|
|
239
|
+
def _get_tft_obj(**kwargs):
|
|
240
|
+
"""
|
|
241
|
+
Get TrainFaultTolerance from kwargs of callback
|
|
242
|
+
"""
|
|
243
|
+
obj = None
|
|
244
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
|
|
245
|
+
obj = kwargs.get('callbacks')
|
|
246
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
247
|
+
for item in kwargs.get('callbacks'):
|
|
248
|
+
if isinstance(item, TrainFaultTolerance):
|
|
249
|
+
obj = item
|
|
250
|
+
return obj
|
|
251
|
+
|
|
252
|
+
|
|
237
253
|
def _handle_tft(func):
|
|
238
254
|
"""
|
|
239
255
|
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
@@ -241,17 +257,11 @@ def _handle_tft(func):
|
|
|
241
257
|
|
|
242
258
|
@wraps(func)
|
|
243
259
|
def wrapper(self, *args, **kwargs):
|
|
244
|
-
obj =
|
|
245
|
-
if
|
|
246
|
-
obj = kwargs.get('callbacks')
|
|
247
|
-
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
248
|
-
for item in kwargs.get('callbacks'):
|
|
249
|
-
if isinstance(item, TrainFaultTolerance):
|
|
250
|
-
obj = item
|
|
251
|
-
if obj:
|
|
260
|
+
obj = _get_tft_obj(**kwargs)
|
|
261
|
+
if obj and not TrainFaultTolerance._only_enable_ckpt_d2h_async():
|
|
252
262
|
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
253
263
|
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env or "HCCE:1" in tft_env
|
|
254
|
-
tre_env = "TRE:1" in tft_env
|
|
264
|
+
tre_env = "TRE:1" in tft_env or "TRE:2" in tft_env
|
|
255
265
|
while True:
|
|
256
266
|
try:
|
|
257
267
|
return func(self, *args, **kwargs)
|
|
@@ -556,9 +566,7 @@ class Model:
|
|
|
556
566
|
self._current_epoch_num = 0
|
|
557
567
|
self._current_step_num = 0
|
|
558
568
|
self.epoch_iter = 0
|
|
559
|
-
self.enable_recovery = False
|
|
560
569
|
self._backbone_is_train = True
|
|
561
|
-
self.need_load_ckpt = False
|
|
562
570
|
self._lite_full_predictor = None
|
|
563
571
|
self._lite_incremental_predictor = None
|
|
564
572
|
self._mindspore_lite = None
|
|
@@ -731,10 +739,7 @@ class Model:
|
|
|
731
739
|
metrics = dict()
|
|
732
740
|
# There's no need for server to execute eval, just give fake metrics.
|
|
733
741
|
for key, value in self._metric_fns.items():
|
|
734
|
-
|
|
735
|
-
metrics[key] = value.eval()
|
|
736
|
-
else:
|
|
737
|
-
metrics[key] = 1
|
|
742
|
+
metrics[key] = value.eval()
|
|
738
743
|
return metrics
|
|
739
744
|
|
|
740
745
|
def _get_scaling_sens(self):
|
|
@@ -768,7 +773,7 @@ class Model:
|
|
|
768
773
|
logger.info("Begin to connect network with dataset.")
|
|
769
774
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
770
775
|
|
|
771
|
-
if
|
|
776
|
+
if self._need_reset_data and is_train:
|
|
772
777
|
_set_training_dataset(dataset_helper)
|
|
773
778
|
|
|
774
779
|
network.set_train(is_train)
|
|
@@ -810,9 +815,7 @@ class Model:
|
|
|
810
815
|
:param cb_params: callback params
|
|
811
816
|
:return: none
|
|
812
817
|
"""
|
|
813
|
-
if
|
|
814
|
-
return
|
|
815
|
-
if context.get_context("device_target") == "Ascend":
|
|
818
|
+
if TrainFaultTolerance._enable_snapshot() and context.get_context("device_target") == "Ascend":
|
|
816
819
|
cb_params.need_ckpt, cb_params.save_checkpoint_steps, \
|
|
817
820
|
cb_params.last_triggered_step = self._check_need_ckpt(cb_params.list_callback)
|
|
818
821
|
logger.info(f"need_ckpt:{cb_params.need_ckpt},"
|
|
@@ -1018,13 +1021,10 @@ class Model:
|
|
|
1018
1021
|
callbacks = cb_params.list_callback
|
|
1019
1022
|
cb_params.train_dataset_element = None
|
|
1020
1023
|
cb_params.network = self._network
|
|
1021
|
-
# Embedding cache server only run one step.
|
|
1022
|
-
if _is_role_pserver() and _cache_enable():
|
|
1023
|
-
epoch = 1
|
|
1024
1024
|
cb_params.last_save_ckpt_step = None
|
|
1025
1025
|
cb_params.latest_ckpt_file = None
|
|
1026
1026
|
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
1027
|
-
cb_params.is_arf =
|
|
1027
|
+
cb_params.is_arf = check_is_arf()
|
|
1028
1028
|
cb_params.initial_step = self._initial_step
|
|
1029
1029
|
|
|
1030
1030
|
# build callback list
|
|
@@ -1086,12 +1086,6 @@ class Model:
|
|
|
1086
1086
|
dataset_helper = train_dataset._dataset_helper
|
|
1087
1087
|
|
|
1088
1088
|
self.epoch_iter = 0
|
|
1089
|
-
self._check_enable_recovery()
|
|
1090
|
-
# Used to check whether need perform recovery for process which is restarted.
|
|
1091
|
-
self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
|
|
1092
|
-
# Check whether this process is embedding cache server.
|
|
1093
|
-
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1094
|
-
|
|
1095
1089
|
while self.epoch_iter < (epoch - initial_epoch):
|
|
1096
1090
|
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
|
|
1097
1091
|
self._current_epoch_num = cb_params.cur_epoch_num
|
|
@@ -1107,11 +1101,6 @@ class Model:
|
|
|
1107
1101
|
cb_params.train_network = train_network
|
|
1108
1102
|
cb_params.dataset_helper = dataset_helper
|
|
1109
1103
|
|
|
1110
|
-
# Perform recovery for process which is restarted.
|
|
1111
|
-
self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
|
|
1112
|
-
# Perform recovery for process which is not restarted.
|
|
1113
|
-
self._reset_training_step_for_normal_process(cb_params, dataset_helper)
|
|
1114
|
-
|
|
1115
1104
|
# For data sink dataset_helper only iter once, other wise iter epoch_size times.
|
|
1116
1105
|
for inputs in dataset_helper:
|
|
1117
1106
|
if is_graph:
|
|
@@ -1126,36 +1115,17 @@ class Model:
|
|
|
1126
1115
|
outputs = train_network(*inputs)
|
|
1127
1116
|
cb_params.net_outputs = outputs
|
|
1128
1117
|
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
if need_exec_callback_step_end:
|
|
1132
|
-
list_callback.on_train_step_end(run_context)
|
|
1118
|
+
list_callback.on_train_step_end(run_context)
|
|
1119
|
+
|
|
1133
1120
|
if cb_params.is_arf:
|
|
1134
1121
|
cb_params.is_arf = False
|
|
1135
|
-
|
|
1122
|
+
set_is_arf(False)
|
|
1136
1123
|
_clean_rootinfo()
|
|
1137
1124
|
|
|
1138
|
-
# Embedding cache server only run one step.
|
|
1139
|
-
if is_embedding_cache_server:
|
|
1140
|
-
break
|
|
1141
|
-
|
|
1142
1125
|
dataset_helper.continue_send()
|
|
1143
1126
|
|
|
1144
|
-
# When it's distributed training and using MindRT,
|
|
1145
|
-
# the node id should be reset to start from 0.
|
|
1146
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1147
|
-
if _enable_distributed_mindrt():
|
|
1148
|
-
_reset_op_id_with_offset()
|
|
1149
|
-
|
|
1150
1127
|
self._eval_during_train(valid_infos, cb_params, list_callback)
|
|
1151
|
-
|
|
1152
|
-
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
|
|
1153
|
-
# Embedding cache server need not do epoch end callback, this process only run one step.
|
|
1154
|
-
need_exec_callback_epoch_end = not ((self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1155
|
-
or is_embedding_cache_server)
|
|
1156
|
-
|
|
1157
|
-
if need_exec_callback_epoch_end:
|
|
1158
|
-
list_callback.on_train_epoch_end(run_context)
|
|
1128
|
+
list_callback.on_train_epoch_end(run_context)
|
|
1159
1129
|
if "metrics" in cb_params or "eval_results" in cb_params:
|
|
1160
1130
|
cb_params.pop("metrics", None)
|
|
1161
1131
|
cb_params.pop("eval_results", None)
|
|
@@ -1164,12 +1134,7 @@ class Model:
|
|
|
1164
1134
|
if should_stop:
|
|
1165
1135
|
break
|
|
1166
1136
|
|
|
1167
|
-
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
|
|
1168
|
-
and not _get_recovery_context("latest_ckpt_file")
|
|
1169
1137
|
self.epoch_iter += 1
|
|
1170
|
-
if need_reset_to_beginning:
|
|
1171
|
-
self.epoch_iter = 0
|
|
1172
|
-
cb_params.cur_step_num = 0
|
|
1173
1138
|
|
|
1174
1139
|
dataset_helper.stop_send()
|
|
1175
1140
|
dataset_helper.release()
|
|
@@ -1203,93 +1168,6 @@ class Model:
|
|
|
1203
1168
|
cb_params.dataset_sink_mode = train_dataset_sink_mode
|
|
1204
1169
|
cb_params.net_outputs = train_net_outputs
|
|
1205
1170
|
|
|
1206
|
-
def _check_enable_recovery(self):
|
|
1207
|
-
"""
|
|
1208
|
-
Check whether enable recovery and execution mode consistency.
|
|
1209
|
-
"""
|
|
1210
|
-
|
|
1211
|
-
enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
|
|
1212
|
-
if not enable_recovery:
|
|
1213
|
-
self.enable_recovery = False
|
|
1214
|
-
else:
|
|
1215
|
-
self.enable_recovery = enable_recovery and _is_role_worker()
|
|
1216
|
-
|
|
1217
|
-
def _check_need_load_ckpt(self, cb_params, dataset_size, sink_size=-1):
|
|
1218
|
-
"""
|
|
1219
|
-
Check whether need to load checkpoint after abnormal process restart.
|
|
1220
|
-
|
|
1221
|
-
Args:
|
|
1222
|
-
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1223
|
-
dataset_size (int): The number of batches in a dataset.
|
|
1224
|
-
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
1225
|
-
"""
|
|
1226
|
-
if context.get_context("device_target") != "GPU":
|
|
1227
|
-
return
|
|
1228
|
-
if not self.enable_recovery:
|
|
1229
|
-
self.need_load_ckpt = False
|
|
1230
|
-
|
|
1231
|
-
cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
|
|
1232
|
-
if cb_params.latest_ckpt_file:
|
|
1233
|
-
recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
|
|
1234
|
-
recovery_step_num = _get_recovery_context("latest_ckpt_step")
|
|
1235
|
-
dataset_sink_size = sink_size if sink_size > 0 else dataset_size
|
|
1236
|
-
cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_sink_size + recovery_step_num
|
|
1237
|
-
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
1238
|
-
self.epoch_iter = recovery_epoch_num
|
|
1239
|
-
self.need_load_ckpt = True
|
|
1240
|
-
else:
|
|
1241
|
-
self.need_load_ckpt = False
|
|
1242
|
-
|
|
1243
|
-
def _reset_training_step_for_abnormal_process(self, cb_params, dataset_helper):
|
|
1244
|
-
"""
|
|
1245
|
-
Execute recovery for abnormal exit process when restart.
|
|
1246
|
-
|
|
1247
|
-
Args:
|
|
1248
|
-
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1249
|
-
"""
|
|
1250
|
-
|
|
1251
|
-
if self.need_load_ckpt:
|
|
1252
|
-
try:
|
|
1253
|
-
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1254
|
-
except BaseException as e:
|
|
1255
|
-
os.remove(cb_params.latest_ckpt_file)
|
|
1256
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
|
|
1257
|
-
+ cb_params.latest_ckpt_file) from e
|
|
1258
|
-
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1259
|
-
self.need_load_ckpt = False
|
|
1260
|
-
|
|
1261
|
-
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
|
|
1262
|
-
"""
|
|
1263
|
-
Execute recovery for normal process when there is process exit abnormally.
|
|
1264
|
-
|
|
1265
|
-
Args:
|
|
1266
|
-
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1267
|
-
dataset_helper (DatasetHelper): A class to process the MindData dataset,
|
|
1268
|
-
it provides the type, shape and queue name of the dataset to wrap the `GetNext`.
|
|
1269
|
-
"""
|
|
1270
|
-
|
|
1271
|
-
if self.enable_recovery and _get_recovery_context("need_reset"):
|
|
1272
|
-
cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
|
|
1273
|
-
if cb_params.latest_ckpt_file:
|
|
1274
|
-
try:
|
|
1275
|
-
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1276
|
-
except BaseException as e:
|
|
1277
|
-
os.remove(cb_params.latest_ckpt_file)
|
|
1278
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1279
|
-
+ cb_params.latest_ckpt_file) from e
|
|
1280
|
-
|
|
1281
|
-
recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
|
|
1282
|
-
recovery_step_num = _get_recovery_context("latest_ckpt_step")
|
|
1283
|
-
cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_helper.sink_size() + recovery_step_num
|
|
1284
|
-
self.epoch_iter = recovery_epoch_num
|
|
1285
|
-
cb_params.cur_epoch_num = self.epoch_iter + 1
|
|
1286
|
-
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
1287
|
-
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1288
|
-
else:
|
|
1289
|
-
_reset_training_dataset(0, dataset_helper.iter.dataset.get_dataset_size())
|
|
1290
|
-
|
|
1291
|
-
_set_recovery_context(need_reset=False)
|
|
1292
|
-
|
|
1293
1171
|
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0,
|
|
1294
1172
|
valid_infos=None):
|
|
1295
1173
|
"""
|
|
@@ -1314,7 +1192,6 @@ class Model:
|
|
|
1314
1192
|
cb_params.dataset_sink_mode = False
|
|
1315
1193
|
run_context = RunContext(cb_params)
|
|
1316
1194
|
list_callback.on_train_begin(run_context)
|
|
1317
|
-
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1318
1195
|
|
|
1319
1196
|
for i in range(initial_epoch, epoch):
|
|
1320
1197
|
cb_params.cur_epoch_num = i + 1
|
|
@@ -1345,21 +1222,12 @@ class Model:
|
|
|
1345
1222
|
list_callback.on_train_step_end(run_context)
|
|
1346
1223
|
if cb_params.is_arf:
|
|
1347
1224
|
cb_params.is_arf = False
|
|
1348
|
-
|
|
1225
|
+
set_is_arf(False)
|
|
1349
1226
|
_clean_rootinfo()
|
|
1350
|
-
# Embedding cache server only run one step.
|
|
1351
|
-
if is_embedding_cache_server:
|
|
1352
|
-
break
|
|
1353
1227
|
should_stop = run_context.get_stop_requested()
|
|
1354
1228
|
if should_stop:
|
|
1355
1229
|
break
|
|
1356
1230
|
|
|
1357
|
-
# When it's distributed training and using MindRT,
|
|
1358
|
-
# the node id should be reset to start from 0.
|
|
1359
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1360
|
-
if _enable_distributed_mindrt():
|
|
1361
|
-
_reset_op_id_with_offset()
|
|
1362
|
-
|
|
1363
1231
|
self._eval_during_train(valid_infos, cb_params, list_callback)
|
|
1364
1232
|
|
|
1365
1233
|
train_dataset.reset()
|
|
@@ -1367,9 +1235,7 @@ class Model:
|
|
|
1367
1235
|
# if param is cache enable, flush data from cache to host before epoch end
|
|
1368
1236
|
self._flush_from_cache(cb_params)
|
|
1369
1237
|
|
|
1370
|
-
|
|
1371
|
-
if not is_embedding_cache_server:
|
|
1372
|
-
list_callback.on_train_epoch_end(run_context)
|
|
1238
|
+
list_callback.on_train_epoch_end(run_context)
|
|
1373
1239
|
if "metrics" in cb_params or "eval_results" in cb_params:
|
|
1374
1240
|
cb_params.pop("metrics", None)
|
|
1375
1241
|
cb_params.pop("eval_results", None)
|
|
@@ -1446,10 +1312,6 @@ class Model:
|
|
|
1446
1312
|
"""
|
|
1447
1313
|
_init_auto_parallel_context(self._network)
|
|
1448
1314
|
_check_tft()
|
|
1449
|
-
device_target = context.get_context("device_target")
|
|
1450
|
-
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1451
|
-
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1452
|
-
dataset_sink_mode = False
|
|
1453
1315
|
|
|
1454
1316
|
Validator.check_bool(dataset_sink_mode)
|
|
1455
1317
|
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
|
@@ -1461,11 +1323,6 @@ class Model:
|
|
|
1461
1323
|
"the value of epoch in train {} separately."
|
|
1462
1324
|
.format(train_dataset._warmup_epoch, epoch))
|
|
1463
1325
|
|
|
1464
|
-
# Parameter server and embedding cache mode check.
|
|
1465
|
-
if _is_ps_mode():
|
|
1466
|
-
if not dataset_sink_mode and _cache_enable():
|
|
1467
|
-
raise ValueError("Embedding cache mode should run with 'dataset_sink_mode=True'.")
|
|
1468
|
-
|
|
1469
1326
|
self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
|
|
1470
1327
|
|
|
1471
1328
|
Validator.check_is_int(sink_size)
|
|
@@ -1496,12 +1353,6 @@ class Model:
|
|
|
1496
1353
|
sink_size=sink_size,
|
|
1497
1354
|
initial_epoch=initial_epoch)
|
|
1498
1355
|
|
|
1499
|
-
# When it's distributed training and using MindRT,
|
|
1500
|
-
# the node id should be reset to start from 0.
|
|
1501
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1502
|
-
if _enable_distributed_mindrt():
|
|
1503
|
-
_reset_op_id_with_offset()
|
|
1504
|
-
|
|
1505
1356
|
_clear_auto_parallel_context(self._network)
|
|
1506
1357
|
|
|
1507
1358
|
@staticmethod
|
|
@@ -1599,10 +1450,6 @@ class Model:
|
|
|
1599
1450
|
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1600
1451
|
"""
|
|
1601
1452
|
_init_auto_parallel_context(self._network)
|
|
1602
|
-
device_target = context.get_context("device_target")
|
|
1603
|
-
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1604
|
-
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1605
|
-
dataset_sink_mode = False
|
|
1606
1453
|
|
|
1607
1454
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1608
1455
|
valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
|
|
@@ -1896,13 +1743,6 @@ class Model:
|
|
|
1896
1743
|
|
|
1897
1744
|
self._clear_metrics()
|
|
1898
1745
|
|
|
1899
|
-
# Embedding cache server as a storage service, no need to execute eval.
|
|
1900
|
-
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1901
|
-
if is_embedding_cache_server:
|
|
1902
|
-
metrics = self._get_metrics()
|
|
1903
|
-
cb_params.metrics = metrics
|
|
1904
|
-
return metrics
|
|
1905
|
-
|
|
1906
1746
|
if context.get_context("device_target") == "CPU" and dataset_sink_mode:
|
|
1907
1747
|
dataset_sink_mode = False
|
|
1908
1748
|
logger.info("CPU cannot support dataset sink mode currently."
|
|
@@ -1914,13 +1754,7 @@ class Model:
|
|
|
1914
1754
|
else:
|
|
1915
1755
|
eval_result = self._eval_process(valid_dataset, list_callback, cb_params)
|
|
1916
1756
|
|
|
1917
|
-
# When it's distributed training and using MindRT,
|
|
1918
|
-
# the node id should be reset to start from 0.
|
|
1919
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1920
|
-
if _enable_distributed_mindrt():
|
|
1921
|
-
_reset_op_id_with_offset()
|
|
1922
1757
|
_clear_auto_parallel_context(self._network)
|
|
1923
|
-
|
|
1924
1758
|
return eval_result
|
|
1925
1759
|
|
|
1926
1760
|
def _predict_lite(self, *predict_data, config=None):
|
|
@@ -2171,13 +2005,6 @@ class Model:
|
|
|
2171
2005
|
result = self._predict_network(*predict_data)
|
|
2172
2006
|
|
|
2173
2007
|
check_output_data(result)
|
|
2174
|
-
|
|
2175
|
-
# When it's distributed training and using MindRT,
|
|
2176
|
-
# the node id should be reset to start from 0.
|
|
2177
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
2178
|
-
if _enable_distributed_mindrt():
|
|
2179
|
-
_reset_op_id_with_offset()
|
|
2180
|
-
|
|
2181
2008
|
return result
|
|
2182
2009
|
|
|
2183
2010
|
def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):
|