mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -36,7 +36,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
|
36
36
|
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
37
|
from mindspore import _checkparam as Validator
|
|
38
38
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
-
FlopsUtilizationCollector,
|
|
39
|
+
FlopsUtilizationCollector, TFTRegister
|
|
40
40
|
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
41
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
42
|
from mindspore import context
|
|
@@ -46,6 +46,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
|
|
|
46
46
|
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
47
47
|
_cache_enable, _enable_distributed_mindrt
|
|
48
48
|
from mindspore.train.metrics import Loss
|
|
49
|
+
from mindspore.train._utils import vlog_print
|
|
49
50
|
from mindspore import nn
|
|
50
51
|
from mindspore.boost import AutoBoost
|
|
51
52
|
from mindspore.context import ParallelMode
|
|
@@ -119,6 +120,101 @@ def _save_final_ckpt(func):
|
|
|
119
120
|
func(self, *args, **kwargs)
|
|
120
121
|
return wrapper
|
|
121
122
|
|
|
123
|
+
def _handle_tft(func):
|
|
124
|
+
"""
|
|
125
|
+
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
126
|
+
"""
|
|
127
|
+
@wraps(func)
|
|
128
|
+
def wrapper(self, *args, **kwargs):
|
|
129
|
+
obj = None
|
|
130
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
|
|
131
|
+
obj = kwargs.get('callbacks')
|
|
132
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
133
|
+
for item in kwargs.get('callbacks'):
|
|
134
|
+
if isinstance(item, TFTRegister):
|
|
135
|
+
obj = item
|
|
136
|
+
if obj:
|
|
137
|
+
tft = obj.tft
|
|
138
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
139
|
+
uce_env = "UCE:1" in tft_env
|
|
140
|
+
while True:
|
|
141
|
+
try:
|
|
142
|
+
return func(self, *args, **kwargs)
|
|
143
|
+
except RuntimeError as e:
|
|
144
|
+
logger.info("uce wrapper caught RuntimeError")
|
|
145
|
+
if not uce_env:
|
|
146
|
+
logger.info("uce wrapper caught RuntimeError uce not enable")
|
|
147
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
148
|
+
raise e
|
|
149
|
+
e_str = str(e)
|
|
150
|
+
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
151
|
+
if "UCEError" in e_str:
|
|
152
|
+
logger.info("uce wrapper report UCEError")
|
|
153
|
+
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
154
|
+
elif "ForceStopError" in e_str:
|
|
155
|
+
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
156
|
+
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
157
|
+
tft.tft_report_error(force_stop_err)
|
|
158
|
+
else:
|
|
159
|
+
logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
|
|
160
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
161
|
+
raise e
|
|
162
|
+
ret = tft.tft_wait_next_action()
|
|
163
|
+
if ret == tft.Action.EXIT.value:
|
|
164
|
+
raise e
|
|
165
|
+
repair_step = tft.tft_get_repair_step()
|
|
166
|
+
logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
|
|
167
|
+
{}".format(repair_step, self.batch_num))
|
|
168
|
+
initial_epoch = int(repair_step/self.batch_num)
|
|
169
|
+
initial_step = repair_step % self.batch_num
|
|
170
|
+
kwargs["initial_epoch"] = initial_epoch
|
|
171
|
+
|
|
172
|
+
train_dataset = args[1]
|
|
173
|
+
dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
|
|
174
|
+
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
175
|
+
|
|
176
|
+
cb_initial_step = 0
|
|
177
|
+
if dataset_sink_mode:
|
|
178
|
+
train_dataset.set_init_step(initial_epoch)
|
|
179
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
180
|
+
if sink_size != -1:
|
|
181
|
+
cb_initial_step = initial_epoch * sink_size + initial_step
|
|
182
|
+
else:
|
|
183
|
+
cb_initial_step = initial_epoch * dataset_size + initial_step
|
|
184
|
+
else:
|
|
185
|
+
train_dataset.set_init_step(initial_step)
|
|
186
|
+
cb_initial_step = initial_step
|
|
187
|
+
|
|
188
|
+
kwargs["initial_step"] = cb_initial_step
|
|
189
|
+
|
|
190
|
+
logger.info("uce wrapper repair complete \
|
|
191
|
+
initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
|
|
192
|
+
continue
|
|
193
|
+
except BaseException as e:
|
|
194
|
+
logger.info("uce wrapper caught BaseException error")
|
|
195
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
196
|
+
raise e
|
|
197
|
+
else:
|
|
198
|
+
return func(self, *args, **kwargs)
|
|
199
|
+
return wrapper
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _check_tft():
|
|
203
|
+
"""Check if TFT is supported"""
|
|
204
|
+
tft_env = os.getenv("MS_ENABLE_TFT")
|
|
205
|
+
device_target = context.get_context("device_target")
|
|
206
|
+
if tft_env and device_target == "Ascend":
|
|
207
|
+
from mindspore._c_expression import MSContext
|
|
208
|
+
ascend_target = MSContext.get_instance().get_ascend_soc_version()
|
|
209
|
+
if ascend_target == 'ascend910':
|
|
210
|
+
raise ValueError("TFT is not supported when using ascend910")
|
|
211
|
+
ms_mode = context.get_context("mode")
|
|
212
|
+
if ms_mode != mindspore.GRAPH_MODE:
|
|
213
|
+
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
214
|
+
jit_level = context.get_context("jit_level")
|
|
215
|
+
if jit_level == "O2" and "UCE:1" in tft_env:
|
|
216
|
+
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
217
|
+
|
|
122
218
|
|
|
123
219
|
def _append_ccae(callbacks):
|
|
124
220
|
"""Add cluster monitoring when CCAE is enabled."""
|
|
@@ -290,21 +386,11 @@ class Model:
|
|
|
290
386
|
amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
|
|
291
387
|
precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
292
388
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
|
|
296
|
-
Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
|
|
297
|
-
- "O2": Cast network to float16, keep BatchNorm run in float32, using dynamic loss scale.
|
|
298
|
-
- "O3": Cast network to float16, the BatchNorm is also cast to float16, loss scale will not be used.
|
|
299
|
-
- "auto": Set level to recommended level in different devices. Set level to "O2" on GPU, set
|
|
300
|
-
level to "O3" on Ascend. The recommended level is chosen by the expert experience, not applicable to all
|
|
301
|
-
scenarios. User should specify the level for special network.
|
|
302
|
-
|
|
303
|
-
"O2" is recommended on GPU, "O3" is recommended on Ascend.
|
|
389
|
+
For details on `amp_level` , refer to :func:`mindspore.amp.auto_mixed_precision`.
|
|
390
|
+
|
|
304
391
|
The BatchNorm strategy can be changed by `keep_batchnorm_fp32` settings in `kwargs`. `keep_batchnorm_fp32`
|
|
305
392
|
must be a bool. The loss scale strategy can be changed by `loss_scale_manager` setting in `kwargs`.
|
|
306
393
|
`loss_scale_manager` should be a subclass of :class:`mindspore.amp.LossScaleManager`.
|
|
307
|
-
The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network`.
|
|
308
394
|
|
|
309
395
|
boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
|
|
310
396
|
training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
|
|
@@ -379,6 +465,7 @@ class Model:
|
|
|
379
465
|
self._mindspore_lite = None
|
|
380
466
|
self._lite_infer = True # if backend lite infer fails, set False
|
|
381
467
|
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
468
|
+
self.batch_num = -1
|
|
382
469
|
|
|
383
470
|
def _check_for_graph_cell(self, kwargs):
|
|
384
471
|
"""Check for graph cell"""
|
|
@@ -568,9 +655,13 @@ class Model:
|
|
|
568
655
|
dataset.__loop_size__ = 1
|
|
569
656
|
|
|
570
657
|
if dataset_helper is None:
|
|
658
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to create DatasetHelper.")
|
|
659
|
+
logger.info("Begin to create DatasetHelper.")
|
|
571
660
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
|
572
661
|
|
|
573
662
|
if dataset_sink_mode:
|
|
663
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to connect network with dataset.")
|
|
664
|
+
logger.info("Begin to connect network with dataset.")
|
|
574
665
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
575
666
|
|
|
576
667
|
if _get_recovery_context("enable_recovery") and is_train:
|
|
@@ -589,6 +680,10 @@ class Model:
|
|
|
589
680
|
if self._backbone_is_train != is_train:
|
|
590
681
|
network.set_train(is_train)
|
|
591
682
|
self._backbone_is_train = is_train
|
|
683
|
+
# Mode train and eval are the same net, network will be set_grad in _build_train_network.
|
|
684
|
+
# But if mode just want to do predict or eval, must set network set_grad False
|
|
685
|
+
if not is_train:
|
|
686
|
+
network.set_grad(False)
|
|
592
687
|
return network
|
|
593
688
|
|
|
594
689
|
def _check_need_ckpt(self, callbacks):
|
|
@@ -687,6 +782,8 @@ class Model:
|
|
|
687
782
|
if not train_dataset and not valid_dataset:
|
|
688
783
|
raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
|
|
689
784
|
|
|
785
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to check device number in model.build().")
|
|
786
|
+
logger.info("Begin to check device number in model.build() procedure.")
|
|
690
787
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
691
788
|
|
|
692
789
|
if train_dataset:
|
|
@@ -694,27 +791,44 @@ class Model:
|
|
|
694
791
|
raise TypeError("The type of 'train_dataset' must be `Dataset`, "
|
|
695
792
|
"but got {}.".format(type(train_dataset)))
|
|
696
793
|
|
|
794
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
795
|
+
"Begin to check parameter broadcast in model.build().")
|
|
796
|
+
logger.info("Begin to check parameter broadcast in model.build() procedure.")
|
|
697
797
|
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
|
698
798
|
if self._parameter_broadcast:
|
|
699
799
|
self._train_network.set_broadcast_flag()
|
|
700
800
|
|
|
801
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to exec preprocess in model.build().")
|
|
802
|
+
logger.info("Begin to exec preprocess in model.build() procedure.")
|
|
701
803
|
train_dataset.__no_send__ = True
|
|
702
804
|
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
703
805
|
dataset=train_dataset,
|
|
704
806
|
dataset_sink_mode=True,
|
|
705
807
|
sink_size=sink_size)
|
|
808
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to warmup dataset in model.build().")
|
|
809
|
+
logger.info("Begin to warmup dataset in model.build() procedure.")
|
|
706
810
|
self._warmup_dataset(epoch, train_dataset, sink_size)
|
|
707
811
|
|
|
708
812
|
# Since dataset pipeline has been triggered, delete flag
|
|
709
813
|
delattr(train_dataset, "__no_send__")
|
|
710
814
|
|
|
711
815
|
# Waiting for the dataset warmup ready
|
|
816
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
817
|
+
"Begin waiting for dataset warmup in model.build().")
|
|
818
|
+
logger.info("Begin waiting for dataset warmup in model.build() procedure.")
|
|
712
819
|
self._waiting_for_dataset_warmup_ready(train_dataset)
|
|
820
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
821
|
+
"The dataset warmup was successful in model.build().")
|
|
822
|
+
logger.info("The dataset warmup was successful in model.build() procedure.")
|
|
713
823
|
|
|
714
824
|
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
|
715
825
|
train_network.add_flags_recursive(is_first_iteration=True)
|
|
716
826
|
for inputs in train_dataset_helper:
|
|
827
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
828
|
+
"Begin to compile train network in model.build().")
|
|
829
|
+
logger.info("Begin to compile train network in model.build() procedure.")
|
|
717
830
|
train_network.compile(*inputs)
|
|
831
|
+
self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
|
|
718
832
|
break
|
|
719
833
|
|
|
720
834
|
if valid_dataset:
|
|
@@ -732,6 +846,9 @@ class Model:
|
|
|
732
846
|
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
|
733
847
|
eval_network.add_flags_recursive(is_first_iteration=False)
|
|
734
848
|
for inputs in valid_dataset_helper:
|
|
849
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
850
|
+
"Begin to compile eval network in model.build().")
|
|
851
|
+
logger.info("Begin to compile eval network in model.build() procedure.")
|
|
735
852
|
eval_network.compile(*inputs)
|
|
736
853
|
break
|
|
737
854
|
|
|
@@ -746,9 +863,10 @@ class Model:
|
|
|
746
863
|
|
|
747
864
|
return [callbacks]
|
|
748
865
|
|
|
866
|
+
@_handle_tft
|
|
749
867
|
@_save_final_ckpt
|
|
750
868
|
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
|
|
751
|
-
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True):
|
|
869
|
+
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True, initial_step=0):
|
|
752
870
|
"""
|
|
753
871
|
Training.
|
|
754
872
|
|
|
@@ -772,12 +890,14 @@ class Model:
|
|
|
772
890
|
self._train_network.set_broadcast_flag()
|
|
773
891
|
|
|
774
892
|
cb_params = _InternalCallbackParam()
|
|
893
|
+
cb_params.cur_step_num = initial_step
|
|
775
894
|
cb_params.train_network = self._train_network
|
|
776
895
|
cb_params.epoch_num = epoch - initial_epoch
|
|
777
896
|
if dataset_sink_mode and sink_size > 0:
|
|
778
897
|
cb_params.batch_num = sink_size
|
|
779
898
|
else:
|
|
780
899
|
cb_params.batch_num = train_dataset.get_dataset_size()
|
|
900
|
+
self.batch_num = cb_params.batch_num
|
|
781
901
|
cb_params.mode = "train"
|
|
782
902
|
cb_params.loss_fn = self._loss_fn
|
|
783
903
|
cb_params.optimizer = self._optimizer
|
|
@@ -801,16 +921,19 @@ class Model:
|
|
|
801
921
|
epoch = 1
|
|
802
922
|
cb_params.last_save_ckpt_step = None
|
|
803
923
|
cb_params.latest_ckpt_file = None
|
|
924
|
+
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
804
925
|
|
|
805
926
|
# build callback list
|
|
806
927
|
with _CallbackManager(callbacks) as list_callback:
|
|
807
928
|
self._check_reuse_dataset(train_dataset)
|
|
808
929
|
if not dataset_sink_mode:
|
|
809
|
-
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
930
|
+
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
931
|
+
valid_infos)
|
|
810
932
|
elif context.get_context("device_target") == "CPU":
|
|
811
933
|
logger.info("The CPU cannot support dataset sink mode currently."
|
|
812
934
|
"So the training process will be performed with dataset not sink.")
|
|
813
|
-
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
935
|
+
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
936
|
+
valid_infos)
|
|
814
937
|
else:
|
|
815
938
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback,
|
|
816
939
|
cb_params, sink_size, initial_epoch, valid_infos)
|
|
@@ -850,9 +973,7 @@ class Model:
|
|
|
850
973
|
train_dataset.__total_batch__ = epoch * sink_size
|
|
851
974
|
|
|
852
975
|
cb_params.sink_size = sink_size
|
|
853
|
-
cb_params.cur_step_num = 0
|
|
854
976
|
cb_params.dataset_sink_mode = True
|
|
855
|
-
|
|
856
977
|
run_context = RunContext(cb_params)
|
|
857
978
|
list_callback.on_train_begin(run_context)
|
|
858
979
|
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
|
@@ -861,7 +982,6 @@ class Model:
|
|
|
861
982
|
dataset_helper = train_dataset._dataset_helper
|
|
862
983
|
|
|
863
984
|
self.epoch_iter = 0
|
|
864
|
-
|
|
865
985
|
self._check_enable_recovery()
|
|
866
986
|
# Used to check whether need perform recovery for process which is restarted.
|
|
867
987
|
self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
|
|
@@ -997,7 +1117,6 @@ class Model:
|
|
|
997
1117
|
dataset_size (int): The number of batches in a dataset.
|
|
998
1118
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
999
1119
|
"""
|
|
1000
|
-
|
|
1001
1120
|
if not self.enable_recovery:
|
|
1002
1121
|
self.need_load_ckpt = False
|
|
1003
1122
|
|
|
@@ -1084,7 +1203,6 @@ class Model:
|
|
|
1084
1203
|
dataset=train_dataset,
|
|
1085
1204
|
dataset_sink_mode=False,
|
|
1086
1205
|
epoch_num=epoch)
|
|
1087
|
-
cb_params.cur_step_num = 0
|
|
1088
1206
|
cb_params.dataset_sink_mode = False
|
|
1089
1207
|
run_context = RunContext(cb_params)
|
|
1090
1208
|
list_callback.on_train_begin(run_context)
|
|
@@ -1106,7 +1224,6 @@ class Model:
|
|
|
1106
1224
|
"returned by 'train_dataset'".format(len_element))
|
|
1107
1225
|
cb_params.cur_step_num += 1
|
|
1108
1226
|
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
1109
|
-
|
|
1110
1227
|
cb_params.train_dataset_element = next_element
|
|
1111
1228
|
list_callback.on_train_step_begin(run_context)
|
|
1112
1229
|
self._check_network_mode(self._train_network, True)
|
|
@@ -1150,31 +1267,6 @@ class Model:
|
|
|
1150
1267
|
|
|
1151
1268
|
list_callback.on_train_end(run_context)
|
|
1152
1269
|
|
|
1153
|
-
def _wrapper_train(self, callbacks):
|
|
1154
|
-
"""
|
|
1155
|
-
This method used to wrap train function with ttp wrapper which will do event notify when
|
|
1156
|
-
exceptions throw.
|
|
1157
|
-
|
|
1158
|
-
Args:
|
|
1159
|
-
callbacks (function): Callbacks passed by train method.
|
|
1160
|
-
"""
|
|
1161
|
-
|
|
1162
|
-
if not callbacks:
|
|
1163
|
-
return self._train
|
|
1164
|
-
cbs = callbacks if isinstance(callbacks, list) else [callbacks]
|
|
1165
|
-
obj = None
|
|
1166
|
-
_train_wrapper = None
|
|
1167
|
-
for item in cbs:
|
|
1168
|
-
if isinstance(item, MindIOTTPAdapter):
|
|
1169
|
-
obj = item
|
|
1170
|
-
|
|
1171
|
-
if (obj is not None) and (obj.enable is True):
|
|
1172
|
-
logger.info("MindIO TTP is enable, so we wrapper ttp exception handdler for self train method.")
|
|
1173
|
-
_train_wrapper = obj.wrapper_ttp_persist(self._train)
|
|
1174
|
-
|
|
1175
|
-
return self._train if not _train_wrapper else _train_wrapper
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
1270
|
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=False, sink_size=-1, initial_epoch=0):
|
|
1179
1271
|
"""
|
|
1180
1272
|
Training API.
|
|
@@ -1240,9 +1332,10 @@ class Model:
|
|
|
1240
1332
|
... loss_scale_manager=loss_scale_manager)
|
|
1241
1333
|
>>> model.train(2, dataset)
|
|
1242
1334
|
"""
|
|
1335
|
+
_check_tft()
|
|
1336
|
+
device_target = context.get_context("device_target")
|
|
1243
1337
|
# prepare dataset for obfuscated model
|
|
1244
1338
|
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1245
|
-
device_target = context.get_context("device_target")
|
|
1246
1339
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1247
1340
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1248
1341
|
dataset_sink_mode = False
|
|
@@ -1283,16 +1376,14 @@ class Model:
|
|
|
1283
1376
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
1284
1377
|
|
|
1285
1378
|
callbacks = _append_ccae(callbacks)
|
|
1286
|
-
_train_wrapper = None
|
|
1287
1379
|
if callbacks:
|
|
1288
1380
|
self._check_methods_for_custom_callbacks(callbacks, "train")
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
initial_epoch=initial_epoch)
|
|
1381
|
+
self._train(epoch,
|
|
1382
|
+
train_dataset,
|
|
1383
|
+
callbacks=callbacks,
|
|
1384
|
+
dataset_sink_mode=dataset_sink_mode,
|
|
1385
|
+
sink_size=sink_size,
|
|
1386
|
+
initial_epoch=initial_epoch)
|
|
1296
1387
|
|
|
1297
1388
|
# When it's distributed training and using MindRT,
|
|
1298
1389
|
# the node id should be reset to start from 0.
|
|
@@ -1396,7 +1487,7 @@ class Model:
|
|
|
1396
1487
|
|
|
1397
1488
|
Tutorial Examples:
|
|
1398
1489
|
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1399
|
-
<https://www.mindspore.cn/
|
|
1490
|
+
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1400
1491
|
"""
|
|
1401
1492
|
device_target = context.get_context("device_target")
|
|
1402
1493
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
@@ -1493,7 +1584,12 @@ class Model:
|
|
|
1493
1584
|
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1494
1585
|
self._train_network.check_names_and_refresh_name()
|
|
1495
1586
|
self._train_network._is_check_and_refresh = True
|
|
1587
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin to init dataset in model.build().")
|
|
1588
|
+
logger.info("Begin to init dataset in model.build() procedure.")
|
|
1496
1589
|
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1590
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1591
|
+
"The model.build() which contains dataset warmup and network compile is success.")
|
|
1592
|
+
logger.info("The model.build() which contains dataset warmup and network compile is success.")
|
|
1497
1593
|
|
|
1498
1594
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1499
1595
|
"""
|
|
@@ -1663,7 +1759,7 @@ class Model:
|
|
|
1663
1759
|
|
|
1664
1760
|
Tutorial Examples:
|
|
1665
1761
|
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1666
|
-
<https://www.mindspore.cn/
|
|
1762
|
+
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1667
1763
|
"""
|
|
1668
1764
|
valid_dataset = self._prepare_obf_dataset(valid_dataset)
|
|
1669
1765
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|