mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -18,8 +18,8 @@ from __future__ import absolute_import
|
|
|
18
18
|
import os
|
|
19
19
|
import stat
|
|
20
20
|
import time
|
|
21
|
-
|
|
22
21
|
import threading
|
|
22
|
+
|
|
23
23
|
import mindspore.context as context
|
|
24
24
|
from mindspore import log as logger
|
|
25
25
|
from mindspore import nn
|
|
@@ -37,14 +37,22 @@ from mindspore.common.tensor import Tensor
|
|
|
37
37
|
from mindspore.common.parameter import Parameter
|
|
38
38
|
from mindspore.common.generator import Generator
|
|
39
39
|
from mindspore.common.api import _cell_graph_executor
|
|
40
|
-
from mindspore._c_expression import
|
|
41
|
-
|
|
40
|
+
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
42
41
|
|
|
43
42
|
_cur_dir = os.getcwd()
|
|
44
43
|
SAVE_DIR = _cur_dir
|
|
45
44
|
_info_list = ["epoch_num", "step_num"]
|
|
46
45
|
|
|
47
46
|
|
|
47
|
+
def _wait_async_save_ckpt(async_save=False):
|
|
48
|
+
"""Waiting for asynchronous saving of ckpt to complete."""
|
|
49
|
+
if async_save:
|
|
50
|
+
thread_list = threading.enumerate()
|
|
51
|
+
for thread in thread_list:
|
|
52
|
+
if thread.getName() == "asyn_save_ckpt":
|
|
53
|
+
thread.join()
|
|
54
|
+
|
|
55
|
+
|
|
48
56
|
def _get_dp_tp_from_redundancy(redundancy_tuple):
|
|
49
57
|
"""From redundancy get dp and tp"""
|
|
50
58
|
dp = []
|
|
@@ -88,9 +96,9 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
|
88
96
|
if index == 0:
|
|
89
97
|
suffix_num = max(suffix_num, 1)
|
|
90
98
|
elif index != -1:
|
|
91
|
-
num = filename[pre_len+1:pre_len+index]
|
|
99
|
+
num = filename[pre_len + 1:pre_len + index]
|
|
92
100
|
if num.isdigit():
|
|
93
|
-
suffix_num = max(suffix_num, int(num)+1)
|
|
101
|
+
suffix_num = max(suffix_num, int(num) + 1)
|
|
94
102
|
|
|
95
103
|
if suffix_num != 0:
|
|
96
104
|
prefix = f'{prefix}_{suffix_num}'
|
|
@@ -98,6 +106,14 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
|
|
|
98
106
|
return prefix
|
|
99
107
|
|
|
100
108
|
|
|
109
|
+
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, exception_save=False,
|
|
110
|
+
map_param_inc=False, global_step_num=None):
|
|
111
|
+
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
|
|
112
|
+
or exception_save or map_param_inc or global_step_num is not None)
|
|
113
|
+
if format == "safetensors" and param_not_default:
|
|
114
|
+
raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
|
|
115
|
+
|
|
116
|
+
|
|
101
117
|
class CheckpointConfig:
|
|
102
118
|
"""
|
|
103
119
|
The configuration of model checkpoint.
|
|
@@ -136,6 +152,10 @@ class CheckpointConfig:
|
|
|
136
152
|
exception_save (bool): Whether to save the current checkpoint when an exception occurs. Default: ``False`` .
|
|
137
153
|
crc_check (bool): Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
138
154
|
result to the end of ckpt. Default: ``False`` .
|
|
155
|
+
remove_redundancy (bool): Whether to enable saving the checkpoint with redundancy removal.
|
|
156
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
157
|
+
redundant-free saving is not enabled.
|
|
158
|
+
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
139
159
|
kwargs (dict): Configuration options dictionary.
|
|
140
160
|
|
|
141
161
|
Raises:
|
|
@@ -188,6 +208,8 @@ class CheckpointConfig:
|
|
|
188
208
|
enc_mode='AES-GCM',
|
|
189
209
|
exception_save=False,
|
|
190
210
|
crc_check=False,
|
|
211
|
+
remove_redundancy=False,
|
|
212
|
+
format="ckpt",
|
|
191
213
|
**kwargs):
|
|
192
214
|
|
|
193
215
|
if save_checkpoint_steps is not None:
|
|
@@ -231,8 +253,13 @@ class CheckpointConfig:
|
|
|
231
253
|
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
232
254
|
self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
233
255
|
self._crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
256
|
+
self._format = Validator.check_isinstance('format', format, str)
|
|
234
257
|
self._map_param_inc = kwargs.get('incremental', False)
|
|
235
258
|
self.enable_redundance = kwargs.get('enable_redundance', False)
|
|
259
|
+
self.remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
260
|
+
|
|
261
|
+
_check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, exception_save,
|
|
262
|
+
self._map_param_inc)
|
|
236
263
|
|
|
237
264
|
@property
|
|
238
265
|
def save_checkpoint_steps(self):
|
|
@@ -333,6 +360,10 @@ class CheckpointConfig:
|
|
|
333
360
|
"""
|
|
334
361
|
return self._crc_check
|
|
335
362
|
|
|
363
|
+
@property
|
|
364
|
+
def format(self):
|
|
365
|
+
return self._format
|
|
366
|
+
|
|
336
367
|
@property
|
|
337
368
|
def append_dict(self):
|
|
338
369
|
"""
|
|
@@ -495,10 +526,10 @@ class ModelCheckpoint(Callback):
|
|
|
495
526
|
self._aiturbo_init_flag = os.getenv("AITURBO") == "1"
|
|
496
527
|
# get existing checkpoint files
|
|
497
528
|
if self._aiturbo_init_flag:
|
|
498
|
-
import
|
|
499
|
-
self._manager =
|
|
529
|
+
from aiturbo.checkpoint.aiturbo_mindspore_ckpt import CheckpointShmManager
|
|
530
|
+
self._manager = CheckpointShmManager()
|
|
500
531
|
else:
|
|
501
|
-
self._manager = CheckpointManager()
|
|
532
|
+
self._manager = CheckpointManager(self._config.format)
|
|
502
533
|
if not callable(directory) and not callable(prefix):
|
|
503
534
|
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
|
504
535
|
self._append_dict = self._config.append_dict or {}
|
|
@@ -517,7 +548,7 @@ class ModelCheckpoint(Callback):
|
|
|
517
548
|
"""
|
|
518
549
|
cb_params = run_context.original_args()
|
|
519
550
|
if self._aiturbo_init_flag:
|
|
520
|
-
import aiturbo
|
|
551
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
521
552
|
ckpt_storage_path = self._directory
|
|
522
553
|
rank_id = get_rank()
|
|
523
554
|
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
@@ -536,7 +567,7 @@ class ModelCheckpoint(Callback):
|
|
|
536
567
|
"stage_layout": param_redundancy_dict}
|
|
537
568
|
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
538
569
|
single_params = {device_id: list(params) for device_id, params in single_params.items()}
|
|
539
|
-
aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, self._config.enable_redundance, dp)
|
|
570
|
+
aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, not self._config.enable_redundance, dp)
|
|
540
571
|
self._aiturbo_init_flag = False
|
|
541
572
|
if self._prefix_func:
|
|
542
573
|
self._prefix = self._prefix_func(cb_params)
|
|
@@ -546,14 +577,14 @@ class ModelCheckpoint(Callback):
|
|
|
546
577
|
"string that does not contain '/', but got {}.".format(self._prefix))
|
|
547
578
|
if self._directory_func:
|
|
548
579
|
self._directory = self._directory_func(cb_params)
|
|
549
|
-
|
|
580
|
+
_make_directory(self._directory)
|
|
581
|
+
collect_host_info("Callback", "ModelCheckpoint", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
550
582
|
# In disaster recovery scenario, the training process may be rolled back to the last step where
|
|
551
583
|
# the ckpt was successfully saved, so the _last_triggered_step should be updated.
|
|
552
584
|
if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
|
|
553
585
|
self._last_triggered_step = cb_params.last_save_ckpt_step
|
|
554
586
|
cb_params.last_save_ckpt_step = None
|
|
555
587
|
|
|
556
|
-
_make_directory(self._directory)
|
|
557
588
|
# save graph (only once)
|
|
558
589
|
if not self._graph_saved:
|
|
559
590
|
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
|
@@ -561,10 +592,6 @@ class ModelCheckpoint(Callback):
|
|
|
561
592
|
os.remove(graph_file_name)
|
|
562
593
|
_save_graph(cb_params.train_network, graph_file_name)
|
|
563
594
|
self._graph_saved = True
|
|
564
|
-
thread_list = threading.enumerate()
|
|
565
|
-
for thread in thread_list:
|
|
566
|
-
if thread.getName() == "asyn_save_ckpt":
|
|
567
|
-
thread.join()
|
|
568
595
|
self._save_ckpt(cb_params)
|
|
569
596
|
|
|
570
597
|
def end(self, run_context):
|
|
@@ -575,15 +602,12 @@ class ModelCheckpoint(Callback):
|
|
|
575
602
|
run_context (RunContext): Context of the train running.
|
|
576
603
|
"""
|
|
577
604
|
cb_params = run_context.original_args()
|
|
578
|
-
|
|
605
|
+
collect_host_info("Callback", "ModelCheckpoint", "end", start_time=get_clock_syscnt(), level=1)
|
|
579
606
|
_to_save_last_ckpt = True
|
|
580
607
|
|
|
581
608
|
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
|
582
609
|
|
|
583
|
-
|
|
584
|
-
for thread in thread_list:
|
|
585
|
-
if thread.getName() == "asyn_save_ckpt":
|
|
586
|
-
thread.join()
|
|
610
|
+
_wait_async_save_ckpt(self._config.async_save)
|
|
587
611
|
|
|
588
612
|
destroy_allgather_cell()
|
|
589
613
|
|
|
@@ -601,6 +625,13 @@ class ModelCheckpoint(Callback):
|
|
|
601
625
|
|
|
602
626
|
return False
|
|
603
627
|
|
|
628
|
+
def _append_dict_content(self, epoch_num, step_num):
|
|
629
|
+
"""Append append_dict content."""
|
|
630
|
+
if "epoch_num" in self._append_dict:
|
|
631
|
+
self._append_dict["epoch_num"] = self._append_epoch_num + epoch_num
|
|
632
|
+
if "step_num" in self._append_dict:
|
|
633
|
+
self._append_dict["step_num"] = self._append_step_num + step_num
|
|
634
|
+
|
|
604
635
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
|
605
636
|
"""Save checkpoint files."""
|
|
606
637
|
if cb_params.cur_step_num == self._last_triggered_step:
|
|
@@ -614,11 +645,12 @@ class ModelCheckpoint(Callback):
|
|
|
614
645
|
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
615
646
|
|
|
616
647
|
if save_ckpt:
|
|
648
|
+
_wait_async_save_ckpt(self._config.async_save)
|
|
617
649
|
if self._prefix_func:
|
|
618
|
-
cur_ckpoint_file = self._prefix + ".
|
|
650
|
+
cur_ckpoint_file = self._prefix + f".{self._config.format}"
|
|
619
651
|
else:
|
|
620
652
|
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
|
621
|
-
|
|
653
|
+
+ str(step_num_in_epoch) + f".{self._config.format}"
|
|
622
654
|
# update checkpoint file list.
|
|
623
655
|
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
|
624
656
|
# keep checkpoint files number equal max number.
|
|
@@ -644,20 +676,51 @@ class ModelCheckpoint(Callback):
|
|
|
644
676
|
set_cur_net(cb_params.train_network)
|
|
645
677
|
cb_params.train_network.add_flags(ge_sync_data=True)
|
|
646
678
|
_cell_graph_executor(cb_params.train_network, phase='save')
|
|
647
|
-
|
|
648
|
-
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
|
|
649
|
-
if "step_num" in self._append_dict:
|
|
650
|
-
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
|
|
679
|
+
self._append_dict_content(cb_params.cur_epoch_num, cb_params.cur_step_num)
|
|
651
680
|
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
|
652
681
|
if os.getenv("AITURBO") == "1":
|
|
653
682
|
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
|
|
654
683
|
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
655
684
|
crc_check=self._config.crc_check, incremental=self._map_param_inc,
|
|
656
685
|
global_step_num=cb_params.cur_step_num)
|
|
686
|
+
elif self._config.remove_redundancy:
|
|
687
|
+
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
688
|
+
if parallel_mode == "stand_alone":
|
|
689
|
+
raise TypeError(f"The deduplication feature for saving checkpoint can only be used "
|
|
690
|
+
f"in parallel scenarios, but got {parallel_mode}.")
|
|
691
|
+
param_layout = network.parameter_layout_dict
|
|
692
|
+
rank_id = get_rank()
|
|
693
|
+
if param_layout:
|
|
694
|
+
device_num = _get_device_num()
|
|
695
|
+
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
696
|
+
chunk_size = device_num // stage_num
|
|
697
|
+
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
698
|
+
param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
|
|
699
|
+
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
700
|
+
save_param_names = single_params.get(rank_id)
|
|
701
|
+
param_layout_set = set(param_layout.keys())
|
|
702
|
+
if save_param_names == param_layout.keys():
|
|
703
|
+
logger.warning(
|
|
704
|
+
f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.")
|
|
705
|
+
|
|
706
|
+
def choice_func(x):
|
|
707
|
+
return x not in param_layout_set or x in save_param_names
|
|
708
|
+
else:
|
|
709
|
+
param_redundancy_dict = get_parameter_redundancy(network)
|
|
710
|
+
single_params = remove_param_redundancy(param_redundancy_dict)
|
|
711
|
+
save_param_names = single_params.get(rank_id)
|
|
712
|
+
|
|
713
|
+
def choice_func(x):
|
|
714
|
+
return x in save_param_names
|
|
715
|
+
save_checkpoint(network, cur_file, False, self._config.async_save,
|
|
716
|
+
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
717
|
+
crc_check=self._config.crc_check, format=self._config.format,
|
|
718
|
+
incremental=self._map_param_inc, choice_func=choice_func)
|
|
657
719
|
else:
|
|
658
720
|
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
|
|
659
721
|
self._append_dict, self._config.enc_key, self._config.enc_mode,
|
|
660
|
-
crc_check=self._config.crc_check,
|
|
722
|
+
crc_check=self._config.crc_check, format=self._config.format,
|
|
723
|
+
incremental=self._map_param_inc)
|
|
661
724
|
|
|
662
725
|
self._latest_ckpt_file_name = cur_file
|
|
663
726
|
|
|
@@ -691,8 +754,9 @@ class ModelCheckpoint(Callback):
|
|
|
691
754
|
class CheckpointManager:
|
|
692
755
|
"""Manage checkpoint files according to train_config of checkpoint."""
|
|
693
756
|
|
|
694
|
-
def __init__(self):
|
|
757
|
+
def __init__(self, format='ckpt'):
|
|
695
758
|
self._ckpoint_filelist = []
|
|
759
|
+
self._format = format
|
|
696
760
|
|
|
697
761
|
@property
|
|
698
762
|
def ckpoint_filelist(self):
|
|
@@ -707,10 +771,12 @@ class CheckpointManager:
|
|
|
707
771
|
def update_ckpoint_filelist(self, directory, prefix):
|
|
708
772
|
"""Update the checkpoint file list."""
|
|
709
773
|
self._ckpoint_filelist = []
|
|
774
|
+
format = self._format
|
|
775
|
+
format_length = len(format) + 1
|
|
710
776
|
files = os.listdir(directory)
|
|
711
777
|
for filename in files:
|
|
712
|
-
if os.path.splitext(filename)[-1] == ".
|
|
713
|
-
mid_name = filename[len(prefix):-
|
|
778
|
+
if os.path.splitext(filename)[-1] == f".{format}" and filename.startswith(prefix + "-"):
|
|
779
|
+
mid_name = filename[len(prefix):-format_length]
|
|
714
780
|
flag = not (True in [char.isalpha() for char in mid_name])
|
|
715
781
|
if flag:
|
|
716
782
|
self._ckpoint_filelist.append(os.path.join(directory, filename))
|
|
@@ -150,7 +150,7 @@ class ClusterMonitor(Callback):
|
|
|
150
150
|
with _perf_mutex:
|
|
151
151
|
dir_path = os.path.dirname(self.full_path)
|
|
152
152
|
if not os.path.exists(dir_path):
|
|
153
|
-
os.makedirs(dir_path)
|
|
153
|
+
os.makedirs(dir_path, mode=0o700)
|
|
154
154
|
if os.path.exists(self.full_path):
|
|
155
155
|
os.chmod(self.full_path, stat.S_IWUSR)
|
|
156
156
|
os.remove(self.full_path)
|
|
@@ -65,6 +65,7 @@ class FlopsUtilizationCollector(Callback):
|
|
|
65
65
|
Raises:
|
|
66
66
|
TypeError: If data_size is not positive int.
|
|
67
67
|
TypeError: If full_flops is not bool.
|
|
68
|
+
AssertionError: If the training mode is not a static graph or not a static shape.
|
|
68
69
|
|
|
69
70
|
Examples:
|
|
70
71
|
>>> import numpy as np
|
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
|
19
19
|
|
|
20
20
|
from mindspore import _checkparam as Validator
|
|
21
21
|
from mindspore.train.callback._callback import Callback, _handle_loss
|
|
22
|
-
from mindspore._c_expression import
|
|
22
|
+
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class LossMonitor(Callback):
|
|
@@ -70,7 +70,7 @@ class LossMonitor(Callback):
|
|
|
70
70
|
please refer to :class:`mindspore.train.RunContext`.
|
|
71
71
|
"""
|
|
72
72
|
cb_params = run_context.original_args()
|
|
73
|
-
|
|
73
|
+
collect_host_info("Callback", "LossMonitor", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
74
74
|
cur_epoch_num = cb_params.get("cur_epoch_num", 1)
|
|
75
75
|
loss = _handle_loss(cb_params.net_outputs)
|
|
76
76
|
|
|
@@ -101,7 +101,7 @@ class LossMonitor(Callback):
|
|
|
101
101
|
please refer to :class:`mindspore.train.RunContext`.
|
|
102
102
|
"""
|
|
103
103
|
cb_params = run_context.original_args()
|
|
104
|
-
|
|
104
|
+
collect_host_info("Callback", "LossMonitor", "train_epoch_end", start_time=get_clock_syscnt(), level=1)
|
|
105
105
|
metrics = cb_params.get("metrics")
|
|
106
106
|
if metrics:
|
|
107
107
|
print("Eval result: epoch %d, metrics: %s" % (cb_params.cur_epoch_num, metrics))
|
|
@@ -16,12 +16,19 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
import os
|
|
19
|
+
import json
|
|
19
20
|
import signal
|
|
20
|
-
|
|
21
|
-
from mindspore import
|
|
21
|
+
import threading
|
|
22
|
+
from mindspore.common import dtype as mstype
|
|
23
|
+
from mindspore import context
|
|
24
|
+
from mindspore import log as logger
|
|
25
|
+
from mindspore.common.tensor import Tensor
|
|
26
|
+
from mindspore.train._utils import _make_directory
|
|
22
27
|
from mindspore import _checkparam as Validator
|
|
23
28
|
from mindspore.train.serialization import load_checkpoint, save_checkpoint, export
|
|
24
29
|
from mindspore.train.callback._callback import Callback
|
|
30
|
+
from mindspore.parallel._utils import _get_parallel_mode
|
|
31
|
+
from mindspore.context import ParallelMode
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
class OnRequestExit(Callback):
|
|
@@ -29,7 +36,8 @@ class OnRequestExit(Callback):
|
|
|
29
36
|
Respond to the user's closing request, exit the training or eval process, and save the checkpoint and mindir.
|
|
30
37
|
|
|
31
38
|
Register OnRequestExit Callback before training, when the user want to exit the training process
|
|
32
|
-
and save the training data, could send the registered exit signal 'sig' to the training process
|
|
39
|
+
and save the training data, could send the registered exit signal 'sig' to the training process or modify the
|
|
40
|
+
'GracefulExit' that a key in the json file specified by the 'config_file' to '1'.
|
|
33
41
|
After the training process executes the current step, saves the current training status,
|
|
34
42
|
including checkpoint and mindir, and then exit the training process.
|
|
35
43
|
|
|
@@ -38,9 +46,12 @@ class OnRequestExit(Callback):
|
|
|
38
46
|
save_mindir (bool): Whether save the mindir before the training process exit. Default: ``True`` .
|
|
39
47
|
file_name (str): The saved checkpoint and mindir file name,
|
|
40
48
|
the checkpoint file add suffix '.ckpt', the mindir file add suffix '.mindir'. Default: ``'Net'`` .
|
|
41
|
-
directory (str): The
|
|
49
|
+
directory (str): The path to save files. It will generate a 'rank_{id}' path by rank_id
|
|
50
|
+
to save checkpoint and mindir. Default: ``'./'`` .
|
|
42
51
|
sig (int): The user registered exit signal, it must be a captureable and negligible signal.
|
|
43
52
|
When the process receives the signal, exits the training or eval process. Default: ``signal.SIGTERM`` .
|
|
53
|
+
config_file (str): A json config file used to exit training process gracefully. Key: ``{"GracefulExit": 1}`` .
|
|
54
|
+
Default: ``None`` .
|
|
44
55
|
|
|
45
56
|
Raises:
|
|
46
57
|
ValueError: If the 'save_ckpt' is not a bool.
|
|
@@ -67,20 +78,28 @@ class OnRequestExit(Callback):
|
|
|
67
78
|
>>> model.train(10, dataset, callbacks=on_request_exit)
|
|
68
79
|
"""
|
|
69
80
|
|
|
70
|
-
def __init__(self, save_ckpt=True, save_mindir=True, file_name='Net', directory='./',
|
|
81
|
+
def __init__(self, save_ckpt=True, save_mindir=True, file_name='Net', directory='./', config_file=None,
|
|
82
|
+
sig=signal.SIGTERM):
|
|
71
83
|
super(OnRequestExit, self).__init__()
|
|
72
84
|
self.save_ckpt = Validator.check_isinstance('save_ckpt', save_ckpt, bool)
|
|
73
85
|
self.save_mindir = Validator.check_isinstance('save_mindir', save_mindir, bool)
|
|
74
|
-
if self.save_ckpt or self.save_mindir:
|
|
75
|
-
file_name = Validator.check_isinstance('file_name', file_name, str)
|
|
76
|
-
directory = Validator.check_isinstance('directory', directory, str)
|
|
77
|
-
os.makedirs(os.path.abspath(directory), exist_ok=True)
|
|
78
|
-
self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train"))
|
|
79
|
-
self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval"))
|
|
80
86
|
self.sig = Validator.check_isinstance('sig', sig, int)
|
|
81
87
|
if hasattr(signal, "SIGKILL") and self.sig == signal.SIGKILL:
|
|
82
88
|
raise ValueError("Not support send exit request by signal SIGKILL.")
|
|
83
|
-
self.exit = False
|
|
89
|
+
self.exit = False # used signal to exit the training process
|
|
90
|
+
self.lock = threading.Lock()
|
|
91
|
+
self.save_path = directory
|
|
92
|
+
self.key = "GracefulExit"
|
|
93
|
+
self.remote_config_file = config_file # used config file to save checkpoint and exit training process
|
|
94
|
+
self.use_graceful = os.environ.get("MS_ENABLE_GRACEFUL_EXIT") == "1"
|
|
95
|
+
self.is_distributed = _get_parallel_mode() != ParallelMode.STAND_ALONE
|
|
96
|
+
self.integrated_save = True
|
|
97
|
+
if self.is_distributed:
|
|
98
|
+
self.integrated_save = _get_parallel_mode() == ParallelMode.AUTO_PARALLEL
|
|
99
|
+
self.stop_train = False
|
|
100
|
+
self.need_do_step_end = False
|
|
101
|
+
if self.save_ckpt or self.save_mindir:
|
|
102
|
+
self.train_name, self.eval_name = self._get_save_path(file_name)
|
|
84
103
|
|
|
85
104
|
def on_train_begin(self, run_context):
|
|
86
105
|
"""
|
|
@@ -91,22 +110,31 @@ class OnRequestExit(Callback):
|
|
|
91
110
|
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
92
111
|
"""
|
|
93
112
|
signal.signal(self.sig, self._handle_signal)
|
|
94
|
-
if self.save_ckpt and os.path.isfile(f"{self.
|
|
113
|
+
if self.save_ckpt and os.path.isfile(f"{self.train_name}.ckpt"):
|
|
95
114
|
cb_params = run_context.original_args()
|
|
96
115
|
train_net = cb_params.train_network
|
|
97
|
-
load_checkpoint(f"{self.
|
|
116
|
+
load_checkpoint(f"{self.train_name}.ckpt", net=train_net)
|
|
117
|
+
|
|
118
|
+
def on_train_step_begin(self, run_context):
|
|
119
|
+
"""
|
|
120
|
+
Check whether received the exit signal or
|
|
121
|
+
whether the value of 'GracefulExit' in 'config_file' was changed to '1'.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
run_context (RunContext): Context information of the model.
|
|
125
|
+
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
126
|
+
"""
|
|
127
|
+
self._do_step_begin(run_context)
|
|
98
128
|
|
|
99
129
|
def on_train_step_end(self, run_context):
|
|
100
130
|
"""
|
|
101
|
-
|
|
102
|
-
Then exit the training process after this step training.
|
|
131
|
+
Save checkpoint file or mindir file according to config, and exit the training process.
|
|
103
132
|
|
|
104
133
|
Args:
|
|
105
134
|
run_context (RunContext): Include some information of the model.
|
|
106
135
|
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
107
136
|
"""
|
|
108
|
-
|
|
109
|
-
run_context.request_stop()
|
|
137
|
+
self._do_step_end(run_context)
|
|
110
138
|
|
|
111
139
|
def on_train_epoch_end(self, run_context):
|
|
112
140
|
"""
|
|
@@ -118,8 +146,7 @@ class OnRequestExit(Callback):
|
|
|
118
146
|
run_context (RunContext): Include some information of the model.
|
|
119
147
|
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
120
148
|
"""
|
|
121
|
-
|
|
122
|
-
run_context.request_stop()
|
|
149
|
+
self._do_step_end(run_context)
|
|
123
150
|
|
|
124
151
|
def on_train_end(self, run_context):
|
|
125
152
|
"""
|
|
@@ -135,10 +162,10 @@ class OnRequestExit(Callback):
|
|
|
135
162
|
cb_params = run_context.original_args()
|
|
136
163
|
train_net = cb_params.train_network
|
|
137
164
|
if self.save_ckpt:
|
|
138
|
-
save_checkpoint(train_net, ckpt_file_name=self.
|
|
165
|
+
save_checkpoint(train_net, ckpt_file_name=self.train_name)
|
|
139
166
|
if self.save_mindir:
|
|
140
167
|
inputs = cb_params.train_dataset_element
|
|
141
|
-
export(train_net, *inputs, file_name=self.
|
|
168
|
+
export(train_net, *inputs, file_name=self.train_name, file_format='MINDIR')
|
|
142
169
|
|
|
143
170
|
def on_eval_begin(self, run_context):
|
|
144
171
|
"""
|
|
@@ -153,15 +180,15 @@ class OnRequestExit(Callback):
|
|
|
153
180
|
return
|
|
154
181
|
cb_params = run_context.original_args()
|
|
155
182
|
eval_net = cb_params.eval_network
|
|
156
|
-
if os.path.isfile(f"{self.
|
|
157
|
-
load_checkpoint(f"{self.
|
|
158
|
-
elif os.path.isfile(f"{self.
|
|
159
|
-
load_checkpoint(f"{self.
|
|
183
|
+
if os.path.isfile(f"{self.eval_name}.ckpt"):
|
|
184
|
+
load_checkpoint(f"{self.eval_name}.ckpt", net=eval_net)
|
|
185
|
+
elif os.path.isfile(f"{self.train_name}.ckpt"):
|
|
186
|
+
load_checkpoint(f"{self.train_name}.ckpt", net=eval_net)
|
|
160
187
|
|
|
161
188
|
def on_eval_step_end(self, run_context):
|
|
162
189
|
"""
|
|
163
|
-
When the eval step end, if received the exit signal, set
|
|
164
|
-
Then exit the eval process after this step eval.
|
|
190
|
+
When the eval step end, if received the exit signal, set attribute '_stop_requested' of the
|
|
191
|
+
'run_context' to True. Then exit the eval process after this step eval.
|
|
165
192
|
|
|
166
193
|
Args:
|
|
167
194
|
run_context (RunContext): Include some information of the model.
|
|
@@ -184,12 +211,99 @@ class OnRequestExit(Callback):
|
|
|
184
211
|
cb_params = run_context.original_args()
|
|
185
212
|
eval_net = cb_params.eval_network
|
|
186
213
|
if self.save_ckpt:
|
|
187
|
-
save_checkpoint(eval_net, ckpt_file_name=self.
|
|
214
|
+
save_checkpoint(eval_net, ckpt_file_name=self.eval_name)
|
|
188
215
|
if self.save_mindir:
|
|
189
216
|
inputs = cb_params.eval_dataset_element
|
|
190
|
-
export(eval_net, *inputs, file_name=self.
|
|
217
|
+
export(eval_net, *inputs, file_name=self.eval_name, file_format='MINDIR')
|
|
191
218
|
|
|
192
219
|
def _handle_signal(self, signum, frame):
|
|
193
220
|
"""Handle the received signal"""
|
|
194
|
-
|
|
221
|
+
logger.debug(f"signum: {signum}, frame: {frame}")
|
|
195
222
|
self.exit = True
|
|
223
|
+
|
|
224
|
+
def _do_step_end(self, run_context):
|
|
225
|
+
"""
|
|
226
|
+
Save the checkpoint or mindir, and then exit training process.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
run_context (RunContext): Include some information of the model.
|
|
230
|
+
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
231
|
+
"""
|
|
232
|
+
with self.lock:
|
|
233
|
+
# save once
|
|
234
|
+
if self.stop_train or not self.need_do_step_end:
|
|
235
|
+
return
|
|
236
|
+
logger.info("Gracefully exiting training process on step end.")
|
|
237
|
+
call_params = run_context.original_args()
|
|
238
|
+
net = call_params.train_network
|
|
239
|
+
for _, param in net.parameters_and_names():
|
|
240
|
+
if param.name == "graceful_exit" and param.asnumpy() == True: # pylint: disable=C0121
|
|
241
|
+
logger.warning("Graceful exit is triggered, stop training.")
|
|
242
|
+
if self.save_ckpt:
|
|
243
|
+
append_dict = {"epoch_num": call_params.cur_epoch_num,
|
|
244
|
+
"step_num": call_params.cur_step_num,
|
|
245
|
+
"batch_num": call_params.batch_num}
|
|
246
|
+
if call_params.loss_scale_mananger is not None:
|
|
247
|
+
append_dict["loss_scale"] = call_params.loss_scale_mananger.get_loss_scale()
|
|
248
|
+
if call_params.optimizer is not None:
|
|
249
|
+
global_step = int(call_params.optimizer.global_step.data)
|
|
250
|
+
else:
|
|
251
|
+
global_step = int(call_params.network.optimizer.global_step.data)
|
|
252
|
+
append_dict["global_step"] = global_step
|
|
253
|
+
save_checkpoint(net, self.train_name, integrated_save=self.integrated_save,
|
|
254
|
+
append_dict=append_dict)
|
|
255
|
+
if self.save_mindir:
|
|
256
|
+
inputs = call_params.train_dataset_element
|
|
257
|
+
export(net, *inputs, file_name=self.train_name, file_format='MINDIR')
|
|
258
|
+
run_context.request_stop()
|
|
259
|
+
self.stop_train = True
|
|
260
|
+
|
|
261
|
+
def _do_step_begin(self, run_context):
|
|
262
|
+
"""
|
|
263
|
+
Check training process exit configuration at the step begin.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
run_context (RunContext): Include some information of the model.
|
|
267
|
+
For more details, please refer to :class:`mindspore.train.RunContext`.
|
|
268
|
+
"""
|
|
269
|
+
with self.lock:
|
|
270
|
+
# no env
|
|
271
|
+
if not self.use_graceful:
|
|
272
|
+
return
|
|
273
|
+
if self._check_config_info() or self.exit:
|
|
274
|
+
call_params = run_context.original_args()
|
|
275
|
+
net = call_params.train_network
|
|
276
|
+
for _, param in net.parameters_and_names():
|
|
277
|
+
if not self.is_distributed and param.name == "graceful_exit":
|
|
278
|
+
param.set_data(Tensor(True, mstype.bool_))
|
|
279
|
+
self.need_do_step_end = True
|
|
280
|
+
break
|
|
281
|
+
if param.name == "graceful_init":
|
|
282
|
+
param.set_data(Tensor([1], mstype.int32))
|
|
283
|
+
self.need_do_step_end = True
|
|
284
|
+
break
|
|
285
|
+
|
|
286
|
+
def _check_config_info(self):
|
|
287
|
+
"""check json config info"""
|
|
288
|
+
if self.remote_config_file is not None and os.path.exists(self.remote_config_file):
|
|
289
|
+
with open(self.remote_config_file, "r") as f:
|
|
290
|
+
try:
|
|
291
|
+
config_info = json.load(f)
|
|
292
|
+
except json.JSONDecodeError as e:
|
|
293
|
+
logger.warning(f"Parse json file failed: {e}, please check json file: {self.remote_config_file}")
|
|
294
|
+
return False
|
|
295
|
+
if self.key in config_info and config_info[self.key] == 1:
|
|
296
|
+
return True
|
|
297
|
+
return False
|
|
298
|
+
|
|
299
|
+
def _get_save_path(self, file_name):
|
|
300
|
+
"""path to save checkpoint files or mindir files"""
|
|
301
|
+
device_id = context.get_context("device_id")
|
|
302
|
+
if self.save_path is None:
|
|
303
|
+
tmp = os.path.join(os.getcwd(), r"rank_" + str(device_id))
|
|
304
|
+
path_ = _make_directory(tmp)
|
|
305
|
+
return os.path.join(path_, f"{file_name}_train"), os.path.join(path_, f"{file_name}_eval")
|
|
306
|
+
|
|
307
|
+
save_path = os.path.join(self.save_path, r"rank_" + str(device_id))
|
|
308
|
+
save_path = _make_directory(save_path)
|
|
309
|
+
return os.path.join(save_path, f"{file_name}_train"), os.path.join(save_path, f"{file_name}_eval")
|
|
@@ -41,7 +41,7 @@ from mindspore.nn.optim.optimizer import Optimizer
|
|
|
41
41
|
from mindspore.nn.loss.loss import LossBase
|
|
42
42
|
from mindspore.train._utils import check_value_type, _make_directory
|
|
43
43
|
from mindspore._c_expression import security
|
|
44
|
-
from mindspore._c_expression import
|
|
44
|
+
from mindspore._c_expression import collect_host_info, get_clock_syscnt
|
|
45
45
|
|
|
46
46
|
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
|
|
47
47
|
HYPER_CONFIG_LEN_LIMIT = 100000
|
|
@@ -472,7 +472,7 @@ class SummaryCollector(Callback):
|
|
|
472
472
|
|
|
473
473
|
def begin(self, run_context):
|
|
474
474
|
cb_params = run_context.original_args()
|
|
475
|
-
|
|
475
|
+
collect_host_info("Callback", "SummaryCollector", "begin", start_time=get_clock_syscnt(), level=1)
|
|
476
476
|
self._check_callbacks(cb_params)
|
|
477
477
|
|
|
478
478
|
if cb_params.mode not in ModeEnum.to_list():
|
|
@@ -484,7 +484,7 @@ class SummaryCollector(Callback):
|
|
|
484
484
|
|
|
485
485
|
def step_end(self, run_context):
|
|
486
486
|
cb_params = run_context.original_args()
|
|
487
|
-
|
|
487
|
+
collect_host_info("Callback", "SummaryCollector", "step_end", start_time=get_clock_syscnt(), level=1)
|
|
488
488
|
if cb_params.mode != ModeEnum.TRAIN.value:
|
|
489
489
|
return
|
|
490
490
|
|
|
@@ -559,7 +559,7 @@ class SummaryCollector(Callback):
|
|
|
559
559
|
|
|
560
560
|
def epoch_end(self, run_context):
|
|
561
561
|
cb_params = run_context.original_args()
|
|
562
|
-
|
|
562
|
+
collect_host_info("Callback", "SummaryCollector", "epoch_end", start_time=get_clock_syscnt(), level=1)
|
|
563
563
|
self._collect_tensor_data(cb_params)
|
|
564
564
|
collect_landscape = self._collect_specified_data.get('collect_landscape')
|
|
565
565
|
if collect_landscape is not None:
|
|
@@ -576,7 +576,7 @@ class SummaryCollector(Callback):
|
|
|
576
576
|
|
|
577
577
|
def end(self, run_context):
|
|
578
578
|
cb_params = run_context.original_args()
|
|
579
|
-
|
|
579
|
+
collect_host_info("Callback", "SummaryCollector", "end", start_time=get_clock_syscnt(), level=1)
|
|
580
580
|
if cb_params.mode == ModeEnum.TRAIN.value:
|
|
581
581
|
self._collect_train_lineage(cb_params)
|
|
582
582
|
else:
|