mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -18,8 +18,8 @@ from __future__ import absolute_import
|
|
|
18
18
|
import os
|
|
19
19
|
import stat
|
|
20
20
|
import time
|
|
21
|
-
|
|
22
21
|
import threading
|
|
22
|
+
|
|
23
23
|
import mindspore.context as context
|
|
24
24
|
from mindspore import log as logger
|
|
25
25
|
from mindspore import nn
|
|
@@ -37,8 +37,7 @@ from mindspore.common.tensor import Tensor
|
|
|
37
37
|
from mindspore.common.parameter import Parameter
|
|
38
38
|
from mindspore.common.generator import Generator
|
|
39
39
|
from mindspore.common.api import _cell_graph_executor
|
|
40
|
-
from mindspore._c_expression import
|
|
41
|
-
|
|
40
|
+
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
42
41
|
|
|
43
42
|
_cur_dir = os.getcwd()
|
|
44
43
|
SAVE_DIR = _cur_dir
|
|
@@ -88,9 +87,9 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
|
88
87
|
if index == 0:
|
|
89
88
|
suffix_num = max(suffix_num, 1)
|
|
90
89
|
elif index != -1:
|
|
91
|
-
num = filename[pre_len+1:pre_len+index]
|
|
90
|
+
num = filename[pre_len + 1:pre_len + index]
|
|
92
91
|
if num.isdigit():
|
|
93
|
-
suffix_num = max(suffix_num, int(num)+1)
|
|
92
|
+
suffix_num = max(suffix_num, int(num) + 1)
|
|
94
93
|
|
|
95
94
|
if suffix_num != 0:
|
|
96
95
|
prefix = f'{prefix}_{suffix_num}'
|
|
@@ -98,6 +97,14 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
|
98
97
|
return prefix
|
|
99
98
|
|
|
100
99
|
|
|
100
|
+
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, exception_save=False,
|
|
101
|
+
map_param_inc=False, global_step_num=None):
|
|
102
|
+
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
|
|
103
|
+
or exception_save or map_param_inc or global_step_num is not None)
|
|
104
|
+
if format == "safetensors" and param_not_default:
|
|
105
|
+
raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
|
|
106
|
+
|
|
107
|
+
|
|
101
108
|
class CheckpointConfig:
|
|
102
109
|
"""
|
|
103
110
|
The configuration of model checkpoint.
|
|
@@ -136,6 +143,10 @@ class CheckpointConfig:
|
|
|
136
143
|
exception_save (bool): Whether to save the current checkpoint when an exception occurs. Default: ``False`` .
|
|
137
144
|
crc_check (bool): Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
138
145
|
result to the end of ckpt. Default: ``False`` .
|
|
146
|
+
remove_redundancy (bool): Whether to enable saving the checkpoint with redundancy removal.
|
|
147
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
148
|
+
redundant-free saving is not enabled.
|
|
149
|
+
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
139
150
|
kwargs (dict): Configuration options dictionary.
|
|
140
151
|
|
|
141
152
|
Raises:
|
|
@@ -188,6 +199,8 @@ class CheckpointConfig:
|
|
|
188
199
|
enc_mode='AES-GCM',
|
|
189
200
|
exception_save=False,
|
|
190
201
|
crc_check=False,
|
|
202
|
+
remove_redundancy=False,
|
|
203
|
+
format="ckpt",
|
|
191
204
|
**kwargs):
|
|
192
205
|
|
|
193
206
|
if save_checkpoint_steps is not None:
|
|
@@ -231,8 +244,13 @@ class CheckpointConfig:
|
|
|
231
244
|
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
232
245
|
self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
233
246
|
self._crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
247
|
+
self._format = Validator.check_isinstance('format', format, str)
|
|
234
248
|
self._map_param_inc = kwargs.get('incremental', False)
|
|
235
249
|
self.enable_redundance = kwargs.get('enable_redundance', False)
|
|
250
|
+
self.remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
251
|
+
|
|
252
|
+
_check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, exception_save,
|
|
253
|
+
self._map_param_inc)
|
|
236
254
|
|
|
237
255
|
@property
|
|
238
256
|
def save_checkpoint_steps(self):
|
|
@@ -333,6 +351,10 @@ class CheckpointConfig:
|
|
|
333
351
|
"""
|
|
334
352
|
return self._crc_check
|
|
335
353
|
|
|
354
|
+
@property
|
|
355
|
+
def format(self):
|
|
356
|
+
return self._format
|
|
357
|
+
|
|
336
358
|
@property
|
|
337
359
|
def append_dict(self):
|
|
338
360
|
"""
|
|
@@ -495,10 +517,10 @@ class ModelCheckpoint(Callback):
|
|
|
495
517
|
self._aiturbo_init_flag = os.getenv("AITURBO") == "1"
|
|
496
518
|
# get existing checkpoint files
|
|
497
519
|
if self._aiturbo_init_flag:
|
|
498
|
-
import
|
|
499
|
-
self._manager =
|
|
520
|
+
from aiturbo.checkpoint.aiturbo_mindspore_ckpt import CheckpointShmManager
|
|
521
|
+
self._manager = CheckpointShmManager()
|
|
500
522
|
else:
|
|
501
|
-
self._manager = CheckpointManager()
|
|
523
|
+
self._manager = CheckpointManager(self._config.format)
|
|
502
524
|
if not callable(directory) and not callable(prefix):
|
|
503
525
|
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
|
504
526
|
self._append_dict = self._config.append_dict or {}
|
|
@@ -517,7 +539,7 @@ class ModelCheckpoint(Callback):
|
|
|
517
539
|
"""
|
|
518
540
|
cb_params = run_context.original_args()
|
|
519
541
|
if self._aiturbo_init_flag:
|
|
520
|
-
import aiturbo
|
|
542
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
521
543
|
ckpt_storage_path = self._directory
|
|
522
544
|
rank_id = get_rank()
|
|
523
545
|
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
@@ -536,7 +558,7 @@ class ModelCheckpoint(Callback):
|
|
|
536
558
|
"stage_layout": param_redundancy_dict}
|
|
537
559
|
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
538
560
|
single_params = {device_id: list(params) for device_id, params in single_params.items()}
|
|
539
|
-
aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, self._config.enable_redundance, dp)
|
|
561
|
+
aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, not self._config.enable_redundance, dp)
|
|
540
562
|
self._aiturbo_init_flag = False
|
|
541
563
|
if self._prefix_func:
|
|
542
564
|
self._prefix = self._prefix_func(cb_params)
|
|
@@ -546,7 +568,7 @@ class ModelCheckpoint(Callback):
|
|
|
546
568
|
"string that does not contain '/', but got {}.".format(self._prefix))
|
|
547
569
|
if self._directory_func:
|
|
548
570
|
self._directory = self._directory_func(cb_params)
|
|
549
|
-
|
|
571
|
+
collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
550
572
|
# In disaster recovery scenario, the training process may be rolled back to the last step where
|
|
551
573
|
# the ckpt was successfully saved, so the _last_triggered_step should be updated.
|
|
552
574
|
if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
|
|
@@ -575,7 +597,7 @@ class ModelCheckpoint(Callback):
|
|
|
575
597
|
run_context (RunContext): Context of the train running.
|
|
576
598
|
"""
|
|
577
599
|
cb_params = run_context.original_args()
|
|
578
|
-
|
|
600
|
+
collect_host_info("Callback", "ModelCheckpoint", "end", start_time=get_clock_syscnt(), level=1)
|
|
579
601
|
_to_save_last_ckpt = True
|
|
580
602
|
|
|
581
603
|
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
|
@@ -601,6 +623,13 @@ class ModelCheckpoint(Callback):
|
|
|
601
623
|
|
|
602
624
|
return False
|
|
603
625
|
|
|
626
|
+
def _append_dict_content(self, epoch_num, step_num):
|
|
627
|
+
"""Append append_dict content."""
|
|
628
|
+
if "epoch_num" in self._append_dict:
|
|
629
|
+
self._append_dict["epoch_num"] = self._append_epoch_num + epoch_num
|
|
630
|
+
if "step_num" in self._append_dict:
|
|
631
|
+
self._append_dict["step_num"] = self._append_step_num + step_num
|
|
632
|
+
|
|
604
633
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
|
605
634
|
"""Save checkpoint files."""
|
|
606
635
|
if cb_params.cur_step_num == self._last_triggered_step:
|
|
@@ -615,10 +644,10 @@ class ModelCheckpoint(Callback):
|
|
|
615
644
|
|
|
616
645
|
if save_ckpt:
|
|
617
646
|
if self._prefix_func:
|
|
618
|
-
cur_ckpoint_file = self._prefix + ".
|
|
647
|
+
cur_ckpoint_file = self._prefix + f".{self._config.format}"
|
|
619
648
|
else:
|
|
620
649
|
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
|
621
|
-
|
|
650
|
+
+ str(step_num_in_epoch) + f".{self._config.format}"
|
|
622
651
|
# update checkpoint file list.
|
|
623
652
|
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
|
624
653
|
# keep checkpoint files number equal max number.
|
|
@@ -644,20 +673,51 @@ class ModelCheckpoint(Callback):
|
|
|
644
673
|
set_cur_net(cb_params.train_network)
|
|
645
674
|
cb_params.train_network.add_flags(ge_sync_data=True)
|
|
646
675
|
_cell_graph_executor(cb_params.train_network, phase='save')
|
|
647
|
-
|
|
648
|
-
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
|
|
649
|
-
if "step_num" in self._append_dict:
|
|
650
|
-
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
|
|
676
|
+
self._append_dict_content(cb_params.cur_epoch_num, cb_params.cur_step_num)
|
|
651
677
|
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
|
652
678
|
if os.getenv("AITURBO") == "1":
|
|
653
679
|
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
|
|
654
680
|
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
655
681
|
crc_check=self._config.crc_check, incremental=self._map_param_inc,
|
|
656
682
|
global_step_num=cb_params.cur_step_num)
|
|
683
|
+
elif self._config.remove_redundancy:
|
|
684
|
+
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
685
|
+
if parallel_mode == "stand_alone":
|
|
686
|
+
raise TypeError(f"The deduplication feature for saving checkpoint can only be used "
|
|
687
|
+
f"in parallel scenarios, but got {parallel_mode}.")
|
|
688
|
+
param_layout = network.parameter_layout_dict
|
|
689
|
+
rank_id = get_rank()
|
|
690
|
+
if param_layout:
|
|
691
|
+
device_num = _get_device_num()
|
|
692
|
+
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
693
|
+
chunk_size = device_num // stage_num
|
|
694
|
+
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
695
|
+
param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
|
|
696
|
+
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
697
|
+
save_param_names = single_params.get(rank_id)
|
|
698
|
+
param_layout_set = set(param_layout.keys())
|
|
699
|
+
if save_param_names == param_layout.keys():
|
|
700
|
+
logger.warning(
|
|
701
|
+
f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.")
|
|
702
|
+
|
|
703
|
+
def choice_func(x):
|
|
704
|
+
return x not in param_layout_set or x in save_param_names
|
|
705
|
+
else:
|
|
706
|
+
param_redundancy_dict = get_parameter_redundancy(network)
|
|
707
|
+
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
708
|
+
save_param_names = single_params.get(rank_id)
|
|
709
|
+
|
|
710
|
+
def choice_func(x):
|
|
711
|
+
return x in save_param_names
|
|
712
|
+
save_checkpoint(network, cur_file, False, self._config.async_save,
|
|
713
|
+
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
714
|
+
crc_check=self._config.crc_check, format=self._config.format,
|
|
715
|
+
incremental=self._map_param_inc, choice_func=choice_func)
|
|
657
716
|
else:
|
|
658
717
|
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
|
|
659
718
|
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
660
|
-
crc_check=self._config.crc_check,
|
|
719
|
+
crc_check=self._config.crc_check, format=self._config.format,
|
|
720
|
+
incremental=self._map_param_inc)
|
|
661
721
|
|
|
662
722
|
self._latest_ckpt_file_name = cur_file
|
|
663
723
|
|
|
@@ -691,8 +751,9 @@ class ModelCheckpoint(Callback):
|
|
|
691
751
|
class CheckpointManager:
|
|
692
752
|
"""Manage checkpoint files according to train_config of checkpoint."""
|
|
693
753
|
|
|
694
|
-
def __init__(self):
|
|
754
|
+
def __init__(self, format='ckpt'):
|
|
695
755
|
self._ckpoint_filelist = []
|
|
756
|
+
self._format = format
|
|
696
757
|
|
|
697
758
|
@property
|
|
698
759
|
def ckpoint_filelist(self):
|
|
@@ -707,10 +768,12 @@ class CheckpointManager:
|
|
|
707
768
|
def update_ckpoint_filelist(self, directory, prefix):
|
|
708
769
|
"""Update the checkpoint file list."""
|
|
709
770
|
self._ckpoint_filelist = []
|
|
771
|
+
format = self._format
|
|
772
|
+
format_length = len(format) + 1
|
|
710
773
|
files = os.listdir(directory)
|
|
711
774
|
for filename in files:
|
|
712
|
-
if os.path.splitext(filename)[-1] == ".
|
|
713
|
-
mid_name = filename[len(prefix):-
|
|
775
|
+
if os.path.splitext(filename)[-1] == f".{format}" and filename.startswith(prefix + "-"):
|
|
776
|
+
mid_name = filename[len(prefix):-format_length]
|
|
714
777
|
flag = not (True in [char.isalpha() for char in mid_name])
|
|
715
778
|
if flag:
|
|
716
779
|
self._ckpoint_filelist.append(os.path.join(directory, filename))
|
|
@@ -150,7 +150,7 @@ class ClusterMonitor(Callback):
|
|
|
150
150
|
with _perf_mutex:
|
|
151
151
|
dir_path = os.path.dirname(self.full_path)
|
|
152
152
|
if not os.path.exists(dir_path):
|
|
153
|
-
os.makedirs(dir_path)
|
|
153
|
+
os.makedirs(dir_path, mode=0o700)
|
|
154
154
|
if os.path.exists(self.full_path):
|
|
155
155
|
os.chmod(self.full_path, stat.S_IWUSR)
|
|
156
156
|
os.remove(self.full_path)
|
|
@@ -65,6 +65,7 @@ class FlopsUtilizationCollector(Callback):
|
|
|
65
65
|
Raises:
|
|
66
66
|
TypeError: If data_size is not positive int.
|
|
67
67
|
TypeError: If full_flops is not bool.
|
|
68
|
+
AssertionError: If the training mode is not a static graph or not a static shape.
|
|
68
69
|
|
|
69
70
|
Examples:
|
|
70
71
|
>>> import numpy as np
|
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
|
19
19
|
|
|
20
20
|
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.train.callback._callback import Callback, _handle_loss
|
|
22
|
-
from mindspore._c_expression import
|
|
22
|
+
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class LossMonitor(Callback):
|
|
@@ -70,7 +70,7 @@ class LossMonitor(Callback):
|
|
|
70
70
|
please refer to :class:`mindspore.train.RunContext`.
|
|
71
71
|
"""
|
|
72
72
|
cb_params = run_context.original_args()
|
|
73
|
-
|
|
73
|
+
collect_host_info("Callback", "LossMonitor", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
74
74
|
cur_epoch_num = cb_params.get("cur_epoch_num", 1)
|
|
75
75
|
loss = _handle_loss(cb_params.net_outputs)
|
|
76
76
|
|
|
@@ -101,7 +101,7 @@ class LossMonitor(Callback):
|
|
|
101
101
|
please refer to :class:`mindspore.train.RunContext`.
|
|
102
102
|
"""
|
|
103
103
|
cb_params = run_context.original_args()
|
|
104
|
-
|
|
104
|
+
collect_host_info("Callback", "LossMonitor", "train_epoch_end", start_time=get_clock_syscnt(), level=1)
|
|
105
105
|
metrics = cb_params.get("metrics")
|
|
106
106
|
if metrics:
|
|
107
107
|
print("Eval result: epoch %d, metrics: %s" % (cb_params.cur_epoch_num, metrics))
|
|
@@ -16,12 +16,19 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
import os
|
|
19
|
+
import json
|
|
19
20
|
import signal
|
|
20
|
-
|
|
21
|
-
from mindspore import
|
|
21
|
+
import threading
|
|
22
|
+
from mindspore.common import dtype as mstype
|
|
23
|
+
from mindspore import context
|
|
24
|
+
from mindspore import log as logger
|
|
25
|
+
from mindspore.common.tensor import Tensor
|
|
26
|
+
from mindspore.train._utils import _make_directory
|
|
22
27
|
from mindspore import _checkparam as Validator
|
|
23
28
|
from mindspore.train.serialization import load_checkpoint, save_checkpoint, export
|
|
24
29
|
from mindspore.train.callback._callback import Callback
|
|
30
|
+
from mindspore.parallel._utils import _get_parallel_mode
|
|
31
|
+
from mindspore.context import ParallelMode
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
class OnRequestExit(Callback):
|
|
@@ -29,7 +36,8 @@ class OnRequestExit(Callback):
|
|
|
29
36
|
Respond to the user's closing request, exit the training or eval process, and save the checkpoint and mindir.
|
|
30
37
|
|
|
31
38
|
Register OnRequestExit Callback before training, when the user want to exit the training process
|
|
32
|
-
and save the training data, could send the registered exit signal 'sig' to the training process
|
|
39
|
+
and save the training data, could send the registered exit signal 'sig' to the training process or modify the
|
|
40
|
+
'GracefulExit' that a key in the json file specified by the 'config_file' to '1'.
|
|
33
41
|
After the training process executes the current step, saves the current training status,
|
|
34
42
|
including checkpoint and mindir, and then exit the training process.
|
|
35
43
|
|
|
@@ -38,9 +46,12 @@ class OnRequestExit(Callback):
|
|
|
38
46
|
save_mindir (bool): Whether save the mindir before the training process exit. Default: ``True`` .
|
|
39
47
|
file_name (str): The saved checkpoint and mindir file name,
|
|
40
48
|
the checkpoint file add suffix '.ckpt', the mindir file add suffix '.mindir'. Default: ``'Net'`` .
|
|
41
|
-
directory (str): The
|
|
49
|
+
directory (str): The path to save files. It will generate a 'rank_{id}' path by rank_id
|
|
50
|
+
to save checkpoint and mindir. Default: ``'./'`` .
|
|
42
51
|
sig (int): The user registered exit signal, it must be a captureable and negligible signal.
|
|
43
52
|
When the process receives the signal, exits the training or eval process. Default: ``signal.SIGTERM`` .
|
|
53
|
+
config_file (str): A json config file used to exit training process gracefully. Key: ``{"GracefulExit": 1}`` .
|
|
54
|
+
Default: ``None`` .
|
|
44
55
|
|
|
45
56
|
Raises:
|
|
46
57
|
ValueError: If the 'save_ckpt' is not a bool.
|
|
@@ -67,20 +78,28 @@ class OnRequestExit(Callback):
|
|
|
67
78
|
>>> model.train(10, dataset, callbacks=on_request_exit)
|
|
68
79
|
"""
|
|
69
80
|
|
|
70
|
-
def __init__(self, save_ckpt=True, save_mindir=True, file_name='Net', directory='./',
|
|
81
|
+
def __init__(self, save_ckpt=True, save_mindir=True, file_name='Net', directory='./', config_file=None,
|
|
82
|
+
sig=signal.SIGTERM):
|
|
71
83
|
super(OnRequestExit, self).__init__()
|
|
72
84
|
self.save_ckpt = Validator.check_isinstance('save_ckpt', save_ckpt, bool)
|
|
73
85
|
self.save_mindir = Validator.check_isinstance('save_mindir', save_mindir, bool)
|
|
74
|
-
if self.save_ckpt or self.save_mindir:
|
|
75
|
-
file_name = Validator.check_isinstance('file_name', file_name, str)
|
|
76
|
-
directory = Validator.check_isinstance('directory', directory, str)
|
|
77
|
-
os.makedirs(os.path.abspath(directory), exist_ok=True)
|
|
78
|
-
self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train"))
|
|
79
|
-
self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval"))
|
|
80
86
|
self.sig = Validator.check_isinstance('sig', sig, int)
|
|
81
87
|
if hasattr(signal, "SIGKILL") and self.sig == signal.SIGKILL:
|
|
82
88
|
raise ValueError("Not support send exit request by signal SIGKILL.")
|
|
83
|
-
self.exit = False
|
|
89
|
+
self.exit = False # used signal to exit the training process
|
|
90
|
+
self.lock = threading.Lock()
|
|
91
|
+
self.save_path = directory
|
|
92
|
+
self.key = "GracefulExit"
|
|
93
|
+
self.remote_config_file = config_file # used config file to save checkpoint and exit training process
|
|
94
|
+
self.use_graceful = os.environ.get("MS_ENABLE_GRACEFUL_EXIT") == "1"
|
|
95
|
+
self.is_distributed = _get_parallel_mode() != ParallelMode.STAND_ALONE
|
|
96
|
+
self.integrated_save = True
|
|
97
|
+
if self.is_distributed:
|
|
98
|
+
self.integrated_save = _get_parallel_mode() == ParallelMode.AUTO_PARALLEL
|
|
99
|
+
self.stop_train = False
|
|
100
|
+
self.need_do_step_end = False
|
|
101
|
+
if self.save_ckpt or self.save_mindir:
|
|
102
|
+
self.train_name, self.eval_name = self._get_save_path(file_name)
|
|
84
103
|
|
|
85
104
|
def on_train_begin(self, run_context):
|
|
86
105
|
"""
|
|
@@ -91,22 +110,31 @@ class OnRequestExit(Callback):
|
|
|
91
110
|
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
92
111
|
"""
|
|
93
112
|
signal.signal(self.sig, self._handle_signal)
|
|
94
|
-
if self.save_ckpt and os.path.isfile(f"{self.
|
|
113
|
+
if self.save_ckpt and os.path.isfile(f"{self.train_name}.ckpt"):
|
|
95
114
|
cb_params = run_context.original_args()
|
|
96
115
|
train_net = cb_params.train_network
|
|
97
|
-
load_checkpoint(f"{self.
|
|
116
|
+
load_checkpoint(f"{self.train_name}.ckpt", net=train_net)
|
|
117
|
+
|
|
118
|
+
def on_train_step_begin(self, run_context):
|
|
119
|
+
"""
|
|
120
|
+
Check whether received the exit signal or
|
|
121
|
+
whether the value of 'GracefulExit' in 'config_file' was changed to '1'.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
run_context (RunContext): Context information of the model.
|
|
125
|
+
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
126
|
+
"""
|
|
127
|
+
self._do_step_begin(run_context)
|
|
98
128
|
|
|
99
129
|
def on_train_step_end(self, run_context):
|
|
100
130
|
"""
|
|
101
|
-
|
|
102
|
-
Then exit the training process after this step training.
|
|
131
|
+
Save checkpoint file or mindir file according to config, and exit the training process.
|
|
103
132
|
|
|
104
133
|
Args:
|
|
105
134
|
run_context (RunContext): Include some information of the model.
|
|
106
135
|
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
107
136
|
"""
|
|
108
|
-
|
|
109
|
-
run_context.request_stop()
|
|
137
|
+
self._do_step_end(run_context)
|
|
110
138
|
|
|
111
139
|
def on_train_epoch_end(self, run_context):
|
|
112
140
|
"""
|
|
@@ -118,8 +146,7 @@ class OnRequestExit(Callback):
|
|
|
118
146
|
run_context (RunContext): Include some information of the model.
|
|
119
147
|
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
120
148
|
"""
|
|
121
|
-
|
|
122
|
-
run_context.request_stop()
|
|
149
|
+
self._do_step_end(run_context)
|
|
123
150
|
|
|
124
151
|
def on_train_end(self, run_context):
|
|
125
152
|
"""
|
|
@@ -135,10 +162,10 @@ class OnRequestExit(Callback):
|
|
|
135
162
|
cb_params = run_context.original_args()
|
|
136
163
|
train_net = cb_params.train_network
|
|
137
164
|
if self.save_ckpt:
|
|
138
|
-
save_checkpoint(train_net, ckpt_file_name=self.
|
|
165
|
+
save_checkpoint(train_net, ckpt_file_name=self.train_name)
|
|
139
166
|
if self.save_mindir:
|
|
140
167
|
inputs = cb_params.train_dataset_element
|
|
141
|
-
export(train_net, *inputs, file_name=self.
|
|
168
|
+
export(train_net, *inputs, file_name=self.train_name, file_format='MINDIR')
|
|
142
169
|
|
|
143
170
|
def on_eval_begin(self, run_context):
|
|
144
171
|
"""
|
|
@@ -153,15 +180,15 @@ class OnRequestExit(Callback):
|
|
|
153
180
|
return
|
|
154
181
|
cb_params = run_context.original_args()
|
|
155
182
|
eval_net = cb_params.eval_network
|
|
156
|
-
if os.path.isfile(f"{self.
|
|
157
|
-
load_checkpoint(f"{self.
|
|
158
|
-
elif os.path.isfile(f"{self.
|
|
159
|
-
load_checkpoint(f"{self.
|
|
183
|
+
if os.path.isfile(f"{self.eval_name}.ckpt"):
|
|
184
|
+
load_checkpoint(f"{self.eval_name}.ckpt", net=eval_net)
|
|
185
|
+
elif os.path.isfile(f"{self.train_name}.ckpt"):
|
|
186
|
+
load_checkpoint(f"{self.train_name}.ckpt", net=eval_net)
|
|
160
187
|
|
|
161
188
|
def on_eval_step_end(self, run_context):
|
|
162
189
|
"""
|
|
163
|
-
When the eval step end, if received the exit signal, set
|
|
164
|
-
Then exit the eval process after this step eval.
|
|
190
|
+
When the eval step end, if received the exit signal, set attribute '_stop_requested' of the
|
|
191
|
+
'run_context' to True. Then exit the eval process after this step eval.
|
|
165
192
|
|
|
166
193
|
Args:
|
|
167
194
|
run_context (RunContext): Include some information of the model.
|
|
@@ -184,12 +211,88 @@ class OnRequestExit(Callback):
|
|
|
184
211
|
cb_params = run_context.original_args()
|
|
185
212
|
eval_net = cb_params.eval_network
|
|
186
213
|
if self.save_ckpt:
|
|
187
|
-
save_checkpoint(eval_net, ckpt_file_name=self.
|
|
214
|
+
save_checkpoint(eval_net, ckpt_file_name=self.eval_name)
|
|
188
215
|
if self.save_mindir:
|
|
189
216
|
inputs = cb_params.eval_dataset_element
|
|
190
|
-
export(eval_net, *inputs, file_name=self.
|
|
217
|
+
export(eval_net, *inputs, file_name=self.eval_name, file_format='MINDIR')
|
|
191
218
|
|
|
192
219
|
def _handle_signal(self, signum, frame):
|
|
193
220
|
"""Handle the received signal"""
|
|
194
|
-
|
|
221
|
+
logger.debug(f"signum: {signum}, frame: {frame}")
|
|
195
222
|
self.exit = True
|
|
223
|
+
|
|
224
|
+
def _do_step_end(self, run_context):
|
|
225
|
+
"""
|
|
226
|
+
Save the checkpoint or mindir, and then exit training process.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
run_context (RunContext): Include some information of the model.
|
|
230
|
+
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
231
|
+
"""
|
|
232
|
+
with self.lock:
|
|
233
|
+
# save once
|
|
234
|
+
if self.stop_train or not self.need_do_step_end:
|
|
235
|
+
return
|
|
236
|
+
logger.info("Gracefully exiting training process on step end.")
|
|
237
|
+
call_params = run_context.original_args()
|
|
238
|
+
net = call_params.train_network
|
|
239
|
+
for _, param in net.parameters_and_names():
|
|
240
|
+
if param.name == "graceful_exit" and param.asnumpy() == True: # pylint: disable=C0121
|
|
241
|
+
logger.warning("Graceful exit is triggered, stop training.")
|
|
242
|
+
if self.save_ckpt:
|
|
243
|
+
save_checkpoint(net, self.train_name, integrated_save=self.integrated_save)
|
|
244
|
+
if self.save_mindir:
|
|
245
|
+
inputs = call_params.train_dataset_element
|
|
246
|
+
export(net, *inputs, file_name=self.train_name, file_format='MINDIR')
|
|
247
|
+
run_context.request_stop()
|
|
248
|
+
self.stop_train = True
|
|
249
|
+
|
|
250
|
+
def _do_step_begin(self, run_context):
|
|
251
|
+
"""
|
|
252
|
+
Check training process exit configuration at the step begin.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
run_context (RunContext): Include some information of the model.
|
|
256
|
+
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
257
|
+
"""
|
|
258
|
+
with self.lock:
|
|
259
|
+
# no env
|
|
260
|
+
if not self.use_graceful:
|
|
261
|
+
return
|
|
262
|
+
if self._check_config_info() or self.exit:
|
|
263
|
+
call_params = run_context.original_args()
|
|
264
|
+
net = call_params.train_network
|
|
265
|
+
for _, param in net.parameters_and_names():
|
|
266
|
+
if not self.is_distributed and param.name == "graceful_exit":
|
|
267
|
+
param.set_data(Tensor(True, mstype.bool_))
|
|
268
|
+
self.need_do_step_end = True
|
|
269
|
+
break
|
|
270
|
+
if param.name == "graceful_init":
|
|
271
|
+
param.set_data(Tensor([1], mstype.int32))
|
|
272
|
+
self.need_do_step_end = True
|
|
273
|
+
break
|
|
274
|
+
|
|
275
|
+
def _check_config_info(self):
|
|
276
|
+
"""check json config info"""
|
|
277
|
+
if self.remote_config_file is not None and os.path.exists(self.remote_config_file):
|
|
278
|
+
with open(self.remote_config_file, "r") as f:
|
|
279
|
+
try:
|
|
280
|
+
config_info = json.load(f)
|
|
281
|
+
except json.JSONDecodeError as e:
|
|
282
|
+
logger.warning(f"Parse json file failed: {e}, please check json file: {self.remote_config_file}")
|
|
283
|
+
return False
|
|
284
|
+
if self.key in config_info and config_info[self.key] == 1:
|
|
285
|
+
return True
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
def _get_save_path(self, file_name):
|
|
289
|
+
"""path to save checkpoint files or mindir files"""
|
|
290
|
+
device_id = context.get_context("device_id")
|
|
291
|
+
if self.save_path is None:
|
|
292
|
+
tmp = os.path.join(os.getcwd(), r"rank_" + str(device_id))
|
|
293
|
+
path_ = _make_directory(tmp)
|
|
294
|
+
return os.path.join(path_, f"{file_name}_train"), os.path.join(path_, f"{file_name}_eval")
|
|
295
|
+
|
|
296
|
+
save_path = os.path.join(self.save_path, r"rank_" + str(device_id))
|
|
297
|
+
save_path = _make_directory(save_path)
|
|
298
|
+
return os.path.join(save_path, f"{file_name}_train"), os.path.join(save_path, f"{file_name}_eval")
|
|
@@ -41,7 +41,7 @@ from mindspore.nn.optim.optimizer import Optimizer
|
|
|
41
41
|
from mindspore.nn.loss.loss import LossBase
|
|
42
42
|
from mindspore.train._utils import check_value_type, _make_directory
|
|
43
43
|
from mindspore._c_expression import security
|
|
44
|
-
from mindspore._c_expression import
|
|
44
|
+
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
45
45
|
|
|
46
46
|
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
|
|
47
47
|
HYPER_CONFIG_LEN_LIMIT = 100000
|
|
@@ -472,7 +472,7 @@ class SummaryCollector(Callback):
|
|
|
472
472
|
|
|
473
473
|
def begin(self, run_context):
|
|
474
474
|
cb_params = run_context.original_args()
|
|
475
|
-
|
|
475
|
+
collect_host_info("Callback", "SummaryCollector", "begin", start_time=get_clock_syscnt(), level=1)
|
|
476
476
|
self._check_callbacks(cb_params)
|
|
477
477
|
|
|
478
478
|
if cb_params.mode not in ModeEnum.to_list():
|
|
@@ -484,7 +484,7 @@ class SummaryCollector(Callback):
|
|
|
484
484
|
|
|
485
485
|
def step_end(self, run_context):
|
|
486
486
|
cb_params = run_context.original_args()
|
|
487
|
-
|
|
487
|
+
collect_host_info("Callback", "SummaryCollector", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
488
488
|
if cb_params.mode != ModeEnum.TRAIN.value:
|
|
489
489
|
return
|
|
490
490
|
|
|
@@ -559,7 +559,7 @@ class SummaryCollector(Callback):
|
|
|
559
559
|
|
|
560
560
|
def epoch_end(self, run_context):
|
|
561
561
|
cb_params = run_context.original_args()
|
|
562
|
-
|
|
562
|
+
collect_host_info("Callback", "SummaryCollector", "epoch_end", start_time=get_clock_syscnt(), level=1)
|
|
563
563
|
self._collect_tensor_data(cb_params)
|
|
564
564
|
collect_landscape = self._collect_specified_data.get('collect_landscape')
|
|
565
565
|
if collect_landscape is not None:
|
|
@@ -576,7 +576,7 @@ class SummaryCollector(Callback):
|
|
|
576
576
|
|
|
577
577
|
def end(self, run_context):
|
|
578
578
|
cb_params = run_context.original_args()
|
|
579
|
-
|
|
579
|
+
collect_host_info("Callback", "SummaryCollector", "end", start_time=get_clock_syscnt(), level=1)
|
|
580
580
|
if cb_params.mode == ModeEnum.TRAIN.value:
|
|
581
581
|
self._collect_train_lineage(cb_params)
|
|
582
582
|
else:
|