mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -28,7 +28,7 @@ import numpy as np
|
|
|
28
28
|
|
|
29
29
|
import mindspore
|
|
30
30
|
from mindspore import log as logger
|
|
31
|
-
from mindspore.train.serialization import save_checkpoint
|
|
31
|
+
from mindspore.train.serialization import save_checkpoint
|
|
32
32
|
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
33
33
|
from mindspore.common.tensor import Tensor
|
|
34
34
|
from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
@@ -40,16 +40,12 @@ from mindspore.train.callback import __all__ as internal_cb_names
|
|
|
40
40
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
41
41
|
from mindspore import context
|
|
42
42
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_parameter_broadcast, \
|
|
43
|
-
_device_number_check, _parameter_broadcast_check, _parallel_predict_check
|
|
44
|
-
_reset_op_id_with_offset
|
|
45
|
-
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
46
|
-
_cache_enable, _enable_distributed_mindrt
|
|
43
|
+
_device_number_check, _parameter_broadcast_check, _parallel_predict_check
|
|
47
44
|
from mindspore.train.metrics import Loss
|
|
48
45
|
from mindspore.log import vlog_print
|
|
49
46
|
from mindspore import nn
|
|
50
47
|
from mindspore.boost import AutoBoost
|
|
51
48
|
from mindspore.context import ParallelMode
|
|
52
|
-
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
53
49
|
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
|
|
54
50
|
from mindspore.common.api import _pynative_executor, ARG_SPECIFIED, TOTAL_ARG_LEN
|
|
55
51
|
from mindspore.dataset.core.config import get_debug_mode
|
|
@@ -57,7 +53,8 @@ from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_trai
|
|
|
57
53
|
from mindspore.train import amp
|
|
58
54
|
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
55
|
from mindspore._c_expression import _get_optimzer_timestamps
|
|
60
|
-
from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo
|
|
56
|
+
from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo, check_is_arf, set_is_arf
|
|
57
|
+
from mindspore._c_expression import _get_snapshot_params, _is_snapshot_valid
|
|
61
58
|
|
|
62
59
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
63
60
|
from .serialization import load_param_into_net
|
|
@@ -156,18 +153,14 @@ def _handle_exception_info(obj, uce_env, tft, e):
|
|
|
156
153
|
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
157
154
|
elif "HCCEError" in e_str:
|
|
158
155
|
logger.warning("uce wrapper caught HCCEError")
|
|
159
|
-
|
|
160
|
-
logger.warning("Received HCCEError after force stop been called, so report force stopped error to MindIO.")
|
|
161
|
-
tft.tft_report_error(tft.ReportState.RS_NORMAL.value)
|
|
162
|
-
else:
|
|
163
|
-
tft.tft_report_error(tft.ReportState.RS_HCCL_FAILED.value)
|
|
156
|
+
tft.tft_report_error(tft.ReportState.RS_HCCL_FAILED.value)
|
|
164
157
|
elif "ForceStopError" in e_str:
|
|
165
158
|
logger.warning("uce wrapper caught RuntimeError ForceStopError")
|
|
166
159
|
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
167
160
|
tft.tft_report_error(force_stop_err)
|
|
168
161
|
elif "ARF FINISH" in e_str:
|
|
169
162
|
logger.warning(f"ARF FINISH")
|
|
170
|
-
|
|
163
|
+
set_is_arf(True)
|
|
171
164
|
tft.tft_report_error(tft.ReportState.RS_PREREPAIR_FINISH.value)
|
|
172
165
|
else:
|
|
173
166
|
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
@@ -179,7 +172,12 @@ def _handle_training_result_error(model, tft_obj):
|
|
|
179
172
|
"""
|
|
180
173
|
Handle training result error for resuming training.
|
|
181
174
|
"""
|
|
182
|
-
|
|
175
|
+
def load_snapshot_params():
|
|
176
|
+
param_dict = {}
|
|
177
|
+
for name, tensor in _get_snapshot_params().items():
|
|
178
|
+
param_dict[name] = mindspore.Parameter(tensor, name=name)
|
|
179
|
+
return (param_dict, False)
|
|
180
|
+
ckpt_load_fn = load_snapshot_params if _is_snapshot_valid() else tft_obj.ckpt_load_func
|
|
183
181
|
train_network = tft_obj.cb_params.train_network
|
|
184
182
|
logger.warning("Process training result error start.")
|
|
185
183
|
# 1. Clear tdt channel
|
|
@@ -238,6 +236,20 @@ def _update_ckpt_callback_info(resume_train_step, **kwargs):
|
|
|
238
236
|
ckpt_obj._append_step_num = resume_train_step
|
|
239
237
|
|
|
240
238
|
|
|
239
|
+
def _get_tft_obj(**kwargs):
|
|
240
|
+
"""
|
|
241
|
+
Get TrainFaultTolerance from kwargs of callback
|
|
242
|
+
"""
|
|
243
|
+
obj = None
|
|
244
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TrainFaultTolerance):
|
|
245
|
+
obj = kwargs.get('callbacks')
|
|
246
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
247
|
+
for item in kwargs.get('callbacks'):
|
|
248
|
+
if isinstance(item, TrainFaultTolerance):
|
|
249
|
+
obj = item
|
|
250
|
+
return obj
|
|
251
|
+
|
|
252
|
+
|
|
241
253
|
def _handle_tft(func):
|
|
242
254
|
"""
|
|
243
255
|
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
@@ -245,17 +257,11 @@ def _handle_tft(func):
|
|
|
245
257
|
|
|
246
258
|
@wraps(func)
|
|
247
259
|
def wrapper(self, *args, **kwargs):
|
|
248
|
-
obj =
|
|
249
|
-
if
|
|
250
|
-
obj = kwargs.get('callbacks')
|
|
251
|
-
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
252
|
-
for item in kwargs.get('callbacks'):
|
|
253
|
-
if isinstance(item, TrainFaultTolerance):
|
|
254
|
-
obj = item
|
|
255
|
-
if obj:
|
|
260
|
+
obj = _get_tft_obj(**kwargs)
|
|
261
|
+
if obj and not TrainFaultTolerance._only_enable_ckpt_d2h_async():
|
|
256
262
|
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
257
263
|
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env or "HCCE:1" in tft_env
|
|
258
|
-
tre_env = "TRE:1" in tft_env
|
|
264
|
+
tre_env = "TRE:1" in tft_env or "TRE:2" in tft_env
|
|
259
265
|
while True:
|
|
260
266
|
try:
|
|
261
267
|
return func(self, *args, **kwargs)
|
|
@@ -270,7 +276,6 @@ def _handle_tft(func):
|
|
|
270
276
|
ret = obj.tft.tft_wait_next_action()
|
|
271
277
|
if ret == obj.tft.Action.EXIT.value:
|
|
272
278
|
raise e
|
|
273
|
-
obj.stop_been_called = False
|
|
274
279
|
repair_step = obj.tft.tft_get_repair_step()
|
|
275
280
|
logger.warning(
|
|
276
281
|
"uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
|
|
@@ -308,9 +313,6 @@ def _check_tft():
|
|
|
308
313
|
ascend_target = MSContext.get_instance().get_ascend_soc_version()
|
|
309
314
|
if ascend_target == 'ascend910':
|
|
310
315
|
raise ValueError("TFT is not supported when using ascend910")
|
|
311
|
-
ms_mode = context.get_context("mode")
|
|
312
|
-
if ms_mode != mindspore.GRAPH_MODE:
|
|
313
|
-
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
314
316
|
jit_level = context.get_context("jit_level")
|
|
315
317
|
if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
|
|
316
318
|
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
@@ -564,9 +566,7 @@ class Model:
|
|
|
564
566
|
self._current_epoch_num = 0
|
|
565
567
|
self._current_step_num = 0
|
|
566
568
|
self.epoch_iter = 0
|
|
567
|
-
self.enable_recovery = False
|
|
568
569
|
self._backbone_is_train = True
|
|
569
|
-
self.need_load_ckpt = False
|
|
570
570
|
self._lite_full_predictor = None
|
|
571
571
|
self._lite_incremental_predictor = None
|
|
572
572
|
self._mindspore_lite = None
|
|
@@ -739,10 +739,7 @@ class Model:
|
|
|
739
739
|
metrics = dict()
|
|
740
740
|
# There's no need for server to execute eval, just give fake metrics.
|
|
741
741
|
for key, value in self._metric_fns.items():
|
|
742
|
-
|
|
743
|
-
metrics[key] = value.eval()
|
|
744
|
-
else:
|
|
745
|
-
metrics[key] = 1
|
|
742
|
+
metrics[key] = value.eval()
|
|
746
743
|
return metrics
|
|
747
744
|
|
|
748
745
|
def _get_scaling_sens(self):
|
|
@@ -776,7 +773,7 @@ class Model:
|
|
|
776
773
|
logger.info("Begin to connect network with dataset.")
|
|
777
774
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
778
775
|
|
|
779
|
-
if
|
|
776
|
+
if self._need_reset_data and is_train:
|
|
780
777
|
_set_training_dataset(dataset_helper)
|
|
781
778
|
|
|
782
779
|
network.set_train(is_train)
|
|
@@ -818,9 +815,7 @@ class Model:
|
|
|
818
815
|
:param cb_params: callback params
|
|
819
816
|
:return: none
|
|
820
817
|
"""
|
|
821
|
-
if
|
|
822
|
-
return
|
|
823
|
-
if (context.get_context("mode") == context.GRAPH_MODE) and (context.get_context("device_target") == "Ascend"):
|
|
818
|
+
if TrainFaultTolerance._enable_snapshot() and context.get_context("device_target") == "Ascend":
|
|
824
819
|
cb_params.need_ckpt, cb_params.save_checkpoint_steps, \
|
|
825
820
|
cb_params.last_triggered_step = self._check_need_ckpt(cb_params.list_callback)
|
|
826
821
|
logger.info(f"need_ckpt:{cb_params.need_ckpt},"
|
|
@@ -888,8 +883,8 @@ class Model:
|
|
|
888
883
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
889
884
|
epoch (int): Total number of iterations on the data. Default: 1.
|
|
890
885
|
"""
|
|
891
|
-
if context.get_context("
|
|
892
|
-
raise RuntimeError('Pre-init process only supports
|
|
886
|
+
if context.get_context("device_target") != "Ascend":
|
|
887
|
+
raise RuntimeError('Pre-init process only supports Ascend target currently.')
|
|
893
888
|
|
|
894
889
|
if not train_dataset and not valid_dataset:
|
|
895
890
|
raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
|
|
@@ -1026,13 +1021,10 @@ class Model:
|
|
|
1026
1021
|
callbacks = cb_params.list_callback
|
|
1027
1022
|
cb_params.train_dataset_element = None
|
|
1028
1023
|
cb_params.network = self._network
|
|
1029
|
-
# Embedding cache server only run one step.
|
|
1030
|
-
if _is_role_pserver() and _cache_enable():
|
|
1031
|
-
epoch = 1
|
|
1032
1024
|
cb_params.last_save_ckpt_step = None
|
|
1033
1025
|
cb_params.latest_ckpt_file = None
|
|
1034
1026
|
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
1035
|
-
cb_params.is_arf =
|
|
1027
|
+
cb_params.is_arf = check_is_arf()
|
|
1036
1028
|
cb_params.initial_step = self._initial_step
|
|
1037
1029
|
|
|
1038
1030
|
# build callback list
|
|
@@ -1094,12 +1086,6 @@ class Model:
|
|
|
1094
1086
|
dataset_helper = train_dataset._dataset_helper
|
|
1095
1087
|
|
|
1096
1088
|
self.epoch_iter = 0
|
|
1097
|
-
self._check_enable_recovery()
|
|
1098
|
-
# Used to check whether need perform recovery for process which is restarted.
|
|
1099
|
-
self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
|
|
1100
|
-
# Check whether this process is embedding cache server.
|
|
1101
|
-
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1102
|
-
|
|
1103
1089
|
while self.epoch_iter < (epoch - initial_epoch):
|
|
1104
1090
|
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
|
|
1105
1091
|
self._current_epoch_num = cb_params.cur_epoch_num
|
|
@@ -1115,11 +1101,6 @@ class Model:
|
|
|
1115
1101
|
cb_params.train_network = train_network
|
|
1116
1102
|
cb_params.dataset_helper = dataset_helper
|
|
1117
1103
|
|
|
1118
|
-
# Perform recovery for process which is restarted.
|
|
1119
|
-
self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
|
|
1120
|
-
# Perform recovery for process which is not restarted.
|
|
1121
|
-
self._reset_training_step_for_normal_process(cb_params, dataset_helper)
|
|
1122
|
-
|
|
1123
1104
|
# For data sink dataset_helper only iter once, other wise iter epoch_size times.
|
|
1124
1105
|
for inputs in dataset_helper:
|
|
1125
1106
|
if is_graph:
|
|
@@ -1134,36 +1115,17 @@ class Model:
|
|
|
1134
1115
|
outputs = train_network(*inputs)
|
|
1135
1116
|
cb_params.net_outputs = outputs
|
|
1136
1117
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
if need_exec_callback_step_end:
|
|
1140
|
-
list_callback.on_train_step_end(run_context)
|
|
1118
|
+
list_callback.on_train_step_end(run_context)
|
|
1119
|
+
|
|
1141
1120
|
if cb_params.is_arf:
|
|
1142
1121
|
cb_params.is_arf = False
|
|
1143
|
-
|
|
1122
|
+
set_is_arf(False)
|
|
1144
1123
|
_clean_rootinfo()
|
|
1145
1124
|
|
|
1146
|
-
# Embedding cache server only run one step.
|
|
1147
|
-
if is_embedding_cache_server:
|
|
1148
|
-
break
|
|
1149
|
-
|
|
1150
1125
|
dataset_helper.continue_send()
|
|
1151
1126
|
|
|
1152
|
-
# When it's distributed training and using MindRT,
|
|
1153
|
-
# the node id should be reset to start from 0.
|
|
1154
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1155
|
-
if _enable_distributed_mindrt():
|
|
1156
|
-
_reset_op_id_with_offset()
|
|
1157
|
-
|
|
1158
1127
|
self._eval_during_train(valid_infos, cb_params, list_callback)
|
|
1159
|
-
|
|
1160
|
-
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
|
|
1161
|
-
# Embedding cache server need not do epoch end callback, this process only run one step.
|
|
1162
|
-
need_exec_callback_epoch_end = not ((self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1163
|
-
or is_embedding_cache_server)
|
|
1164
|
-
|
|
1165
|
-
if need_exec_callback_epoch_end:
|
|
1166
|
-
list_callback.on_train_epoch_end(run_context)
|
|
1128
|
+
list_callback.on_train_epoch_end(run_context)
|
|
1167
1129
|
if "metrics" in cb_params or "eval_results" in cb_params:
|
|
1168
1130
|
cb_params.pop("metrics", None)
|
|
1169
1131
|
cb_params.pop("eval_results", None)
|
|
@@ -1172,12 +1134,7 @@ class Model:
|
|
|
1172
1134
|
if should_stop:
|
|
1173
1135
|
break
|
|
1174
1136
|
|
|
1175
|
-
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset") \
|
|
1176
|
-
and not _get_recovery_context("latest_ckpt_file")
|
|
1177
1137
|
self.epoch_iter += 1
|
|
1178
|
-
if need_reset_to_beginning:
|
|
1179
|
-
self.epoch_iter = 0
|
|
1180
|
-
cb_params.cur_step_num = 0
|
|
1181
1138
|
|
|
1182
1139
|
dataset_helper.stop_send()
|
|
1183
1140
|
dataset_helper.release()
|
|
@@ -1211,95 +1168,6 @@ class Model:
|
|
|
1211
1168
|
cb_params.dataset_sink_mode = train_dataset_sink_mode
|
|
1212
1169
|
cb_params.net_outputs = train_net_outputs
|
|
1213
1170
|
|
|
1214
|
-
def _check_enable_recovery(self):
|
|
1215
|
-
"""
|
|
1216
|
-
Check whether enable recovery and execution mode consistency.
|
|
1217
|
-
"""
|
|
1218
|
-
|
|
1219
|
-
enable_recovery = _get_recovery_context("enable_recovery") and context.get_context("device_target") == "GPU"
|
|
1220
|
-
if not enable_recovery:
|
|
1221
|
-
self.enable_recovery = False
|
|
1222
|
-
else:
|
|
1223
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
1224
|
-
raise RuntimeError("Recovery for training only support graph mode currently.")
|
|
1225
|
-
self.enable_recovery = enable_recovery and _is_role_worker()
|
|
1226
|
-
|
|
1227
|
-
def _check_need_load_ckpt(self, cb_params, dataset_size, sink_size=-1):
|
|
1228
|
-
"""
|
|
1229
|
-
Check whether need to load checkpoint after abnormal process restart.
|
|
1230
|
-
|
|
1231
|
-
Args:
|
|
1232
|
-
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1233
|
-
dataset_size (int): The number of batches in a dataset.
|
|
1234
|
-
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
1235
|
-
"""
|
|
1236
|
-
if context.get_context("device_target") != "GPU":
|
|
1237
|
-
return
|
|
1238
|
-
if not self.enable_recovery:
|
|
1239
|
-
self.need_load_ckpt = False
|
|
1240
|
-
|
|
1241
|
-
cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
|
|
1242
|
-
if cb_params.latest_ckpt_file:
|
|
1243
|
-
recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
|
|
1244
|
-
recovery_step_num = _get_recovery_context("latest_ckpt_step")
|
|
1245
|
-
dataset_sink_size = sink_size if sink_size > 0 else dataset_size
|
|
1246
|
-
cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_sink_size + recovery_step_num
|
|
1247
|
-
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
1248
|
-
self.epoch_iter = recovery_epoch_num
|
|
1249
|
-
self.need_load_ckpt = True
|
|
1250
|
-
else:
|
|
1251
|
-
self.need_load_ckpt = False
|
|
1252
|
-
|
|
1253
|
-
def _reset_training_step_for_abnormal_process(self, cb_params, dataset_helper):
|
|
1254
|
-
"""
|
|
1255
|
-
Execute recovery for abnormal exit process when restart.
|
|
1256
|
-
|
|
1257
|
-
Args:
|
|
1258
|
-
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1259
|
-
"""
|
|
1260
|
-
|
|
1261
|
-
if self.need_load_ckpt:
|
|
1262
|
-
try:
|
|
1263
|
-
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1264
|
-
except BaseException as e:
|
|
1265
|
-
os.remove(cb_params.latest_ckpt_file)
|
|
1266
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: " \
|
|
1267
|
-
+ cb_params.latest_ckpt_file) from e
|
|
1268
|
-
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1269
|
-
self.need_load_ckpt = False
|
|
1270
|
-
|
|
1271
|
-
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
|
|
1272
|
-
"""
|
|
1273
|
-
Execute recovery for normal process when there is process exit abnormally.
|
|
1274
|
-
|
|
1275
|
-
Args:
|
|
1276
|
-
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1277
|
-
dataset_helper (DatasetHelper): A class to process the MindData dataset,
|
|
1278
|
-
it provides the type, shape and queue name of the dataset to wrap the `GetNext`.
|
|
1279
|
-
"""
|
|
1280
|
-
|
|
1281
|
-
if self.enable_recovery and _get_recovery_context("need_reset"):
|
|
1282
|
-
cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
|
|
1283
|
-
if cb_params.latest_ckpt_file:
|
|
1284
|
-
try:
|
|
1285
|
-
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1286
|
-
except BaseException as e:
|
|
1287
|
-
os.remove(cb_params.latest_ckpt_file)
|
|
1288
|
-
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1289
|
-
+ cb_params.latest_ckpt_file) from e
|
|
1290
|
-
|
|
1291
|
-
recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
|
|
1292
|
-
recovery_step_num = _get_recovery_context("latest_ckpt_step")
|
|
1293
|
-
cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_helper.sink_size() + recovery_step_num
|
|
1294
|
-
self.epoch_iter = recovery_epoch_num
|
|
1295
|
-
cb_params.cur_epoch_num = self.epoch_iter + 1
|
|
1296
|
-
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
1297
|
-
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1298
|
-
else:
|
|
1299
|
-
_reset_training_dataset(0, dataset_helper.iter.dataset.get_dataset_size())
|
|
1300
|
-
|
|
1301
|
-
_set_recovery_context(need_reset=False)
|
|
1302
|
-
|
|
1303
1171
|
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0,
|
|
1304
1172
|
valid_infos=None):
|
|
1305
1173
|
"""
|
|
@@ -1324,7 +1192,6 @@ class Model:
|
|
|
1324
1192
|
cb_params.dataset_sink_mode = False
|
|
1325
1193
|
run_context = RunContext(cb_params)
|
|
1326
1194
|
list_callback.on_train_begin(run_context)
|
|
1327
|
-
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1328
1195
|
|
|
1329
1196
|
for i in range(initial_epoch, epoch):
|
|
1330
1197
|
cb_params.cur_epoch_num = i + 1
|
|
@@ -1355,21 +1222,12 @@ class Model:
|
|
|
1355
1222
|
list_callback.on_train_step_end(run_context)
|
|
1356
1223
|
if cb_params.is_arf:
|
|
1357
1224
|
cb_params.is_arf = False
|
|
1358
|
-
|
|
1225
|
+
set_is_arf(False)
|
|
1359
1226
|
_clean_rootinfo()
|
|
1360
|
-
# Embedding cache server only run one step.
|
|
1361
|
-
if is_embedding_cache_server:
|
|
1362
|
-
break
|
|
1363
1227
|
should_stop = run_context.get_stop_requested()
|
|
1364
1228
|
if should_stop:
|
|
1365
1229
|
break
|
|
1366
1230
|
|
|
1367
|
-
# When it's distributed training and using MindRT,
|
|
1368
|
-
# the node id should be reset to start from 0.
|
|
1369
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1370
|
-
if _enable_distributed_mindrt():
|
|
1371
|
-
_reset_op_id_with_offset()
|
|
1372
|
-
|
|
1373
1231
|
self._eval_during_train(valid_infos, cb_params, list_callback)
|
|
1374
1232
|
|
|
1375
1233
|
train_dataset.reset()
|
|
@@ -1377,9 +1235,7 @@ class Model:
|
|
|
1377
1235
|
# if param is cache enable, flush data from cache to host before epoch end
|
|
1378
1236
|
self._flush_from_cache(cb_params)
|
|
1379
1237
|
|
|
1380
|
-
|
|
1381
|
-
if not is_embedding_cache_server:
|
|
1382
|
-
list_callback.on_train_epoch_end(run_context)
|
|
1238
|
+
list_callback.on_train_epoch_end(run_context)
|
|
1383
1239
|
if "metrics" in cb_params or "eval_results" in cb_params:
|
|
1384
1240
|
cb_params.pop("metrics", None)
|
|
1385
1241
|
cb_params.pop("eval_results", None)
|
|
@@ -1456,10 +1312,6 @@ class Model:
|
|
|
1456
1312
|
"""
|
|
1457
1313
|
_init_auto_parallel_context(self._network)
|
|
1458
1314
|
_check_tft()
|
|
1459
|
-
device_target = context.get_context("device_target")
|
|
1460
|
-
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1461
|
-
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1462
|
-
dataset_sink_mode = False
|
|
1463
1315
|
|
|
1464
1316
|
Validator.check_bool(dataset_sink_mode)
|
|
1465
1317
|
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
|
@@ -1471,11 +1323,6 @@ class Model:
|
|
|
1471
1323
|
"the value of epoch in train {} separately."
|
|
1472
1324
|
.format(train_dataset._warmup_epoch, epoch))
|
|
1473
1325
|
|
|
1474
|
-
# Parameter server and embedding cache mode check.
|
|
1475
|
-
if _is_ps_mode():
|
|
1476
|
-
if not dataset_sink_mode and _cache_enable():
|
|
1477
|
-
raise ValueError("Embedding cache mode should run with 'dataset_sink_mode=True'.")
|
|
1478
|
-
|
|
1479
1326
|
self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
|
|
1480
1327
|
|
|
1481
1328
|
Validator.check_is_int(sink_size)
|
|
@@ -1506,12 +1353,6 @@ class Model:
|
|
|
1506
1353
|
sink_size=sink_size,
|
|
1507
1354
|
initial_epoch=initial_epoch)
|
|
1508
1355
|
|
|
1509
|
-
# When it's distributed training and using MindRT,
|
|
1510
|
-
# the node id should be reset to start from 0.
|
|
1511
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1512
|
-
if _enable_distributed_mindrt():
|
|
1513
|
-
_reset_op_id_with_offset()
|
|
1514
|
-
|
|
1515
1356
|
_clear_auto_parallel_context(self._network)
|
|
1516
1357
|
|
|
1517
1358
|
@staticmethod
|
|
@@ -1609,10 +1450,6 @@ class Model:
|
|
|
1609
1450
|
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1610
1451
|
"""
|
|
1611
1452
|
_init_auto_parallel_context(self._network)
|
|
1612
|
-
device_target = context.get_context("device_target")
|
|
1613
|
-
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1614
|
-
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1615
|
-
dataset_sink_mode = False
|
|
1616
1453
|
|
|
1617
1454
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1618
1455
|
valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
|
|
@@ -1906,13 +1743,6 @@ class Model:
|
|
|
1906
1743
|
|
|
1907
1744
|
self._clear_metrics()
|
|
1908
1745
|
|
|
1909
|
-
# Embedding cache server as a storage service, no need to execute eval.
|
|
1910
|
-
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1911
|
-
if is_embedding_cache_server:
|
|
1912
|
-
metrics = self._get_metrics()
|
|
1913
|
-
cb_params.metrics = metrics
|
|
1914
|
-
return metrics
|
|
1915
|
-
|
|
1916
1746
|
if context.get_context("device_target") == "CPU" and dataset_sink_mode:
|
|
1917
1747
|
dataset_sink_mode = False
|
|
1918
1748
|
logger.info("CPU cannot support dataset sink mode currently."
|
|
@@ -1924,13 +1754,7 @@ class Model:
|
|
|
1924
1754
|
else:
|
|
1925
1755
|
eval_result = self._eval_process(valid_dataset, list_callback, cb_params)
|
|
1926
1756
|
|
|
1927
|
-
# When it's distributed training and using MindRT,
|
|
1928
|
-
# the node id should be reset to start from 0.
|
|
1929
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1930
|
-
if _enable_distributed_mindrt():
|
|
1931
|
-
_reset_op_id_with_offset()
|
|
1932
1757
|
_clear_auto_parallel_context(self._network)
|
|
1933
|
-
|
|
1934
1758
|
return eval_result
|
|
1935
1759
|
|
|
1936
1760
|
def _predict_lite(self, *predict_data, config=None):
|
|
@@ -2181,13 +2005,6 @@ class Model:
|
|
|
2181
2005
|
result = self._predict_network(*predict_data)
|
|
2182
2006
|
|
|
2183
2007
|
check_output_data(result)
|
|
2184
|
-
|
|
2185
|
-
# When it's distributed training and using MindRT,
|
|
2186
|
-
# the node id should be reset to start from 0.
|
|
2187
|
-
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
2188
|
-
if _enable_distributed_mindrt():
|
|
2189
|
-
_reset_op_id_with_offset()
|
|
2190
|
-
|
|
2191
2008
|
return result
|
|
2192
2009
|
|
|
2193
2010
|
def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):
|
|
@@ -2199,9 +2016,6 @@ class Model:
|
|
|
2199
2016
|
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
2200
2017
|
sink_size (int): Control the amount of data in each sink.
|
|
2201
2018
|
"""
|
|
2202
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2203
|
-
raise RuntimeError("Pre-compile process that generate parameter layout for the train network "
|
|
2204
|
-
"only supports GRAPH MODE and Ascend target currently.")
|
|
2205
2019
|
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
2206
2020
|
raise RuntimeError("'infer_train_layout' only supports 'semi_auto_parallel' and 'auto_parallel' "
|
|
2207
2021
|
"mode, but got {}.".format(_get_parallel_mode()))
|
|
@@ -2361,9 +2175,6 @@ class Model:
|
|
|
2361
2175
|
>>> predict_map = model.infer_predict_layout(inputs)
|
|
2362
2176
|
"""
|
|
2363
2177
|
_init_auto_parallel_context(self._network)
|
|
2364
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2365
|
-
raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
|
|
2366
|
-
"only supports GRAPH MODE and Ascend target currently.")
|
|
2367
2178
|
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
2368
2179
|
raise RuntimeError('Infer predict layout only supports semi auto parallel and auto parallel mode.')
|
|
2369
2180
|
_parallel_predict_check()
|