mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/swresample-4.dll
CHANGED
|
Binary file
|
mindspore/swscale-6.dll
CHANGED
|
Binary file
|
mindspore/tinyxml2.dll
CHANGED
|
Binary file
|
mindspore/train/__init__.py
CHANGED
|
@@ -27,10 +27,10 @@ from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleM
|
|
|
27
27
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
|
|
28
28
|
load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \
|
|
29
29
|
async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model, export_split_mindir, \
|
|
30
|
-
load_checkpoint_async, check_checkpoint
|
|
30
|
+
load_checkpoint_async, check_checkpoint, get_ckpt_path_with_strategy
|
|
31
31
|
from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
|
|
32
32
|
CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, FlopsUtilizationCollector, \
|
|
33
|
-
History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore,
|
|
33
|
+
History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore, TFTRegister
|
|
34
34
|
from mindspore.train.summary import SummaryRecord
|
|
35
35
|
from mindspore.train.train_thor import ConvertNetUtils, ConvertModelUtils
|
|
36
36
|
from mindspore.train.metrics import *
|
|
@@ -40,7 +40,8 @@ __all__ = ["Model", "DatasetHelper", "connect_network_with_dataset", "build_trai
|
|
|
40
40
|
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "check_checkpoint",
|
|
41
41
|
"load_param_into_net", "export", "load", "export_split_mindir", "parse_print", "build_searched_strategy",
|
|
42
42
|
"merge_sliced_parameter", "load_distributed_checkpoint", "async_ckpt_thread_status",
|
|
43
|
-
"restore_group_info_list", "convert_model", "data_sink", "obfuscate_model", "load_checkpoint_async"
|
|
43
|
+
"restore_group_info_list", "convert_model", "data_sink", "obfuscate_model", "load_checkpoint_async",
|
|
44
|
+
"get_ckpt_path_with_strategy"]
|
|
44
45
|
__all__.extend(callback.__all__)
|
|
45
46
|
__all__.extend(summary.__all__)
|
|
46
47
|
__all__.extend(train_thor.__all__)
|
mindspore/train/_utils.py
CHANGED
|
@@ -25,15 +25,18 @@ from mindspore.common.tensor import Tensor
|
|
|
25
25
|
from mindspore._c_expression import Tensor as Tensor_
|
|
26
26
|
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
|
27
27
|
from mindspore.common import dtype as mstype
|
|
28
|
+
from mindspore import context
|
|
28
29
|
from mindspore import log as logger
|
|
29
30
|
from mindspore import _checkparam as Validator
|
|
30
31
|
from mindspore.common.api import _cell_graph_executor
|
|
32
|
+
from mindspore.communication import get_group_size
|
|
31
33
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
32
34
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
33
35
|
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
|
|
34
36
|
from mindspore.train.lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
|
|
35
37
|
from mindspore.parallel._parallel_serialization import _make_dir
|
|
36
38
|
from mindspore.ops.operations import debug_ops
|
|
39
|
+
from mindspore.nn.cell import Cell
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
def _convert_type(types):
|
|
@@ -71,6 +74,11 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
|
|
|
71
74
|
queue_name = _cell_graph_executor.get_queue_name(phase)
|
|
72
75
|
if queue_name is None:
|
|
73
76
|
queue_name = str("")
|
|
77
|
+
|
|
78
|
+
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
79
|
+
if use_pipeline_parallel:
|
|
80
|
+
create_data_info_queue = False
|
|
81
|
+
|
|
74
82
|
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
|
|
75
83
|
create_data_info_queue=create_data_info_queue, queue_name=queue_name)
|
|
76
84
|
_cell_graph_executor.init_dataset(exec_dataset.queue_name,
|
|
@@ -320,6 +328,11 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
|
|
|
320
328
|
"""
|
|
321
329
|
if isinstance(layout_obj, str):
|
|
322
330
|
parameter_layout = parse_strategy_ckpt(layout_obj)
|
|
331
|
+
elif isinstance(layout_obj, Cell):
|
|
332
|
+
from mindspore.communication.management import get_process_group_ranks
|
|
333
|
+
groups_ranks = (tuple(get_process_group_ranks()),)
|
|
334
|
+
param_redundancy_dict = {param.name: groups_ranks for _, param in layout_obj.parameters_and_names()}
|
|
335
|
+
return param_redundancy_dict
|
|
323
336
|
else:
|
|
324
337
|
parameter_layout = {}
|
|
325
338
|
for k, v in layout_obj.items():
|
|
@@ -338,12 +351,25 @@ def get_parameter_redundancy(layout_obj, initial_rank=0):
|
|
|
338
351
|
locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist()
|
|
339
352
|
redundancy_dict = {}
|
|
340
353
|
for index, locate in enumerate(locate_list):
|
|
341
|
-
redundancy_dict.setdefault(tuple(locate), []).append(index+initial_rank)
|
|
354
|
+
redundancy_dict.setdefault(tuple(locate), []).append(index + initial_rank)
|
|
342
355
|
redundancy_list = []
|
|
343
356
|
for _, indices in sorted(redundancy_dict.items()):
|
|
344
357
|
redundancy_list.append(tuple(indices))
|
|
345
|
-
|
|
346
358
|
param_redundancy_dict[key] = tuple(redundancy_list)
|
|
359
|
+
if isinstance(layout_obj, str):
|
|
360
|
+
return param_redundancy_dict
|
|
361
|
+
|
|
362
|
+
for key, value in layout_obj.items():
|
|
363
|
+
if value[5]:
|
|
364
|
+
world_groups = ("hccl_world_group", "nccl_world_group", "mccl_world_group")
|
|
365
|
+
opt_para_num = int(value[5][0]) if value[5] not in world_groups else get_group_size()
|
|
366
|
+
param_redundancy_ranks = param_redundancy_dict.get(key)
|
|
367
|
+
res = []
|
|
368
|
+
for param_ranks in param_redundancy_ranks:
|
|
369
|
+
if len(param_ranks) % opt_para_num == 0:
|
|
370
|
+
for i in range(0, opt_para_num):
|
|
371
|
+
res.append(param_ranks[i::opt_para_num])
|
|
372
|
+
param_redundancy_dict[key] = tuple(res)
|
|
347
373
|
return param_redundancy_dict
|
|
348
374
|
|
|
349
375
|
|
mindspore/train/amp.py
CHANGED
|
@@ -16,6 +16,9 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
import inspect
|
|
18
18
|
import types
|
|
19
|
+
from typing import Any
|
|
20
|
+
import functools
|
|
21
|
+
import collections
|
|
19
22
|
|
|
20
23
|
import mindspore as ms
|
|
21
24
|
from mindspore import nn
|
|
@@ -29,8 +32,9 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScal
|
|
|
29
32
|
from mindspore import boost, context
|
|
30
33
|
from mindspore.ops import operations as P
|
|
31
34
|
from mindspore.ops import Primitive
|
|
35
|
+
from mindspore.ops import auto_generate as gen
|
|
32
36
|
from mindspore import log as logger
|
|
33
|
-
|
|
37
|
+
from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel
|
|
34
38
|
|
|
35
39
|
AMP_WHITE_LIST = [
|
|
36
40
|
nn.Conv1d,
|
|
@@ -52,17 +56,67 @@ AMP_WHITE_LIST = [
|
|
|
52
56
|
P.BatchMatMul,
|
|
53
57
|
P.PReLU,
|
|
54
58
|
P.ReLU,
|
|
55
|
-
P.Ger
|
|
59
|
+
P.Ger,
|
|
56
60
|
]
|
|
57
61
|
|
|
58
|
-
|
|
59
62
|
AMP_BLACK_LIST = [
|
|
60
63
|
nn.BatchNorm1d,
|
|
61
64
|
nn.BatchNorm2d,
|
|
62
65
|
nn.BatchNorm3d,
|
|
63
|
-
nn.LayerNorm
|
|
66
|
+
nn.LayerNorm,
|
|
64
67
|
]
|
|
65
68
|
|
|
69
|
+
AMP_AUTO_WHITE_LIST = [
|
|
70
|
+
P.Conv2D,
|
|
71
|
+
P.Conv3D,
|
|
72
|
+
P.Conv2DTranspose,
|
|
73
|
+
P.Conv3DTranspose,
|
|
74
|
+
gen.Convolution,
|
|
75
|
+
P.MatMul,
|
|
76
|
+
gen.MatMulExt,
|
|
77
|
+
P.BatchMatMul,
|
|
78
|
+
gen.BatchMatMulExt,
|
|
79
|
+
gen.PReLU,
|
|
80
|
+
P.Einsum,
|
|
81
|
+
gen.Dense,
|
|
82
|
+
gen.Addmm,
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
AMP_AUTO_BLACK_LIST = [
|
|
86
|
+
gen.Pow,
|
|
87
|
+
gen.ACos,
|
|
88
|
+
gen.Asin,
|
|
89
|
+
gen.Cosh,
|
|
90
|
+
P.Erfinv,
|
|
91
|
+
P.Exp,
|
|
92
|
+
P.Expm1,
|
|
93
|
+
P.Log,
|
|
94
|
+
P.Log1p,
|
|
95
|
+
P.Reciprocal,
|
|
96
|
+
P.Rsqrt,
|
|
97
|
+
P.Sinh,
|
|
98
|
+
P.Tan,
|
|
99
|
+
P.Softplus,
|
|
100
|
+
gen.SoftplusExt,
|
|
101
|
+
P.LayerNorm,
|
|
102
|
+
gen.LayerNormExt,
|
|
103
|
+
P.BatchNorm,
|
|
104
|
+
gen.GroupNorm,
|
|
105
|
+
P.KLDivLoss,
|
|
106
|
+
P.SmoothL1Loss,
|
|
107
|
+
P.MultilabelMarginLoss,
|
|
108
|
+
P.SoftMarginLoss,
|
|
109
|
+
P.TripletMarginLoss,
|
|
110
|
+
P.MultiMarginLoss,
|
|
111
|
+
P.BCEWithLogitsLoss,
|
|
112
|
+
P.Pdist,
|
|
113
|
+
P.Cdist,
|
|
114
|
+
P.Renorm,
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
# Indicates which inputs of primitives need to be converted
|
|
118
|
+
AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})
|
|
119
|
+
|
|
66
120
|
# Primitives in inner amp black list will not be converted in O2/O3
|
|
67
121
|
_INNER_AMP_BLACK_LIST = []
|
|
68
122
|
|
|
@@ -302,6 +356,42 @@ def _auto_black_list(network, black_list, dtype):
|
|
|
302
356
|
return network
|
|
303
357
|
|
|
304
358
|
|
|
359
|
+
class amp_decorator:
|
|
360
|
+
"""
|
|
361
|
+
Auto mixed precision decorator.
|
|
362
|
+
Type of lists: List[Tuple[str, List[int]]]
|
|
363
|
+
"""
|
|
364
|
+
def __init__(self, amp_level, amp_dtype, white_list, black_list):
|
|
365
|
+
self.amp_level = amp_level
|
|
366
|
+
self.amp_dtype = amp_dtype
|
|
367
|
+
self.white_list = white_list
|
|
368
|
+
self.black_list = black_list
|
|
369
|
+
|
|
370
|
+
def __enter__(self):
|
|
371
|
+
push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)
|
|
372
|
+
|
|
373
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
374
|
+
pop_amp_strategy()
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
|
|
378
|
+
"""
|
|
379
|
+
Set auto mixed precision context decorator for object.
|
|
380
|
+
Type of lists: List[Tuple[str, List[int]]]
|
|
381
|
+
"""
|
|
382
|
+
if inspect.isfunction(obj) or inspect.ismethod(obj):
|
|
383
|
+
@functools.wraps(obj)
|
|
384
|
+
def wrapper(*args, **kwargs):
|
|
385
|
+
with amp_decorator(amp_level, amp_dtype, white_list, black_list):
|
|
386
|
+
return obj(*args, **kwargs)
|
|
387
|
+
return wrapper
|
|
388
|
+
if isinstance(obj, nn.Cell):
|
|
389
|
+
obj.construct = types.MethodType(
|
|
390
|
+
_set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
|
|
391
|
+
return obj
|
|
392
|
+
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")
|
|
393
|
+
|
|
394
|
+
|
|
305
395
|
def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
306
396
|
"""
|
|
307
397
|
Returns a network processed with auto mixed precision.
|
|
@@ -312,26 +402,44 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
312
402
|
converted to lower precision float, and calculation results are converted back to full precision float,
|
|
313
403
|
i.e. ``mstype.float32`` .
|
|
314
404
|
|
|
315
|
-
The
|
|
316
|
-
operators are specifically converted.
|
|
405
|
+
The `amp_level` and its corresponding lists determine which cells and operators are converted.
|
|
317
406
|
|
|
318
|
-
|
|
407
|
+
When `amp_level` is set to ``O0``, no cells and operators are converted.
|
|
319
408
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
:class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
|
|
323
|
-
:class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
|
|
324
|
-
:class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
|
|
325
|
-
:class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
|
|
326
|
-
:class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
|
|
409
|
+
When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision
|
|
410
|
+
operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`.
|
|
327
411
|
|
|
328
|
-
|
|
412
|
+
When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the
|
|
413
|
+
list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`.
|
|
329
414
|
|
|
330
|
-
|
|
331
|
-
|
|
415
|
+
When `amp_level` is set to ``O3``, all cells will be converted to low precision.
|
|
416
|
+
|
|
417
|
+
When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision
|
|
418
|
+
operations, operators in `auto_blacklist` will be converted to full precision operations, operators in
|
|
419
|
+
`promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators
|
|
420
|
+
not listed will run in the type defined by their inputs.
|
|
421
|
+
|
|
422
|
+
Operators in `auto_whitelist` are:
|
|
423
|
+
|
|
424
|
+
``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``,
|
|
425
|
+
``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm``
|
|
426
|
+
|
|
427
|
+
Operators in `auto_blacklist` are:
|
|
428
|
+
|
|
429
|
+
``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
|
|
430
|
+
``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
|
|
431
|
+
``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
|
|
432
|
+
``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
|
|
433
|
+
``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
|
|
434
|
+
``Norm``
|
|
435
|
+
|
|
436
|
+
Operators in `promote_list` are:
|
|
437
|
+
|
|
438
|
+
``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
|
|
439
|
+
``BiasAdd``
|
|
332
440
|
|
|
333
441
|
For details on automatic mixed precision, refer to
|
|
334
|
-
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/
|
|
442
|
+
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
|
|
335
443
|
|
|
336
444
|
Note:
|
|
337
445
|
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
@@ -339,10 +447,18 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
339
447
|
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
340
448
|
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
341
449
|
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
450
|
+
- When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you
|
|
451
|
+
may need to manually convert the type to avoid type inconsistency errors of the loss function.
|
|
452
|
+
- When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy
|
|
453
|
+
specified by `to_float` takes effect first.
|
|
454
|
+
|
|
455
|
+
.. warning::
|
|
456
|
+
``auto`` level of `amp_level` is an experimental API that is subject to change or deletion.
|
|
342
457
|
|
|
343
458
|
Args:
|
|
344
|
-
network (Cell): Definition of the network.
|
|
345
|
-
|
|
459
|
+
network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level`
|
|
460
|
+
is set to ``auto`` .
|
|
461
|
+
amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
346
462
|
|
|
347
463
|
- "O0": Do not change.
|
|
348
464
|
- "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
|
|
@@ -350,12 +466,16 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
350
466
|
- "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
|
|
351
467
|
to lower precision operations.
|
|
352
468
|
- "O3": Cast network to lower precision.
|
|
469
|
+
- "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in
|
|
470
|
+
`auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted
|
|
471
|
+
to the higher accuracy float type of the operator inputs, and operators not listed will run in the
|
|
472
|
+
type defined by their inputs.
|
|
353
473
|
|
|
354
474
|
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
355
475
|
default: ``mstype.float16`` .
|
|
356
476
|
|
|
357
477
|
Raises:
|
|
358
|
-
TypeError: If `network` is not a Cell.
|
|
478
|
+
TypeError: If `network` is not a Cell or a function.
|
|
359
479
|
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
360
480
|
ValueError: If `amp_level` is not within the supported range.
|
|
361
481
|
|
|
@@ -368,7 +488,12 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
368
488
|
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
369
489
|
"""
|
|
370
490
|
if not isinstance(network, nn.Cell):
|
|
371
|
-
|
|
491
|
+
if amp_level == "auto":
|
|
492
|
+
if not inspect.isfunction(network) and not inspect.ismethod(network):
|
|
493
|
+
raise TypeError("For amp_level 'auto', the network type should be Cell or function.")
|
|
494
|
+
# function is supported for amp_level 'auto'
|
|
495
|
+
else:
|
|
496
|
+
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.")
|
|
372
497
|
|
|
373
498
|
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
374
499
|
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
@@ -377,7 +502,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
377
502
|
return network
|
|
378
503
|
|
|
379
504
|
# Return network if the same amp level has already been configurated
|
|
380
|
-
if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
|
|
505
|
+
if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"):
|
|
381
506
|
logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
|
|
382
507
|
f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
|
|
383
508
|
f"degradation.")
|
|
@@ -396,8 +521,16 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
396
521
|
else:
|
|
397
522
|
network.to_float(dtype)
|
|
398
523
|
network = _OutputTo32(network)
|
|
524
|
+
elif amp_level == "auto":
|
|
525
|
+
white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST]
|
|
526
|
+
black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST]
|
|
527
|
+
# set amp_strategy attribute for the object
|
|
528
|
+
amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list)
|
|
529
|
+
setattr(network, "amp_strategy", amp_strategy)
|
|
530
|
+
# set amp_strategy context decorator for the object
|
|
531
|
+
network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list)
|
|
399
532
|
else:
|
|
400
|
-
raise ValueError("The amp level {} is not supported"
|
|
533
|
+
raise ValueError(f"The amp level {amp_level} is not supported")
|
|
401
534
|
|
|
402
535
|
setattr(network, "_amp_level", amp_level)
|
|
403
536
|
|
|
@@ -437,6 +570,10 @@ _config_level = {
|
|
|
437
570
|
"O3": {
|
|
438
571
|
"keep_batchnorm_fp32": False,
|
|
439
572
|
"cast_model_type": mstype.float16,
|
|
573
|
+
"loss_scale_manager": None},
|
|
574
|
+
"auto": {
|
|
575
|
+
"keep_batchnorm_fp32": False,
|
|
576
|
+
"cast_model_type": mstype.float32,
|
|
440
577
|
"loss_scale_manager": None}}
|
|
441
578
|
|
|
442
579
|
|
|
@@ -461,20 +598,11 @@ def _check_kwargs(key_words):
|
|
|
461
598
|
def _check_level(level, boost_level):
|
|
462
599
|
"""Check level."""
|
|
463
600
|
if not isinstance(level, str):
|
|
464
|
-
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],
|
|
465
|
-
|
|
601
|
+
raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],"
|
|
602
|
+
f"but got type {type(level)}.")
|
|
466
603
|
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
|
|
467
604
|
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
|
|
468
605
|
|
|
469
|
-
if level == "auto":
|
|
470
|
-
device_target = context.get_context('device_target')
|
|
471
|
-
if device_target == "GPU":
|
|
472
|
-
level = "O2"
|
|
473
|
-
elif device_target == "Ascend":
|
|
474
|
-
level = "O3"
|
|
475
|
-
else:
|
|
476
|
-
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
|
|
477
|
-
|
|
478
606
|
enable_boost = False
|
|
479
607
|
if boost_level in ["O1", "O2"]:
|
|
480
608
|
enable_boost = True
|
|
@@ -499,7 +627,8 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
499
627
|
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
|
500
628
|
|
|
501
629
|
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
|
502
|
-
if cast_model_type
|
|
630
|
+
if cast_model_type in (mstype.float16, mstype.bfloat16) or \
|
|
631
|
+
(hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")):
|
|
503
632
|
network = WithLossCell(network, loss_fn)
|
|
504
633
|
else:
|
|
505
634
|
network = nn.WithLossCell(network, loss_fn)
|
|
@@ -555,20 +684,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
555
684
|
Default: ``None`` .
|
|
556
685
|
level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
|
|
557
686
|
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
- 'O2': Cast network to float16, keep `mindspore.nn.BatchNorm` series interface,
|
|
563
|
-
:class:`mindspore.nn.LayerNorm` and `loss_fn` (if set) run in float32, using dynamic loss scale.
|
|
564
|
-
- 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
|
565
|
-
- 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
|
|
566
|
-
level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
|
|
567
|
-
scenarios. User should specify the level for special network.
|
|
568
|
-
|
|
569
|
-
'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
|
|
570
|
-
`cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
|
|
571
|
-
`kwargs`.
|
|
687
|
+
For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`.
|
|
688
|
+
|
|
689
|
+
Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level`
|
|
690
|
+
setting may be overwritten by settings in `kwargs`.
|
|
572
691
|
|
|
573
692
|
boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
|
|
574
693
|
training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
|
|
@@ -649,7 +768,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
649
768
|
|
|
650
769
|
def get_white_list():
|
|
651
770
|
"""
|
|
652
|
-
Provide a copy of internal white list used by auto mixed precision
|
|
771
|
+
Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``.
|
|
653
772
|
|
|
654
773
|
The current built-in whitelist contents are:
|
|
655
774
|
|
|
@@ -687,7 +806,7 @@ def get_white_list():
|
|
|
687
806
|
|
|
688
807
|
def get_black_list():
|
|
689
808
|
"""
|
|
690
|
-
Provide a copy of internal black list used by auto mixed precision
|
|
809
|
+
Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``.
|
|
691
810
|
|
|
692
811
|
The current built-in blacklist contents are:
|
|
693
812
|
|
|
@@ -710,7 +829,6 @@ def get_black_list():
|
|
|
710
829
|
|
|
711
830
|
def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
|
|
712
831
|
"""
|
|
713
|
-
Custom mixed precision by setting whitelist or blacklist.
|
|
714
832
|
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
715
833
|
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
|
|
716
834
|
Only one of `white_list` and `black_list` should be provided.
|
|
@@ -36,9 +36,9 @@ from mindspore.train.callback._reduce_lr_on_plateau import ReduceLROnPlateau
|
|
|
36
36
|
from mindspore.train.callback._on_request_exit import OnRequestExit
|
|
37
37
|
from mindspore.train.callback._backup_and_restore import BackupAndRestore
|
|
38
38
|
from mindspore.train.callback._flops_collector import FlopsUtilizationCollector
|
|
39
|
-
from mindspore.train.callback.
|
|
39
|
+
from mindspore.train.callback._tft_register import TFTRegister
|
|
40
40
|
|
|
41
41
|
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "FlopsUtilizationCollector",
|
|
42
42
|
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
|
|
43
43
|
"History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping", "OnRequestExit", "BackupAndRestore",
|
|
44
|
-
"
|
|
44
|
+
"TFTRegister"]
|
|
@@ -123,7 +123,7 @@ class Callback:
|
|
|
123
123
|
recording current attributes. Users can add custimized attributes to the information.
|
|
124
124
|
Training process can also be stopped by calling `request_stop` method. For details
|
|
125
125
|
of custom Callback, please check
|
|
126
|
-
`Callback tutorial <https://www.mindspore.cn/
|
|
126
|
+
`Callback tutorial <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/
|
|
127
127
|
callback.html#customized-callback-mechanism>`_.
|
|
128
128
|
|
|
129
129
|
Examples:
|
|
@@ -493,7 +493,7 @@ class RunContext:
|
|
|
493
493
|
`RunContext.original_args()` and add extra attributes to the information, but also can stop the
|
|
494
494
|
training process by calling `request_stop` method. For details of custom Callback,
|
|
495
495
|
please check
|
|
496
|
-
`Callback Mechanism <https://www.mindspore.cn/
|
|
496
|
+
`Callback Mechanism <https://www.mindspore.cn/docs/en/master/model_train/train_process/model/callback.html>`_.
|
|
497
497
|
|
|
498
498
|
`RunContext.original_args()` holds the model context information as a dictionary variable, and
|
|
499
499
|
different attributes of the dictionary are stored in training or eval process. Details are as follows:
|
|
@@ -575,7 +575,7 @@ class RunContext:
|
|
|
575
575
|
|
|
576
576
|
Tutorial Examples:
|
|
577
577
|
- `Callback Mechanism - Customized Callback Mechanism
|
|
578
|
-
<https://mindspore.cn/
|
|
578
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#customized-callback-mechanism>`_
|
|
579
579
|
"""
|
|
580
580
|
return self._original_args
|
|
581
581
|
|
|
@@ -588,7 +588,7 @@ class RunContext:
|
|
|
588
588
|
|
|
589
589
|
Tutorial Examples:
|
|
590
590
|
- `Callback Mechanism - Customized Training Termination Time
|
|
591
|
-
<https://mindspore.cn/
|
|
591
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/model/callback.html#
|
|
592
592
|
customized-training-termination-time>`_
|
|
593
593
|
"""
|
|
594
594
|
self._stop_requested = True
|