mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Stress detect."""
|
|
16
|
+
from mindspore import _c_expression
|
|
17
|
+
from mindspore import log as logger
|
|
18
|
+
from mindspore.communication import init, create_group, get_rank
|
|
19
|
+
from mindspore.communication import get_local_rank_size
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def stress_detect(detect_type="aic"):
|
|
23
|
+
"""
|
|
24
|
+
Used to detect whether there are faults in hardware accuracy or communication between links.
|
|
25
|
+
The common usage scenario is to initiate a new thread or call this interface through a Callback function
|
|
26
|
+
at each step or when saving checkpoints, to check whether hardware malfunctions could affect accuracy.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
detect_type (str, optional): The type of stress test to perform. There are two options available: ``'aic'`` and
|
|
30
|
+
``'hccs'``, which perform AiCore and HCCS link stress tests on the device, respectively. Default: "aic".
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
int, the return value represents the error type. 0 indicates normal. 1 indicates failure to start some or
|
|
34
|
+
all test cases. 2 indicates a hardware failure, and it is recommended to replace the device.
|
|
35
|
+
|
|
36
|
+
Supported Platforms:
|
|
37
|
+
``Ascend``
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> from mindspore.tools import stress_detect
|
|
41
|
+
>>> ret = stress_detect()
|
|
42
|
+
>>> print(ret)
|
|
43
|
+
0
|
|
44
|
+
"""
|
|
45
|
+
if detect_type not in ["aic", "hccs"]:
|
|
46
|
+
logger.error(f"For stress detect, detection type must be 'aic' or 'hccs'."
|
|
47
|
+
f"But got {detect_type}. Exiting stress detect.")
|
|
48
|
+
return 1
|
|
49
|
+
|
|
50
|
+
if detect_type == "aic":
|
|
51
|
+
return _c_expression.stress_detect("aic")
|
|
52
|
+
|
|
53
|
+
init()
|
|
54
|
+
local_ranks = []
|
|
55
|
+
local_rank_size = get_local_rank_size()
|
|
56
|
+
node_num = get_rank() // local_rank_size
|
|
57
|
+
for i in range(local_rank_size):
|
|
58
|
+
local_ranks.append(local_rank_size * node_num + i)
|
|
59
|
+
if get_rank() in local_ranks:
|
|
60
|
+
group = f"new_group_{node_num}"
|
|
61
|
+
create_group(group, local_ranks)
|
|
62
|
+
|
|
63
|
+
return _c_expression.stress_detect(group)
|
mindspore/train/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -25,8 +25,8 @@ from mindspore.train import amp
|
|
|
25
25
|
from mindspore.train.amp import build_train_network
|
|
26
26
|
from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
|
27
27
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
|
|
28
|
-
load,
|
|
29
|
-
load_checkpoint_async,
|
|
28
|
+
load, async_ckpt_thread_status, export_split_mindir, \
|
|
29
|
+
load_checkpoint_async, get_ckpt_path_with_strategy, ckpt_to_safetensors, safetensors_to_ckpt, \
|
|
30
30
|
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, restore_group_info_list
|
|
31
31
|
from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
|
|
32
32
|
CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, FlopsUtilizationCollector, \
|
|
@@ -37,9 +37,9 @@ from mindspore.train.metrics import *
|
|
|
37
37
|
from mindspore.train.data_sink import data_sink
|
|
38
38
|
|
|
39
39
|
__all__ = ["Model", "DatasetHelper", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
|
|
40
|
-
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
|
41
|
-
"load_param_into_net", "export", "load", "export_split_mindir", "
|
|
42
|
-
"
|
|
40
|
+
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
|
41
|
+
"load_param_into_net", "export", "load", "export_split_mindir", "async_ckpt_thread_status",
|
|
42
|
+
"data_sink", "load_checkpoint_async", "get_ckpt_path_with_strategy", "ckpt_to_safetensors",
|
|
43
43
|
"safetensors_to_ckpt", "build_searched_strategy", "merge_sliced_parameter", "load_distributed_checkpoint",
|
|
44
44
|
"restore_group_info_list"]
|
|
45
45
|
__all__.extend(callback.__all__)
|
mindspore/train/_utils.py
CHANGED
|
@@ -26,7 +26,7 @@ import numpy as np
|
|
|
26
26
|
from mindspore.common.tensor import Tensor
|
|
27
27
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
28
28
|
from mindspore._c_expression import MSContext, ms_ctx_param
|
|
29
|
-
from mindspore.common.dtype import
|
|
29
|
+
from mindspore.common.dtype import _dtype_to_nptype, _pytype_to_dtype
|
|
30
30
|
from mindspore.common import dtype as mstype
|
|
31
31
|
from mindspore import context
|
|
32
32
|
from mindspore import log as logger
|
|
@@ -54,7 +54,7 @@ def _convert_type(types):
|
|
|
54
54
|
"""
|
|
55
55
|
ms_types = []
|
|
56
56
|
for np_type in types:
|
|
57
|
-
ms_type =
|
|
57
|
+
ms_type = _pytype_to_dtype(np_type) # pylint:disable=protected-access
|
|
58
58
|
ms_types.append(ms_type)
|
|
59
59
|
return ms_types
|
|
60
60
|
|
|
@@ -131,7 +131,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
|
|
|
131
131
|
new_shape += (item * batch_expand_num,)
|
|
132
132
|
else:
|
|
133
133
|
new_shape += (item,)
|
|
134
|
-
tensor = Tensor(np.zeros(new_shape,
|
|
134
|
+
tensor = Tensor(np.zeros(new_shape, _dtype_to_nptype(type_)), dtype=type_) # pylint:disable=protected-access
|
|
135
135
|
tensor.virtual_flag = True
|
|
136
136
|
tensor_list.append(tensor)
|
|
137
137
|
return tensor_list
|
|
@@ -344,15 +344,7 @@ def _get_layout_opt_shard(layout_obj, param_redundancy_dict):
|
|
|
344
344
|
"""Layout ckpt append opt shard."""
|
|
345
345
|
for key, value in layout_obj.items():
|
|
346
346
|
if value[5]:
|
|
347
|
-
|
|
348
|
-
if value[5] in world_groups:
|
|
349
|
-
opt_para_num = get_group_size()
|
|
350
|
-
elif "-" in value[5]:
|
|
351
|
-
opt_para_str = value[5].split("-")[0]
|
|
352
|
-
opt_para_num = int(opt_para_str)
|
|
353
|
-
else:
|
|
354
|
-
raise ValueError(f"For get_parameter_redundancy, the format of the parallel communication domain for "
|
|
355
|
-
f"the optimizer is incorrect.")
|
|
347
|
+
opt_para_num = get_group_size(value[5])
|
|
356
348
|
param_redundancy_ranks = param_redundancy_dict.get(key)
|
|
357
349
|
res = []
|
|
358
350
|
for param_ranks in param_redundancy_ranks:
|
|
@@ -582,17 +574,12 @@ def _progress_bar(iterable, total=None):
|
|
|
582
574
|
print_progress_bar(i)
|
|
583
575
|
|
|
584
576
|
|
|
585
|
-
def _load_and_transform(path, name_map, load_func
|
|
577
|
+
def _load_and_transform(path, name_map, load_func):
|
|
586
578
|
"""use load_func to load and use transform_func to convert"""
|
|
587
|
-
|
|
588
|
-
param_dict = load_func(path)
|
|
589
|
-
else:
|
|
590
|
-
param_dict = path
|
|
579
|
+
param_dict = load_func(path)
|
|
591
580
|
transform_dict = {}
|
|
581
|
+
|
|
592
582
|
for k, v in param_dict.items():
|
|
593
583
|
new_name = name_map.get(k, k) if name_map is not None else k
|
|
594
|
-
|
|
595
|
-
transform_dict[new_name] = transform_func(v, new_name)
|
|
596
|
-
else:
|
|
597
|
-
transform_dict[new_name] = v
|
|
584
|
+
transform_dict[new_name] = v
|
|
598
585
|
return transform_dict
|
mindspore/train/amp.py
CHANGED
|
@@ -463,9 +463,6 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
463
463
|
``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
|
|
464
464
|
``BiasAdd``, ``AddN``, ``Concat``
|
|
465
465
|
|
|
466
|
-
For details on automatic mixed precision, refer to
|
|
467
|
-
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
|
|
468
|
-
|
|
469
466
|
Note:
|
|
470
467
|
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
471
468
|
can result in a larger network hierarchy and slower performance.
|
|
@@ -821,8 +818,10 @@ def get_white_list():
|
|
|
821
818
|
<class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
|
|
822
819
|
<class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
|
|
823
820
|
<class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
|
|
824
|
-
<class 'mindspore.ops.
|
|
825
|
-
<class 'mindspore.ops.
|
|
821
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.MatMul'>,
|
|
822
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.BatchMatMul'>,
|
|
823
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.PReLU'>,
|
|
824
|
+
<class 'mindspore.ops.auto_generate.gen_ops_prim.ReLU'>,
|
|
826
825
|
<class 'mindspore.ops.operations.math_ops.Ger'>]
|
|
827
826
|
"""
|
|
828
827
|
white_list = AMP_WHITE_LIST.copy()
|
|
@@ -874,8 +873,8 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
|
|
|
874
873
|
white list is not used.
|
|
875
874
|
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
|
|
876
875
|
black list is not used.
|
|
877
|
-
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or
|
|
878
|
-
default: ``mstype.float16`` .
|
|
876
|
+
dtype (Type, optional): The type used in lower precision calculations, can be ``mstype.float16`` or
|
|
877
|
+
``mstype.bfloat16`` , default: ``mstype.float16`` .
|
|
879
878
|
|
|
880
879
|
Returns:
|
|
881
880
|
network (Cell), A network supporting mixed precision.
|
|
@@ -60,7 +60,8 @@ def _fill_param_into_net(net, parameter_list):
|
|
|
60
60
|
if np_val.shape == (1,):
|
|
61
61
|
parameter_dict[param_name] = Parameter(np_val, name=param_name)
|
|
62
62
|
elif np_val.shape == ():
|
|
63
|
-
|
|
63
|
+
# pylint:disable=protected-access
|
|
64
|
+
parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype._pytype_to_dtype(np_val.dtype)),
|
|
64
65
|
name=param_name)
|
|
65
66
|
else:
|
|
66
67
|
parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
|
|
@@ -27,7 +27,6 @@ from mindspore.train._utils import _make_directory
|
|
|
27
27
|
from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
|
|
28
28
|
_wait_async_thread_save_ckpt, _check_async_save
|
|
29
29
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
30
|
-
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
31
30
|
from mindspore.communication.management import get_rank, get_group_size
|
|
32
31
|
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy, _get_pp_size_from_redundancy_map
|
|
33
32
|
from mindspore.train.callback._callback import Callback
|
|
@@ -509,9 +508,6 @@ class ModelCheckpoint(Callback):
|
|
|
509
508
|
if callable(prefix):
|
|
510
509
|
self._prefix_func = prefix
|
|
511
510
|
|
|
512
|
-
if context.get_context("device_target") == "GPU" and _get_recovery_context("enable_recovery"):
|
|
513
|
-
_set_recovery_context(ckpt_path=self._directory)
|
|
514
|
-
|
|
515
511
|
if config is None:
|
|
516
512
|
self._config = CheckpointConfig()
|
|
517
513
|
else:
|
|
@@ -577,11 +573,6 @@ class ModelCheckpoint(Callback):
|
|
|
577
573
|
self._directory = self._directory_func(cb_params)
|
|
578
574
|
_make_directory(self._directory)
|
|
579
575
|
collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
580
|
-
# In disaster recovery scenario, the training process may be rolled back to the last step where
|
|
581
|
-
# the ckpt was successfully saved, so the _last_triggered_step should be updated.
|
|
582
|
-
if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
|
|
583
|
-
self._last_triggered_step = cb_params.last_save_ckpt_step
|
|
584
|
-
cb_params.last_save_ckpt_step = None
|
|
585
576
|
|
|
586
577
|
# save graph (only once)
|
|
587
578
|
if not self._graph_saved:
|
|
@@ -628,13 +619,6 @@ class ModelCheckpoint(Callback):
|
|
|
628
619
|
if "step_num" in self._append_dict:
|
|
629
620
|
self._append_dict["step_num"] = self._append_step_num + step_num
|
|
630
621
|
|
|
631
|
-
def _update_save_step(self, cb_params):
|
|
632
|
-
"""update step if used async d2h copy"""
|
|
633
|
-
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
634
|
-
if self._d2h_async and self._run_mode == context.GRAPH_MODE:
|
|
635
|
-
step_num_in_epoch -= 1
|
|
636
|
-
return step_num_in_epoch
|
|
637
|
-
|
|
638
622
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
|
639
623
|
"""Save checkpoint files."""
|
|
640
624
|
if cb_params.cur_step_num == self._last_triggered_step:
|
|
@@ -645,7 +629,7 @@ class ModelCheckpoint(Callback):
|
|
|
645
629
|
self._flush_from_cache(cb_params)
|
|
646
630
|
|
|
647
631
|
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
|
648
|
-
step_num_in_epoch =
|
|
632
|
+
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
649
633
|
|
|
650
634
|
if save_ckpt:
|
|
651
635
|
|
|
@@ -31,7 +31,6 @@ from mindspore.communication.management import (create_group, get_group_size,
|
|
|
31
31
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
32
32
|
from mindspore.ops import operations as P
|
|
33
33
|
from mindspore.common import Tensor
|
|
34
|
-
from mindspore import context
|
|
35
34
|
import mindspore.nn as nn
|
|
36
35
|
|
|
37
36
|
|
|
@@ -152,16 +151,21 @@ class FlopsUtilizationCollector(Callback):
|
|
|
152
151
|
"""
|
|
153
152
|
Check whether FlopsUtilizationCollector is working in the current environment
|
|
154
153
|
"""
|
|
155
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
156
|
-
if self.verbose:
|
|
157
|
-
raise ValueError("FlopsUtilizationCollector now only support graph mode.")
|
|
158
|
-
logger.info("FlopsUtilizationCollector now only support graph mode.")
|
|
159
|
-
return False
|
|
160
154
|
cb_params = run_context.original_args()
|
|
161
155
|
if cb_params.mode == 'train':
|
|
162
156
|
network = cb_params.train_network
|
|
157
|
+
if not network.compiled:
|
|
158
|
+
if self.verbose:
|
|
159
|
+
raise ValueError("FlopsUtilizationCollector now only support graph mode.")
|
|
160
|
+
logger.info("FlopsUtilizationCollector now only support graph mode.")
|
|
161
|
+
return False
|
|
163
162
|
elif cb_params.mode == 'eval':
|
|
164
163
|
network = cb_params.eval_network
|
|
164
|
+
if not network.compiled:
|
|
165
|
+
if self.verbose:
|
|
166
|
+
raise ValueError("FlopsUtilizationCollector now only support graph mode.")
|
|
167
|
+
logger.info("FlopsUtilizationCollector now only support graph mode.")
|
|
168
|
+
return False
|
|
165
169
|
else:
|
|
166
170
|
if self.verbose:
|
|
167
171
|
raise ValueError('FlopsUtilizationCollector only support train and eval mode!')
|
|
@@ -28,15 +28,15 @@ from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post,
|
|
|
28
28
|
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm, _clean_rootinfo
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
30
|
from mindspore._c_expression import _pre_launch_send_recv
|
|
31
|
-
from mindspore._c_expression import send_recv, reset_params
|
|
31
|
+
from mindspore._c_expression import send_recv, reset_params, direct_copy_to_host
|
|
32
|
+
from mindspore._c_expression import _reg_snapshot_params, _reset_snapshot_state, _clear_snapshot_saving_flag
|
|
32
33
|
from mindspore._c_expression import CollectiveManager
|
|
33
34
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
34
|
-
from mindspore._c_expression import TensorPy as Tensor_
|
|
35
35
|
from mindspore.ops.operations.manually_defined._inner import TensorReport
|
|
36
36
|
import mindspore
|
|
37
37
|
import mindspore.common.dtype as mstype
|
|
38
|
-
from mindspore.parallel._recovery_context import _set_recovery_context
|
|
39
38
|
from mindspore import runtime
|
|
39
|
+
from mindspore._c_expression import set_is_arf
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
@@ -157,6 +157,7 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
157
157
|
CollectiveManager.get_instance().resume_hccl_comm()
|
|
158
158
|
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
159
159
|
if ctx.tft.tft_get_repair_type() == "recover":
|
|
160
|
+
_reset_snapshot_state()
|
|
160
161
|
logger.warning(f"Destroy hcom")
|
|
161
162
|
_finalize_comm()
|
|
162
163
|
logger.warning(f"Destroy hcom end")
|
|
@@ -166,11 +167,10 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
166
167
|
def _tft_stop_callback(args, cb_ctx):
|
|
167
168
|
""" Callback used for TFT stop function."""
|
|
168
169
|
logger.warning(f"Enter _tft_stop_callback device_id: {cb_ctx.device_id}")
|
|
169
|
-
_stop_device(cb_ctx.device_id)
|
|
170
|
-
cb_ctx.stop_been_called = True
|
|
171
170
|
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
172
171
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
173
172
|
cb_ctx.is_uce_rank = False
|
|
173
|
+
_stop_device(cb_ctx.device_id)
|
|
174
174
|
if cb_ctx.tft.tft_get_repair_type() == "recover":
|
|
175
175
|
logger.warning(f"Reset limit step")
|
|
176
176
|
cb_ctx.tft.tft_reset_limit_step()
|
|
@@ -182,7 +182,7 @@ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
|
182
182
|
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: {ctx.device_id}")
|
|
183
183
|
_rebuild_world_group()
|
|
184
184
|
_rebuild_sub_group()
|
|
185
|
-
|
|
185
|
+
set_is_arf(True)
|
|
186
186
|
logger.warning(f"try to pre launch send recv before real launch")
|
|
187
187
|
_pre_launch_send_recv(context.get_context('device_id'))
|
|
188
188
|
logger.warning(f"Pre launch send recv before real launch end")
|
|
@@ -192,7 +192,7 @@ def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
|
192
192
|
class TrainFaultTolerance(Callback):
|
|
193
193
|
"""
|
|
194
194
|
This callback is used to enable the TFT feature
|
|
195
|
-
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/
|
|
195
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/600/clusterscheduling/ref/mindiottp/mindiotft001.html>`_
|
|
196
196
|
and will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
197
197
|
|
|
198
198
|
Note:
|
|
@@ -202,7 +202,10 @@ class TrainFaultTolerance(Callback):
|
|
|
202
202
|
ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
|
|
203
203
|
a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
|
|
204
204
|
is created in that directory. Default: ``None``.
|
|
205
|
-
kwargs (dict): Other dictionary type parameters.
|
|
205
|
+
kwargs (dict): Other dictionary type parameters. When argument `ckpt_save_path` is ``None``, `kwargs` must
|
|
206
|
+
provide a parameter named `ckpt_save_fn`, which points to a function used to save checkpoint. The
|
|
207
|
+
prototype of `ckpt_save_fn` is ``def save_ckpt(cb_params, append_dict)``. When both `ckpt_save_path`
|
|
208
|
+
and `ckpt_save_fn` are provided, `ckpt_save_fn` is used in priority.
|
|
206
209
|
|
|
207
210
|
Raises:
|
|
208
211
|
Exception: TFT init failed.
|
|
@@ -329,7 +332,7 @@ class TrainFaultTolerance(Callback):
|
|
|
329
332
|
# `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
|
|
330
333
|
# i.e. (param_dict, remove_redundancy)
|
|
331
334
|
self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
|
|
332
|
-
if self._only_enable_tre():
|
|
335
|
+
if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
|
|
333
336
|
return
|
|
334
337
|
self.tft = _tft_handler.get_tft()
|
|
335
338
|
self._check_init()
|
|
@@ -340,11 +343,9 @@ class TrainFaultTolerance(Callback):
|
|
|
340
343
|
self.learning_rate = None
|
|
341
344
|
self.has_init_replica = False
|
|
342
345
|
self.is_uce_rank = False
|
|
343
|
-
self.stop_been_called = False
|
|
344
346
|
|
|
345
347
|
self.assign = mindspore.ops.Assign()
|
|
346
|
-
self.g_one =
|
|
347
|
-
self.s1 = mindspore.hal.Stream()
|
|
348
|
+
self.g_one = Tensor([1], dtype=mstype.int32)
|
|
348
349
|
_tft_sem_enable()
|
|
349
350
|
self._tft_register()
|
|
350
351
|
|
|
@@ -354,7 +355,21 @@ class TrainFaultTolerance(Callback):
|
|
|
354
355
|
non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
|
|
355
356
|
if any(flag in env_enable for flag in non_tre_flags):
|
|
356
357
|
return False
|
|
357
|
-
return "TRE:1" in env_enable
|
|
358
|
+
return "TRE:1" in env_enable or "TRE:2" in env_enable
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def _only_enable_ckpt_d2h_async():
|
|
362
|
+
"""Check whether only set MS_ENABLE_CKPT_D2H_ASYNC=1 without setting MS_ENABLE_TFT"""
|
|
363
|
+
if os.getenv("MS_ENABLE_TFT", "") != "":
|
|
364
|
+
return False
|
|
365
|
+
return os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
def _enable_snapshot():
|
|
369
|
+
"""Check whether parameter snapshot enabled"""
|
|
370
|
+
enable_step_tre = "TRE:2" in os.getenv("MS_ENABLE_TFT", "")
|
|
371
|
+
enable_ckpt_d2h_async = os.getenv("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
|
|
372
|
+
return enable_step_tre or enable_ckpt_d2h_async
|
|
358
373
|
|
|
359
374
|
def _only_enable_tsp(self):
|
|
360
375
|
"""Check if only configured MS_ENABLE_TFT='{TSP:1}'"""
|
|
@@ -382,18 +397,14 @@ class TrainFaultTolerance(Callback):
|
|
|
382
397
|
_tft_handler.init(config=None)
|
|
383
398
|
self.tft = _tft_handler.get_tft()
|
|
384
399
|
logger.warning(f"TFT handle init ok.")
|
|
385
|
-
mode = context.get_context("mode")
|
|
386
400
|
device_target = context.get_context("device_target")
|
|
387
|
-
if device_target != "Ascend"
|
|
388
|
-
raise ValueError(f"MindIO adataper only support on Ascend device
|
|
389
|
-
f"device:{device_target}, run mode: {mode}")
|
|
401
|
+
if device_target != "Ascend":
|
|
402
|
+
raise ValueError(f"MindIO adataper only support on Ascend device but got device {device_target}!")
|
|
390
403
|
|
|
391
404
|
def _is_params_consistent(self):
|
|
392
405
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
393
406
|
if "tft_g_one_flag" in key:
|
|
394
|
-
|
|
395
|
-
tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
|
|
396
|
-
self.s1.synchronize()
|
|
407
|
+
tft_g_one_flag = direct_copy_to_host(param)
|
|
397
408
|
return int(tft_g_one_flag) == 1
|
|
398
409
|
return False
|
|
399
410
|
|
|
@@ -438,7 +449,7 @@ class TrainFaultTolerance(Callback):
|
|
|
438
449
|
super(TFTOptSubCls, self).__init__(*args, **kwargs)
|
|
439
450
|
self.report = TensorReport()
|
|
440
451
|
self.report_end = TensorReport()
|
|
441
|
-
self.report_end.add_prim_attr("
|
|
452
|
+
self.report_end.add_prim_attr("optimizer_end", True)
|
|
442
453
|
self.depend = ops.Depend()
|
|
443
454
|
self.allreduce_sum = ops.AllReduce()
|
|
444
455
|
self.allreduce_sum.add_prim_attr("tft_report_before", True)
|
|
@@ -452,7 +463,27 @@ class TrainFaultTolerance(Callback):
|
|
|
452
463
|
self.report_end("tft_report", self.tft_g_one_flag)
|
|
453
464
|
return opt_ret
|
|
454
465
|
|
|
455
|
-
|
|
466
|
+
class TFTOptSnapShotCls(origin_opt_cls):
|
|
467
|
+
"""
|
|
468
|
+
Optimizer wrapper class when using tft.
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
def __init__(self, *args, **kwargs):
|
|
472
|
+
super(TFTOptSnapShotCls, self).__init__(*args, **kwargs)
|
|
473
|
+
self.report = TensorReport()
|
|
474
|
+
self.report.add_prim_attr("side_effect_mem", True).add_prim_attr("snapshot", True)
|
|
475
|
+
self.dummy_input = Tensor([1], dtype=mstype.int32)
|
|
476
|
+
|
|
477
|
+
def construct(self, gradients, **kwargs):
|
|
478
|
+
"""Add fake op TensorReport to insert wait event for copying parameters"""
|
|
479
|
+
self.report("tft_report", self.dummy_input)
|
|
480
|
+
opt_ret = super(TFTOptSnapShotCls, self).construct(gradients, **kwargs)
|
|
481
|
+
return opt_ret
|
|
482
|
+
|
|
483
|
+
env_tft = os.getenv('MS_ENABLE_TFT', '')
|
|
484
|
+
features = ['TTP:1', 'UCE:1', 'ARF:1']
|
|
485
|
+
need_redundancy = any([env_tft.find(feat) >= 0 for feat in features])
|
|
486
|
+
return TFTOptSubCls if need_redundancy else TFTOptSnapShotCls
|
|
456
487
|
|
|
457
488
|
def _tft_register(self):
|
|
458
489
|
"""Register callback functions."""
|
|
@@ -480,6 +511,17 @@ class TrainFaultTolerance(Callback):
|
|
|
480
511
|
_clean_rootinfo()
|
|
481
512
|
self.clean_unique_id = True
|
|
482
513
|
|
|
514
|
+
def on_train_step_begin(self, run_context):
|
|
515
|
+
"""
|
|
516
|
+
Clear saving snapshot state at each step begin.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
520
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
521
|
+
"""
|
|
522
|
+
if self._enable_snapshot():
|
|
523
|
+
_clear_snapshot_saving_flag()
|
|
524
|
+
|
|
483
525
|
def on_train_step_end(self, run_context):
|
|
484
526
|
"""
|
|
485
527
|
Report status to MindIO TFT after every step finished.
|
|
@@ -488,7 +530,7 @@ class TrainFaultTolerance(Callback):
|
|
|
488
530
|
run_context (RunContext): Context of the train running. Refer to
|
|
489
531
|
:class:`mindspore.train.RunContext` for detail.
|
|
490
532
|
"""
|
|
491
|
-
if self._only_enable_tre():
|
|
533
|
+
if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
|
|
492
534
|
return
|
|
493
535
|
|
|
494
536
|
cb_params = run_context.original_args()
|
|
@@ -528,10 +570,15 @@ class TrainFaultTolerance(Callback):
|
|
|
528
570
|
run_context (RunContext): Context of the train running. Refer to
|
|
529
571
|
:class:`mindspore.train.RunContext` for detail.
|
|
530
572
|
"""
|
|
573
|
+
cb_params = run_context.original_args()
|
|
574
|
+
if self._enable_snapshot():
|
|
575
|
+
param_dict = {}
|
|
576
|
+
for param in cb_params.train_network.trainable_params():
|
|
577
|
+
param_dict[param.name] = param
|
|
578
|
+
_reg_snapshot_params(param_dict)
|
|
531
579
|
if self._only_enable_tsp():
|
|
532
580
|
return
|
|
533
|
-
|
|
534
|
-
if self._only_enable_tre():
|
|
581
|
+
if self._only_enable_tre() or self._only_enable_ckpt_d2h_async():
|
|
535
582
|
self.cb_params = cb_params
|
|
536
583
|
return
|
|
537
584
|
sink_size = cb_params.get("sink_size", 0)
|
mindspore/train/data_sink.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from functools import wraps
|
|
17
17
|
import mindspore.ops as ops
|
|
18
18
|
from mindspore import context
|
|
19
|
-
from mindspore.common.dtype import
|
|
19
|
+
from mindspore.common.dtype import _pytype_to_dtype
|
|
20
20
|
from mindspore.common.api import jit
|
|
21
21
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, enable_data_broadcast
|
|
22
22
|
from mindspore.train.dataset_helper import _has_dynamic_shape, _check_inputs
|
|
@@ -61,7 +61,7 @@ def _init_sink_dataset(dataset, sink_size, input_signature, create_info):
|
|
|
61
61
|
_check_inputs(input_signature, dataset_shapes, dataset_types)
|
|
62
62
|
|
|
63
63
|
queue_name = transfer_dataset.queue_name
|
|
64
|
-
if _need_to_full()
|
|
64
|
+
if _need_to_full():
|
|
65
65
|
device_num = _get_device_num() // _get_pipeline_stages()
|
|
66
66
|
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
67
67
|
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
@@ -94,12 +94,12 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
|
|
|
94
94
|
|
|
95
95
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
96
96
|
dataset_types, dataset_shapes = dataset.__transfer_dataset__.get_data_info()
|
|
97
|
-
dataset_types = [
|
|
97
|
+
dataset_types = [_pytype_to_dtype(x) for x in dataset_types] # pylint:disable=protected-access
|
|
98
98
|
key = str(dataset_types) + str(dataset_shapes)
|
|
99
99
|
if key in dataset.__sink_aux__.next_ops:
|
|
100
100
|
next_op = dataset.__sink_aux__.next_ops[key]
|
|
101
101
|
else:
|
|
102
|
-
if _need_to_full()
|
|
102
|
+
if _need_to_full():
|
|
103
103
|
device_num = _get_device_num() // _get_pipeline_stages()
|
|
104
104
|
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
105
105
|
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
@@ -238,12 +238,8 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
238
238
|
|
|
239
239
|
real_sink_fun = _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config)
|
|
240
240
|
|
|
241
|
-
loop = sink_size
|
|
242
|
-
if jit_config is not None and context.get_context('mode') == context.GRAPH_MODE:
|
|
243
|
-
loop = 1
|
|
244
|
-
|
|
245
241
|
out = None
|
|
246
|
-
for _ in range(
|
|
242
|
+
for _ in range(sink_size):
|
|
247
243
|
out = real_sink_fun()
|
|
248
244
|
|
|
249
245
|
return out
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,8 +20,8 @@ import copy
|
|
|
20
20
|
|
|
21
21
|
from mindspore import _checkparam as Validator
|
|
22
22
|
from mindspore import log as logger
|
|
23
|
-
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
24
|
-
from mindspore.common.dtype import
|
|
23
|
+
from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
24
|
+
from mindspore.common.dtype import _pytype_to_dtype
|
|
25
25
|
from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
|
|
26
26
|
from mindspore.common._utils import is_shape_unknown
|
|
27
27
|
from mindspore.dataset.core import config as dataset_config
|
|
@@ -34,7 +34,7 @@ from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_t
|
|
|
34
34
|
_origin_shapes, _dynamic_shape_for_dataset
|
|
35
35
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
36
36
|
from mindspore.ops import operations as P
|
|
37
|
-
from mindspore.common.auto_dynamic_shape import _auto_dynamic_shape
|
|
37
|
+
from mindspore.common.dynamic_shape.auto_dynamic_shape import _auto_dynamic_shape
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def _send_data(dataset, epoch_num):
|
|
@@ -275,7 +275,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
275
275
|
# Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
|
|
276
276
|
if _need_to_full():
|
|
277
277
|
dataset_shapes = _to_full_shapes(dataset_shapes, _get_device_num() // _get_pipeline_stages())
|
|
278
|
-
dataset_types = [
|
|
278
|
+
dataset_types = [_pytype_to_dtype(x) for x in dataset_types] # pylint:disable=protected-access
|
|
279
279
|
if not is_dynamic:
|
|
280
280
|
dataset_shapes = _auto_dynamic_shape.auto_dynamic_generate_compile_args(dataset_shapes, True)
|
|
281
281
|
key = str(dataset_types) + str(dataset_shapes)
|