mindspore 2.4.1__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +99 -25
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +30 -27
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +10 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +292 -193
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +28 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +125 -65
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +3 -2
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
- mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
- mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +1 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +142 -21
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/conv.py +3 -0
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +55 -23
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/__init__.py +22 -0
- mindspore/nn/utils/init.py +73 -0
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1664 -294
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +23 -17
- mindspore/ops/operations/custom_ops.py +7 -49
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +153 -310
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +44 -34
- mindspore/parallel/_cell_wrapper.py +22 -3
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +18 -13
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +37 -15
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -14
- mindspore/train/serialization.py +461 -72
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +368 -242
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -18,14 +18,14 @@ from __future__ import absolute_import
|
|
|
18
18
|
import os
|
|
19
19
|
import stat
|
|
20
20
|
import time
|
|
21
|
-
import threading
|
|
22
21
|
|
|
23
22
|
import mindspore.context as context
|
|
24
23
|
from mindspore import log as logger
|
|
25
24
|
from mindspore import nn
|
|
26
25
|
from mindspore import _checkparam as Validator
|
|
27
26
|
from mindspore.train._utils import _make_directory
|
|
28
|
-
from mindspore.train.serialization import save_checkpoint, _save_graph
|
|
27
|
+
from mindspore.train.serialization import save_checkpoint, _save_graph, _wait_async_process_save_ckpt, \
|
|
28
|
+
_wait_async_thread_save_ckpt, _check_async_save
|
|
29
29
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
30
30
|
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
31
31
|
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
@@ -44,15 +44,6 @@ SAVE_DIR = _cur_dir
|
|
|
44
44
|
_info_list = ["epoch_num", "step_num"]
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
def _wait_async_save_ckpt(async_save=False):
|
|
48
|
-
"""Waiting for asynchronous saving of ckpt to complete."""
|
|
49
|
-
if async_save:
|
|
50
|
-
thread_list = threading.enumerate()
|
|
51
|
-
for thread in thread_list:
|
|
52
|
-
if thread.getName() == "asyn_save_ckpt":
|
|
53
|
-
thread.join()
|
|
54
|
-
|
|
55
|
-
|
|
56
47
|
def _get_dp_tp_from_redundancy(redundancy_tuple):
|
|
57
48
|
"""From redundancy get dp and tp"""
|
|
58
49
|
dp = []
|
|
@@ -76,6 +67,15 @@ def _get_dp_tp_from_layout(parameter_redundancy_dict):
|
|
|
76
67
|
return dp, tp
|
|
77
68
|
|
|
78
69
|
|
|
70
|
+
def _wait_async_save_ckpt(async_save=False):
|
|
71
|
+
"""Waiting for asynchronous saving of ckpt to complete."""
|
|
72
|
+
if async_save:
|
|
73
|
+
if async_save == "process":
|
|
74
|
+
_wait_async_process_save_ckpt()
|
|
75
|
+
else:
|
|
76
|
+
_wait_async_thread_save_ckpt()
|
|
77
|
+
|
|
78
|
+
|
|
79
79
|
def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
80
80
|
"""Check if there is a file with the same name."""
|
|
81
81
|
if callable(prefix) or callable(directory):
|
|
@@ -139,7 +139,10 @@ class CheckpointConfig:
|
|
|
139
139
|
integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
|
|
140
140
|
Integrated save function is only supported in automatic parallel scene, not supported
|
|
141
141
|
in manual parallel. Default: ``True`` .
|
|
142
|
-
async_save (bool):
|
|
142
|
+
async_save (Union[bool, str]):Whether to use asynchronous saving of the checkpoint file, if True,
|
|
143
|
+
the asynchronous thread is used by default. If the type is string,
|
|
144
|
+
the method of asynchronous saving, it can be "process" or "thread".
|
|
145
|
+
Default: ``False`` .
|
|
143
146
|
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
|
|
144
147
|
with the network in training, the initial value of saved_network will be saved. Default: ``None`` .
|
|
145
148
|
append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and
|
|
@@ -247,7 +250,7 @@ class CheckpointConfig:
|
|
|
247
250
|
self._keep_checkpoint_max = 1
|
|
248
251
|
|
|
249
252
|
self._integrated_save = Validator.check_bool(integrated_save)
|
|
250
|
-
self._async_save =
|
|
253
|
+
self._async_save = _check_async_save(async_save)
|
|
251
254
|
self._saved_network = saved_network
|
|
252
255
|
self._append_dict = self._handle_append_info(append_info)
|
|
253
256
|
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
@@ -313,10 +316,10 @@ class CheckpointConfig:
|
|
|
313
316
|
@property
|
|
314
317
|
def async_save(self):
|
|
315
318
|
"""
|
|
316
|
-
Get the value of whether asynchronous execution saves the checkpoint to a file.
|
|
319
|
+
Get the value of whether or how asynchronous execution saves the checkpoint to a file.
|
|
317
320
|
|
|
318
321
|
Returns:
|
|
319
|
-
bool, whether asynchronous execution saves the checkpoint to a file.
|
|
322
|
+
(bool, str), whether or how asynchronous execution saves the checkpoint to a file.
|
|
320
323
|
"""
|
|
321
324
|
return self._async_save
|
|
322
325
|
|
|
@@ -538,6 +541,8 @@ class ModelCheckpoint(Callback):
|
|
|
538
541
|
self._graph_saved = False
|
|
539
542
|
self._need_flush_from_cache = True
|
|
540
543
|
self._map_param_inc = self._config.map_param_inc
|
|
544
|
+
self._d2h_async = os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") == "1"
|
|
545
|
+
self._run_mode = context.get_context("mode")
|
|
541
546
|
|
|
542
547
|
def step_end(self, run_context):
|
|
543
548
|
"""
|
|
@@ -632,6 +637,13 @@ class ModelCheckpoint(Callback):
|
|
|
632
637
|
if "step_num" in self._append_dict:
|
|
633
638
|
self._append_dict["step_num"] = self._append_step_num + step_num
|
|
634
639
|
|
|
640
|
+
def _update_save_step(self, cb_params):
|
|
641
|
+
"""update step if used async d2h copy"""
|
|
642
|
+
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
643
|
+
if self._d2h_async and self._run_mode == context.GRAPH_MODE:
|
|
644
|
+
step_num_in_epoch -= 1
|
|
645
|
+
return step_num_in_epoch
|
|
646
|
+
|
|
635
647
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
|
636
648
|
"""Save checkpoint files."""
|
|
637
649
|
if cb_params.cur_step_num == self._last_triggered_step:
|
|
@@ -642,10 +654,12 @@ class ModelCheckpoint(Callback):
|
|
|
642
654
|
self._flush_from_cache(cb_params)
|
|
643
655
|
|
|
644
656
|
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
|
645
|
-
step_num_in_epoch =
|
|
657
|
+
step_num_in_epoch = self._update_save_step(cb_params)
|
|
646
658
|
|
|
647
659
|
if save_ckpt:
|
|
660
|
+
|
|
648
661
|
_wait_async_save_ckpt(self._config.async_save)
|
|
662
|
+
|
|
649
663
|
if self._prefix_func:
|
|
650
664
|
cur_ckpoint_file = self._prefix + f".{self._config.format}"
|
|
651
665
|
else:
|
|
@@ -704,14 +718,14 @@ class ModelCheckpoint(Callback):
|
|
|
704
718
|
f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.")
|
|
705
719
|
|
|
706
720
|
def choice_func(x):
|
|
707
|
-
return x not in param_layout_set or x in save_param_names
|
|
721
|
+
return x not in param_layout_set or (save_param_names is not None and x in save_param_names)
|
|
708
722
|
else:
|
|
709
723
|
param_redundancy_dict = get_parameter_redundancy(network)
|
|
710
724
|
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
711
725
|
save_param_names = single_params.get(rank_id)
|
|
712
726
|
|
|
713
727
|
def choice_func(x):
|
|
714
|
-
return x in save_param_names
|
|
728
|
+
return save_param_names is not None and x in save_param_names
|
|
715
729
|
save_checkpoint(network, cur_file, False, self._config.async_save,
|
|
716
730
|
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
717
731
|
crc_check=self._config.crc_check, format=self._config.format,
|
|
@@ -198,7 +198,7 @@ class EarlyStopping(Callback):
|
|
|
198
198
|
"""
|
|
199
199
|
Get the monitor value at the end of epoch during training.
|
|
200
200
|
|
|
201
|
-
If
|
|
201
|
+
If :class:`mindspore.train.callback.ReduceLROnPlateau` used with `model.train`, no evaluation process
|
|
202
202
|
during training, only monitor="loss" is valid; if it used with `model.fit`, evaluation process will be
|
|
203
203
|
performed at the end of epoch, valid monitor is "loss", "eval_loss" and metrics passed to `Model`.
|
|
204
204
|
|
|
@@ -61,10 +61,13 @@ class FlopsUtilizationCollector(Callback):
|
|
|
61
61
|
computility (int): The peak flops of each compute card. Default: ``1`` .
|
|
62
62
|
full_flops(bool): Whether to count the full model flops. If set full_flops to False,
|
|
63
63
|
FlopsUtilizationCollector would count the shard model flops in each device. Default: ``True`` .
|
|
64
|
+
enable_ma_collector(bool): Whether to write flops into the log and provide them to tasks
|
|
65
|
+
on the cloud for retrieval. Default: ``False`` .
|
|
64
66
|
|
|
65
67
|
Raises:
|
|
66
68
|
TypeError: If data_size is not positive int.
|
|
67
69
|
TypeError: If full_flops is not bool.
|
|
70
|
+
TypeError: If enable_ma_collector is not bool.
|
|
68
71
|
AssertionError: If the training mode is not a static graph or not a static shape.
|
|
69
72
|
|
|
70
73
|
Examples:
|
|
@@ -86,7 +89,7 @@ class FlopsUtilizationCollector(Callback):
|
|
|
86
89
|
Train per step time: 135.572 ms, mfu:0.47% hfu:0.47%
|
|
87
90
|
Train per step time: 1.317 ms, mfu:48.59% hfu:48.59%
|
|
88
91
|
"""
|
|
89
|
-
def __init__(self, data_size, computility=1, full_flops=True):
|
|
92
|
+
def __init__(self, data_size=None, computility=1, full_flops=True, enable_ma_collector=False):
|
|
90
93
|
super(FlopsUtilizationCollector, self).__init__()
|
|
91
94
|
self.step_time = time.time()
|
|
92
95
|
self.computility = computility
|
|
@@ -101,10 +104,14 @@ class FlopsUtilizationCollector(Callback):
|
|
|
101
104
|
self.mfu_calculated = False
|
|
102
105
|
self.data_size = data_size
|
|
103
106
|
self.time_step_path = ''
|
|
104
|
-
self.
|
|
105
|
-
self.
|
|
106
|
-
|
|
107
|
-
|
|
107
|
+
self.full_flops = full_flops
|
|
108
|
+
self.verbose = not(computility == 1 and enable_ma_collector)
|
|
109
|
+
self.ma = enable_ma_collector
|
|
110
|
+
self.batch_step_size = None
|
|
111
|
+
Validator.check_bool(full_flops, "full_flops")
|
|
112
|
+
Validator.check_bool(enable_ma_collector, "enable_ma_collector")
|
|
113
|
+
if data_size:
|
|
114
|
+
Validator.check_positive_int(data_size, "data_size")
|
|
108
115
|
|
|
109
116
|
def step_begin(self, run_context):
|
|
110
117
|
"""
|
|
@@ -115,6 +122,14 @@ class FlopsUtilizationCollector(Callback):
|
|
|
115
122
|
run_context (RunContext): Context of the process running. For more details,
|
|
116
123
|
please refer to :class:`mindspore.train.RunContext`.
|
|
117
124
|
"""
|
|
125
|
+
if self.batch_step_size is None:
|
|
126
|
+
self.batch_step_size = self.data_size
|
|
127
|
+
cb_params = run_context.original_args()
|
|
128
|
+
if hasattr(cb_params, "batch_num"):
|
|
129
|
+
batch_num = cb_params.batch_num
|
|
130
|
+
if isinstance(batch_num, int) and batch_num > 0:
|
|
131
|
+
self.batch_step_size = cb_params.batch_num
|
|
132
|
+
Validator.check_positive_int(self.batch_step_size)
|
|
118
133
|
self.step_time = time.time()
|
|
119
134
|
|
|
120
135
|
def _get_pipeline_group(self):
|
|
@@ -134,6 +149,40 @@ class FlopsUtilizationCollector(Callback):
|
|
|
134
149
|
rank_list_str = "-".join(rank_str_list)
|
|
135
150
|
return rank_list, rank_list_str
|
|
136
151
|
|
|
152
|
+
def _check_run_mode_valid(self, run_context):
|
|
153
|
+
"""
|
|
154
|
+
Check whether FlopsUtilizationCollector is working in the current environment
|
|
155
|
+
"""
|
|
156
|
+
if context.get_context("mode") != context.GRAPH_MODE:
|
|
157
|
+
if self.verbose:
|
|
158
|
+
raise ValueError("FlopsUtilizationCollector now only support graph mode.")
|
|
159
|
+
logger.info("FlopsUtilizationCollector now only support graph mode.")
|
|
160
|
+
return False
|
|
161
|
+
cb_params = run_context.original_args()
|
|
162
|
+
if cb_params.mode == 'train':
|
|
163
|
+
network = cb_params.train_network
|
|
164
|
+
elif cb_params.mode == 'eval':
|
|
165
|
+
network = cb_params.eval_network
|
|
166
|
+
else:
|
|
167
|
+
if self.verbose:
|
|
168
|
+
raise ValueError('FlopsUtilizationCollector only support train and eval mode!')
|
|
169
|
+
logger.info('FlopsUtilizationCollector only support train and eval mode!')
|
|
170
|
+
return False
|
|
171
|
+
try:
|
|
172
|
+
self.full_model_flops, self.full_hardware_flops, self.shard_model_flops, \
|
|
173
|
+
self.shard_hardware_flops, is_dynamic_shape = flops_collection(network.current_phase)
|
|
174
|
+
except Exception as e:
|
|
175
|
+
if self.verbose:
|
|
176
|
+
raise ValueError("FlopsUtilizationCollector is not supported because {}.".format(e))
|
|
177
|
+
logger.info("FlopsUtilizationCollector is not supported because {}.".format(e))
|
|
178
|
+
return False
|
|
179
|
+
if is_dynamic_shape:
|
|
180
|
+
if self.verbose:
|
|
181
|
+
raise ValueError("FlopsUtilizationCollector now do not support dynamic shape.")
|
|
182
|
+
logger.info("FlopsUtilizationCollector now do not support dynamic shape.")
|
|
183
|
+
return False
|
|
184
|
+
return True
|
|
185
|
+
|
|
137
186
|
def step_end(self, run_context):
|
|
138
187
|
"""
|
|
139
188
|
Print mfu and hfu time at the end of step.
|
|
@@ -142,84 +191,67 @@ class FlopsUtilizationCollector(Callback):
|
|
|
142
191
|
run_context (RunContext): Context of the process running. For more details,
|
|
143
192
|
please refer to :class:`mindspore.train.RunContext`.
|
|
144
193
|
"""
|
|
145
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
146
|
-
logger.warning("FlopsUtilizationCollector now only support graph mode.")
|
|
147
|
-
return
|
|
148
|
-
|
|
149
194
|
step_seconds = (time.time() - self.step_time) * 1000
|
|
150
195
|
if not self.mfu_calculated:
|
|
151
|
-
|
|
152
|
-
if cb_params.mode == 'train':
|
|
153
|
-
network = cb_params.train_network
|
|
154
|
-
elif cb_params.mode == 'eval':
|
|
155
|
-
network = cb_params.eval_network
|
|
156
|
-
else:
|
|
157
|
-
logger.warning('FlopsUtilizationCollector only support train and eval mode!')
|
|
196
|
+
if not self._check_run_mode_valid(run_context):
|
|
158
197
|
return
|
|
159
|
-
full_model_flops
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
return
|
|
164
|
-
self.full_mfu = full_model_flops / self.computility
|
|
165
|
-
self.full_hfu = full_hardware_flops / self.computility
|
|
166
|
-
|
|
167
|
-
self.shard_mfu = shard_model_flops / self.computility
|
|
168
|
-
self.shard_hfu = shard_hardware_flops / self.computility
|
|
169
|
-
self.full_model_flops = full_model_flops
|
|
170
|
-
self.full_hardware_flops = full_hardware_flops
|
|
171
|
-
self.shard_model_flops = shard_model_flops
|
|
172
|
-
self.shard_hardware_flops = shard_hardware_flops
|
|
198
|
+
self.full_mfu = self.full_model_flops / self.computility
|
|
199
|
+
self.full_hfu = self.full_hardware_flops / self.computility
|
|
200
|
+
self.shard_mfu = self.shard_model_flops / self.computility
|
|
201
|
+
self.shard_hfu = self.shard_hardware_flops / self.computility
|
|
173
202
|
self.mfu_calculated = True
|
|
174
203
|
shard_mf_dir = os.path.realpath(os.getenv('MA_LOG_DIR', './'))
|
|
175
204
|
if self.ma:
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
self.time_step_path = os.path.join(
|
|
182
|
-
shard_mf_dir, "time_step_rank_" + str(get_rank())) + ".txt"
|
|
205
|
+
rank_id = get_rank() if auto_parallel_context().get_parallel_mode() != "stand_alone" else 0
|
|
206
|
+
flops_path = os.path.join(
|
|
207
|
+
shard_mf_dir, "flops_rank_" + str(rank_id)) + ".txt"
|
|
208
|
+
self.time_step_path = os.path.join(
|
|
209
|
+
shard_mf_dir, "time_step_rank_" + str(rank_id)) + ".txt"
|
|
183
210
|
time_stamp = time.time()
|
|
184
|
-
model_flops_log = "flops{{type=\"model_flops\"}} {} {}\n".\
|
|
185
|
-
format(shard_model_flops, int(round(time_stamp * 1000)))
|
|
186
|
-
hardware_flops_log = "flops{{type=\"hardware_flops\"}} {} {}\n".\
|
|
187
|
-
format(shard_hardware_flops, int(round(time_stamp * 1000)))
|
|
211
|
+
model_flops_log = "flops{{type=\"model_flops\", rank_id=\"{}\"}} {} {}\n".\
|
|
212
|
+
format(str(rank_id), self.shard_model_flops, int(round(time_stamp * 1000)))
|
|
213
|
+
hardware_flops_log = "flops{{type=\"hardware_flops\", rank_id=\"{}\"}} {} {}\n".\
|
|
214
|
+
format(str(rank_id), self.shard_hardware_flops, int(round(time_stamp * 1000)))
|
|
188
215
|
flags = os.O_WRONLY | os.O_CREAT
|
|
189
216
|
modes = stat.S_IWUSR | stat.S_IRUSR
|
|
190
217
|
with os.fdopen(os.open(flops_path, flags, modes), 'w') as f:
|
|
191
218
|
f.write(model_flops_log)
|
|
192
219
|
f.write(hardware_flops_log)
|
|
193
220
|
if self.verbose:
|
|
194
|
-
|
|
195
|
-
|
|
221
|
+
if self.full_flops:
|
|
222
|
+
pipeline_num = auto_parallel_context().get_pipeline_stages()
|
|
223
|
+
if pipeline_num > 1:
|
|
224
|
+
pipeline_group_list, pipeline_group_name = self._get_pipeline_group()
|
|
225
|
+
auto_parallel_context().set_pipeline_stages(1)
|
|
226
|
+
hashed = hashlib.md5(
|
|
227
|
+
pipeline_group_name.encode()).hexdigest()[:48]
|
|
228
|
+
pipeline_group_name = str(hashed)
|
|
229
|
+
create_group(pipeline_group_name, pipeline_group_list)
|
|
230
|
+
self.full_mfu = AllReduceNet(pipeline_group_name)(
|
|
231
|
+
Tensor([self.full_mfu])).asnumpy()[0]
|
|
232
|
+
self.full_hfu = AllReduceNet(pipeline_group_name)(
|
|
233
|
+
Tensor([self.full_hfu])).asnumpy()[0]
|
|
234
|
+
auto_parallel_context().set_pipeline_stages(pipeline_num)
|
|
235
|
+
full_model_flops = self.full_mfu * self.computility
|
|
236
|
+
full_hardware_flops = self.full_hfu * self.computility
|
|
237
|
+
if auto_parallel_context().get_parallel_mode() != "stand_alone":
|
|
238
|
+
self.full_mfu = self.full_mfu / get_group_size()
|
|
239
|
+
self.full_hfu = self.full_hfu / get_group_size()
|
|
240
|
+
flops_log = f"Full model flops is {full_model_flops}, " \
|
|
241
|
+
f"Full hardware flops is {full_hardware_flops}, " \
|
|
242
|
+
f"Shard model flops is {self.shard_model_flops}, " \
|
|
243
|
+
f"Shard hardware flops is {self.shard_hardware_flops}."
|
|
244
|
+
else:
|
|
245
|
+
flops_log = f"Shard model flops is {self.shard_model_flops}, " \
|
|
246
|
+
f"Shard hardware flops is {self.shard_hardware_flops}."
|
|
196
247
|
print(flops_log, flush=True)
|
|
197
|
-
if auto_parallel_context().get_pipeline_stages() > 1:
|
|
198
|
-
pipeline_group_list, pipeline_group_name = self._get_pipeline_group()
|
|
199
|
-
auto_parallel_context().set_pipeline_stages(1)
|
|
200
|
-
hashed = hashlib.md5(
|
|
201
|
-
pipeline_group_name.encode()).hexdigest()[:48]
|
|
202
|
-
pipeline_group_name = str(hashed)
|
|
203
|
-
create_group(pipeline_group_name, pipeline_group_list)
|
|
204
|
-
self.full_mfu = AllReduceNet(pipeline_group_name)(
|
|
205
|
-
Tensor([self.full_mfu])).asnumpy()[0]
|
|
206
|
-
self.full_hfu = AllReduceNet(pipeline_group_name)(
|
|
207
|
-
Tensor([self.full_hfu])).asnumpy()[0]
|
|
208
|
-
if auto_parallel_context().get_parallel_mode() != "stand_alone":
|
|
209
|
-
self.full_mfu = self.full_mfu / get_group_size()
|
|
210
|
-
self.full_hfu = self.full_hfu / get_group_size()
|
|
211
|
-
|
|
212
|
-
step_size = self.data_size
|
|
213
248
|
cb_params = run_context.original_args()
|
|
214
|
-
if hasattr(cb_params, "batch_num"):
|
|
215
|
-
batch_num = cb_params.batch_num
|
|
216
|
-
if isinstance(batch_num, int) and batch_num > 0:
|
|
217
|
-
step_size = cb_params.batch_num
|
|
218
|
-
Validator.check_positive_int(step_size)
|
|
219
249
|
if cb_params.dataset_sink_mode:
|
|
220
|
-
step_seconds = step_seconds /
|
|
250
|
+
step_seconds = step_seconds / self.batch_step_size
|
|
221
251
|
time_stamp = time.time()
|
|
222
|
-
|
|
252
|
+
rank_id = get_rank() if auto_parallel_context().get_parallel_mode() != "stand_alone" else 0
|
|
253
|
+
train_log = "time_monitor{{type=\"per_step_time\", rank_id=\"{}\"}} {} {}".format(
|
|
254
|
+
str(rank_id), step_seconds, int(round(time_stamp * 1000)))
|
|
223
255
|
if self.ma:
|
|
224
256
|
flags = os.O_WRONLY | os.O_CREAT
|
|
225
257
|
modes = stat.S_IWUSR | stat.S_IRUSR
|
|
@@ -227,9 +259,13 @@ class FlopsUtilizationCollector(Callback):
|
|
|
227
259
|
f.write(train_log + '\n')
|
|
228
260
|
train_log = "{} per step time: {:5.3f} ms".format(
|
|
229
261
|
cb_params.mode.title(), step_seconds)
|
|
230
|
-
if self.verbose:
|
|
231
|
-
|
|
232
|
-
|
|
262
|
+
if self.verbose and cb_params.cur_step_num % self.data_size:
|
|
263
|
+
if self.full_flops:
|
|
264
|
+
mfu = 1000 * self.full_mfu / step_seconds
|
|
265
|
+
hfu = 1000 * self.full_hfu / step_seconds
|
|
266
|
+
else:
|
|
267
|
+
mfu = 1000 * self.shard_mfu / step_seconds
|
|
268
|
+
hfu = 1000 * self.shard_hfu / step_seconds
|
|
233
269
|
|
|
234
270
|
def floored_percentage(index, val, digits):
|
|
235
271
|
val *= 10 ** (digits + 2)
|
|
@@ -31,7 +31,7 @@ class History(Callback):
|
|
|
31
31
|
outputs will be recorded.
|
|
32
32
|
|
|
33
33
|
Note:
|
|
34
|
-
Normally used in
|
|
34
|
+
Normally used in :func:`mindspore.train.Model.train` or :func:`mindspore.train.Model.fit`.
|
|
35
35
|
|
|
36
36
|
Examples:
|
|
37
37
|
>>> import numpy as np
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
import platform
|
|
19
20
|
import stat
|
|
20
21
|
import re
|
|
21
22
|
import json
|
|
@@ -43,6 +44,9 @@ from mindspore.train._utils import check_value_type, _make_directory
|
|
|
43
44
|
from mindspore._c_expression import security
|
|
44
45
|
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
45
46
|
|
|
47
|
+
if platform.system() == "Linux":
|
|
48
|
+
import fcntl
|
|
49
|
+
|
|
46
50
|
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
|
|
47
51
|
HYPER_CONFIG_LEN_LIMIT = 100000
|
|
48
52
|
|
|
@@ -606,13 +610,32 @@ class SummaryCollector(Callback):
|
|
|
606
610
|
"landscape_size": landscape_size,
|
|
607
611
|
"create_landscape": create_landscape
|
|
608
612
|
}
|
|
613
|
+
|
|
609
614
|
meta_path = os.path.join(self._ckpt_dir, 'train_metadata.json')
|
|
615
|
+
if platform.system() != "Linux":
|
|
616
|
+
try:
|
|
617
|
+
with open(meta_path, 'w') as file:
|
|
618
|
+
json.dump(data, file)
|
|
619
|
+
os.chmod(meta_path, stat.S_IRUSR)
|
|
620
|
+
except OSError as e:
|
|
621
|
+
logger.error("Write meta data %s failed, detail: %s" % (meta_path, str(e)))
|
|
622
|
+
return
|
|
623
|
+
|
|
624
|
+
lock_file = f"{meta_path}.lock"
|
|
610
625
|
try:
|
|
611
|
-
with open(
|
|
612
|
-
|
|
613
|
-
|
|
626
|
+
with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR), 'w') as f:
|
|
627
|
+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
|
628
|
+
if not os.path.isfile(meta_path):
|
|
629
|
+
with open(meta_path, 'w') as file:
|
|
630
|
+
json.dump(data, file)
|
|
631
|
+
os.chmod(meta_path, stat.S_IRUSR)
|
|
614
632
|
except OSError as e:
|
|
615
633
|
logger.error("Write meta data %s failed, detail: %s" % (meta_path, str(e)))
|
|
634
|
+
try:
|
|
635
|
+
if os.path.isfile(lock_file):
|
|
636
|
+
os.remove(lock_file)
|
|
637
|
+
except OSError:
|
|
638
|
+
logger.warning("The lock file %s has been removed.", lock_file)
|
|
616
639
|
|
|
617
640
|
def _save_model_params(self, cur_num, unit, backbone):
|
|
618
641
|
"""Save model params."""
|
|
@@ -629,12 +652,27 @@ class SummaryCollector(Callback):
|
|
|
629
652
|
|
|
630
653
|
ckpt_file_name = f"{type(backbone).__name__}_{cur_num}_{unit}.ckpt"
|
|
631
654
|
file_path = os.path.join(self._ckpt_dir, ckpt_file_name)
|
|
655
|
+
self._model_params_file_map[str(cur_num)] = file_path
|
|
656
|
+
if platform.system() != "Linux":
|
|
657
|
+
try:
|
|
658
|
+
save_checkpoint(param_list, file_path)
|
|
659
|
+
except OSError as e:
|
|
660
|
+
logger.error(str(e))
|
|
661
|
+
return
|
|
662
|
+
|
|
663
|
+
lock_file = f"{file_path}.lock"
|
|
632
664
|
try:
|
|
633
|
-
|
|
665
|
+
with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR), 'w') as f:
|
|
666
|
+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
|
667
|
+
if not os.path.isfile(file_path):
|
|
668
|
+
save_checkpoint(param_list, file_path)
|
|
634
669
|
except OSError as e:
|
|
635
670
|
logger.error(str(e))
|
|
636
|
-
|
|
637
|
-
|
|
671
|
+
try:
|
|
672
|
+
if os.path.isfile(lock_file):
|
|
673
|
+
os.remove(lock_file)
|
|
674
|
+
except OSError:
|
|
675
|
+
logger.warning("The lock file %s has been removed.", lock_file)
|
|
638
676
|
|
|
639
677
|
def _save_model_params_for_landscape(self, cb_params):
|
|
640
678
|
"""Save model params for landscape."""
|
|
@@ -25,9 +25,9 @@ from mindspore.common.tensor import Tensor
|
|
|
25
25
|
from mindspore.communication import get_rank, get_group_size
|
|
26
26
|
from mindspore import log as logger
|
|
27
27
|
from mindspore.train.serialization import _get_cur_rank_dp
|
|
28
|
-
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
|
|
28
|
+
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
|
-
from mindspore._c_expression import send_recv
|
|
30
|
+
from mindspore._c_expression import send_recv, reset_params
|
|
31
31
|
from mindspore._c_expression import CollectiveManager
|
|
32
32
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
33
33
|
from mindspore._c_expression import Tensor as Tensor_
|
|
@@ -90,6 +90,7 @@ def _tft_exit_cb(ctx):
|
|
|
90
90
|
_tft_sem_post()
|
|
91
91
|
os._exit(1) # pylint: disable=W0212
|
|
92
92
|
|
|
93
|
+
|
|
93
94
|
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
94
95
|
""" Callback used for TFT repair function."""
|
|
95
96
|
logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
|
|
@@ -105,11 +106,12 @@ or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
|
|
|
105
106
|
cb_params = args
|
|
106
107
|
src_rank = repair_info["src"][0]
|
|
107
108
|
dst_rank = repair_info["dst"][0]
|
|
108
|
-
send_recv(cb_params.
|
|
109
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
110
|
+
raise ValueError("Call send_recv failed.")
|
|
109
111
|
logger.info("Finish _tft_repair_callback")
|
|
110
112
|
|
|
111
113
|
|
|
112
|
-
def _tft_clean_callback(is_uce_error, ctx):
|
|
114
|
+
def _tft_clean_callback(is_uce_error, args, ctx):
|
|
113
115
|
""" Callback used for TFT clean function."""
|
|
114
116
|
logger.info("Enter _tft_clean_callback")
|
|
115
117
|
ret = 0
|
|
@@ -130,12 +132,13 @@ def _tft_clean_callback(is_uce_error, ctx):
|
|
|
130
132
|
return ret
|
|
131
133
|
|
|
132
134
|
|
|
133
|
-
def _tft_stop_callback(cb_ctx):
|
|
135
|
+
def _tft_stop_callback(args, cb_ctx):
|
|
134
136
|
""" Callback used for TFT stop function."""
|
|
135
137
|
logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
|
|
136
138
|
_stop_device(cb_ctx.device_id)
|
|
137
|
-
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
|
|
139
|
+
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
138
140
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
141
|
+
cb_ctx.is_uce_rank = False
|
|
139
142
|
logger.info("Finish _tft_stop_callback")
|
|
140
143
|
|
|
141
144
|
|
|
@@ -160,13 +163,23 @@ class TFTRegister(Callback):
|
|
|
160
163
|
ModuleNotFoundError: Mindio TFT whl package is not installed.
|
|
161
164
|
|
|
162
165
|
Examples:
|
|
166
|
+
.. note::
|
|
167
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
168
|
+
|
|
169
|
+
It's recommended to use the msrun startup method.
|
|
170
|
+
Please see the `msrun start up
|
|
171
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/msrun_launcher.html>`_
|
|
172
|
+
for more details.
|
|
173
|
+
|
|
174
|
+
This example should be run with 4 devices.
|
|
175
|
+
|
|
163
176
|
>>> import numpy as np
|
|
164
177
|
>>> import os
|
|
165
178
|
>>> import math
|
|
166
179
|
>>> import mindspore as ms
|
|
167
180
|
>>> import mindspore.dataset as ds
|
|
168
181
|
>>> from mindspore import nn, ops, Parameter, train
|
|
169
|
-
>>> from mindspore.communication import init
|
|
182
|
+
>>> from mindspore.communication import init, get_rank
|
|
170
183
|
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
171
184
|
>>> from mindspore.train import Model, TFTRegister
|
|
172
185
|
>>> from mindspore import dataset as ds
|
|
@@ -175,7 +188,7 @@ class TFTRegister(Callback):
|
|
|
175
188
|
>>> init()
|
|
176
189
|
>>> ms.set_seed(1)
|
|
177
190
|
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
|
|
178
|
-
|
|
191
|
+
... "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
|
|
179
192
|
>>> class MatMulCell(nn.Cell):
|
|
180
193
|
... def __init__(self, param=None, shape=None):
|
|
181
194
|
... super().__init__()
|
|
@@ -233,7 +246,7 @@ class TFTRegister(Callback):
|
|
|
233
246
|
... dataset = dataset.batch(batch_size)
|
|
234
247
|
... return dataset
|
|
235
248
|
>>>
|
|
236
|
-
>>>
|
|
249
|
+
>>> dataset = create_dataset(32)
|
|
237
250
|
>>>
|
|
238
251
|
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
239
252
|
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
|
|
@@ -241,8 +254,8 @@ class TFTRegister(Callback):
|
|
|
241
254
|
>>>
|
|
242
255
|
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
|
|
243
256
|
>>> net_with_loss.set_train()
|
|
244
|
-
>>> model = Model(net_with_loss, optimizer=
|
|
245
|
-
>>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
|
|
257
|
+
>>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
|
|
258
|
+
>>> tft_cb = TFTRegister(0, "192.168.0.1", 2000, "./tft_checkpoint/")
|
|
246
259
|
>>> loss_cb = train.LossMonitor(1)
|
|
247
260
|
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
|
|
248
261
|
"""
|
|
@@ -264,6 +277,7 @@ class TFTRegister(Callback):
|
|
|
264
277
|
self.global_step = 0
|
|
265
278
|
Validator.check_non_negative_int(ctrl_port)
|
|
266
279
|
self.has_init_replica = False
|
|
280
|
+
self.is_uce_rank = False
|
|
267
281
|
self._controller_ip = ctrl_ip
|
|
268
282
|
self._controller_rank_id = ctrl_rank_id
|
|
269
283
|
self._controller_port = ctrl_port
|
|
@@ -274,6 +288,7 @@ class TFTRegister(Callback):
|
|
|
274
288
|
self.assign = mindspore.ops.Assign()
|
|
275
289
|
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
|
|
276
290
|
self.s1 = mindspore.hal.Stream()
|
|
291
|
+
_tft_sem_enable()
|
|
277
292
|
|
|
278
293
|
def _is_params_consistent(self):
|
|
279
294
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
@@ -300,7 +315,7 @@ class TFTRegister(Callback):
|
|
|
300
315
|
replica_info = [
|
|
301
316
|
{
|
|
302
317
|
"type": 1,
|
|
303
|
-
"rank_list": dp,
|
|
318
|
+
"rank_list": list(dp),
|
|
304
319
|
"replica_cnt": len(dp),
|
|
305
320
|
"replica_shift": 0
|
|
306
321
|
}
|
|
@@ -321,13 +336,12 @@ class TFTRegister(Callback):
|
|
|
321
336
|
cur_rank = get_rank()
|
|
322
337
|
enable_local_copy = False
|
|
323
338
|
enable_arf = False
|
|
324
|
-
enable_zit = False
|
|
325
339
|
enable_tls = False
|
|
326
340
|
tls_key_dir = ""
|
|
327
341
|
|
|
328
342
|
if cur_rank == self._controller_rank_id:
|
|
329
343
|
logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
|
|
330
|
-
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf
|
|
344
|
+
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
|
|
331
345
|
self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
|
|
332
346
|
logger.info("Finish start tft controller.")
|
|
333
347
|
|
|
@@ -336,6 +350,14 @@ class TFTRegister(Callback):
|
|
|
336
350
|
self.tft.tft_start_processor(self._controller_ip, self._controller_port)
|
|
337
351
|
logger.info("Finished start tft processor.")
|
|
338
352
|
|
|
353
|
+
def _reset_acc_grads(self):
|
|
354
|
+
accu_grad_params = map(lambda e: e[1],
|
|
355
|
+
filter(lambda e: e[1].name.startswith('accu_grads'),
|
|
356
|
+
self.cb_params.train_network.parameters_and_names()))
|
|
357
|
+
accu_grad_list = list(accu_grad_params)
|
|
358
|
+
if reset_params(accu_grad_list) != 0:
|
|
359
|
+
raise ValueError("Call reset_params failed.")
|
|
360
|
+
|
|
339
361
|
def on_train_step_end(self, run_context):
|
|
340
362
|
"""
|
|
341
363
|
And report status to MindIO TFT after every step finished.
|
|
@@ -349,13 +371,13 @@ class TFTRegister(Callback):
|
|
|
349
371
|
self._set_tft_optimizer_replica(run_context)
|
|
350
372
|
cb_params = run_context.original_args()
|
|
351
373
|
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
352
|
-
self.tft.tft_end_updating_os(cb_params.cur_step_num)
|
|
353
374
|
if cb_params.optimizer is not None:
|
|
354
375
|
self.global_step = int(cb_params.optimizer.global_step.data)
|
|
355
376
|
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
|
|
356
377
|
else:
|
|
357
378
|
self.global_step = int(cb_params.network.optimizer.global_step.data)
|
|
358
379
|
self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
|
|
380
|
+
self.tft.tft_end_updating_os(cb_params.cur_step_num)
|
|
359
381
|
logger.info("END Set optimizer finish step status to TFT.")
|
|
360
382
|
|
|
361
383
|
|