mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Checkpoint related classes and functions."""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from mindspore.train.serialization import save_checkpoint
|
|
19
|
+
from mindspore.parallel._utils import _get_device_num
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
|
+
from mindspore.train.callback._callback import Callback
|
|
22
|
+
from mindspore import context
|
|
23
|
+
from mindspore.common.parameter import Parameter
|
|
24
|
+
from mindspore.common.tensor import Tensor
|
|
25
|
+
from mindspore.communication import get_rank, get_group_size
|
|
26
|
+
from mindspore import log as logger
|
|
27
|
+
from mindspore.train.serialization import _get_cur_rank_dp
|
|
28
|
+
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
|
|
29
|
+
from mindspore._c_expression import clean_tdt_channel
|
|
30
|
+
from mindspore._c_expression import send_recv
|
|
31
|
+
from mindspore._c_expression import CollectiveManager
|
|
32
|
+
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
33
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
34
|
+
import mindspore
|
|
35
|
+
import mindspore.common.dtype as mstype
|
|
36
|
+
|
|
37
|
+
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
38
|
+
""" Common func to generate ckpt dir name."""
|
|
39
|
+
tmp = "_tmp" if is_tmp_file else ""
|
|
40
|
+
mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
|
|
41
|
+
return os.path.join(ckpt_save_path, mid_dir)
|
|
42
|
+
|
|
43
|
+
def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
44
|
+
""" Callback used for TFT save ckpt function when errors occur."""
|
|
45
|
+
logger.info("Enter _save_checkpoint_on_failure function")
|
|
46
|
+
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
|
|
47
|
+
raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
|
|
48
|
+
|
|
49
|
+
ckpt_save_path = cb_ctx.ckpt_save_path
|
|
50
|
+
cb_params = args
|
|
51
|
+
cur_rank = get_rank()
|
|
52
|
+
cur_step_num = cb_params.cur_step_num
|
|
53
|
+
cur_epoch_num = cb_params.cur_epoch_num
|
|
54
|
+
batch_num = cb_params.batch_num
|
|
55
|
+
if cur_step_num > step:
|
|
56
|
+
cur_epoch_num = (step - 1) // batch_num + 1
|
|
57
|
+
step_num_in_epoch = int((step - 1) % batch_num + 1)
|
|
58
|
+
|
|
59
|
+
append_dict = {}
|
|
60
|
+
append_dict["epoch_num"] = cur_epoch_num
|
|
61
|
+
append_dict["step_num"] = step
|
|
62
|
+
append_dict["cur_rank"] = cur_rank
|
|
63
|
+
append_dict["batch_num"] = batch_num
|
|
64
|
+
append_dict["__exception_save__"] = True
|
|
65
|
+
|
|
66
|
+
append_dict["global_step"] = Parameter([cb_ctx.global_step])
|
|
67
|
+
outputs = cb_params.net_outputs
|
|
68
|
+
if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
|
|
69
|
+
append_dict["loss_scale"] = outputs[2]
|
|
70
|
+
|
|
71
|
+
ckpt_file = f"ttp_rank_{str(cur_rank)}-{str(cur_epoch_num)}_{str(step_num_in_epoch)}.ckpt"
|
|
72
|
+
cur_ckpt_dir = _get_ckpt_dir(step, ckpt_save_path, True) + "/rank_" + str(cur_rank)
|
|
73
|
+
os.makedirs(cur_ckpt_dir, exist_ok=True)
|
|
74
|
+
cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
|
|
75
|
+
save_checkpoint(cb_params.train_network, cur_file,
|
|
76
|
+
integrated_save=False, append_dict=append_dict)
|
|
77
|
+
logger.info("Finish _save_checkpoint_on_failure function")
|
|
78
|
+
|
|
79
|
+
def _rename_save_result(step, cb_ctx):
|
|
80
|
+
""" Callback used for TFT rename function after ckpt save callback was finished and successful."""
|
|
81
|
+
logger.info("Enter _rename_save_result function")
|
|
82
|
+
tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
|
|
83
|
+
fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
|
|
84
|
+
|
|
85
|
+
os.rename(tmp_dir, fin_dir)
|
|
86
|
+
logger.info("Finish _rename_save_result function")
|
|
87
|
+
|
|
88
|
+
def _tft_exit_cb(ctx):
|
|
89
|
+
logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
|
|
90
|
+
_tft_sem_post()
|
|
91
|
+
os._exit(1) # pylint: disable=W0212
|
|
92
|
+
|
|
93
|
+
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
94
|
+
""" Callback used for TFT repair function."""
|
|
95
|
+
logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
|
|
96
|
+
if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
|
|
97
|
+
or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
|
|
98
|
+
logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
|
|
99
|
+
_repair_device(cb_ctx.device_id)
|
|
100
|
+
|
|
101
|
+
if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
|
|
102
|
+
or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
|
|
103
|
+
logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
|
|
104
|
+
{}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
|
|
105
|
+
cb_params = args
|
|
106
|
+
src_rank = repair_info["src"][0]
|
|
107
|
+
dst_rank = repair_info["dst"][0]
|
|
108
|
+
send_recv(cb_params.network.trainable_params(), src_rank, dst_rank)
|
|
109
|
+
logger.info("Finish _tft_repair_callback")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _tft_clean_callback(is_uce_error, ctx):
|
|
113
|
+
""" Callback used for TFT clean function."""
|
|
114
|
+
logger.info("Enter _tft_clean_callback")
|
|
115
|
+
ret = 0
|
|
116
|
+
if is_uce_error:
|
|
117
|
+
_get_uce_mem_info(ctx.device_id)
|
|
118
|
+
err_strategy = _get_uce_process_strategy()
|
|
119
|
+
logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
|
|
120
|
+
if err_strategy == "RS_UCE_HIGHLEVEL":
|
|
121
|
+
ret = 0
|
|
122
|
+
elif err_strategy == "RS_UCE_LOWLEVEL":
|
|
123
|
+
ret = 2
|
|
124
|
+
else:
|
|
125
|
+
ret = 1
|
|
126
|
+
clean_tdt_channel()
|
|
127
|
+
logger.info("Enter _tft_clean_callback resume_hccl_comm")
|
|
128
|
+
CollectiveManager.get_instance().resume_hccl_comm()
|
|
129
|
+
logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
130
|
+
return ret
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _tft_stop_callback(cb_ctx):
|
|
134
|
+
""" Callback used for TFT stop function."""
|
|
135
|
+
logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
|
|
136
|
+
_stop_device(cb_ctx.device_id)
|
|
137
|
+
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
|
|
138
|
+
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
139
|
+
logger.info("Finish _tft_stop_callback")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class TFTRegister(Callback):
|
|
143
|
+
"""
|
|
144
|
+
This callback is used to enable the TFT feature
|
|
145
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
|
|
146
|
+
This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
147
|
+
|
|
148
|
+
Note:
|
|
149
|
+
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
|
|
153
|
+
ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
|
|
154
|
+
ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
|
|
155
|
+
ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
|
|
156
|
+
named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
Exception: TFT init failed.
|
|
160
|
+
ModuleNotFoundError: Mindio TFT whl package is not installed.
|
|
161
|
+
|
|
162
|
+
Examples:
|
|
163
|
+
>>> import numpy as np
|
|
164
|
+
>>> import os
|
|
165
|
+
>>> import math
|
|
166
|
+
>>> import mindspore as ms
|
|
167
|
+
>>> import mindspore.dataset as ds
|
|
168
|
+
>>> from mindspore import nn, ops, Parameter, train
|
|
169
|
+
>>> from mindspore.communication import init
|
|
170
|
+
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
171
|
+
>>> from mindspore.train import Model, TFTRegister
|
|
172
|
+
>>> from mindspore import dataset as ds
|
|
173
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
|
|
174
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
|
|
175
|
+
>>> init()
|
|
176
|
+
>>> ms.set_seed(1)
|
|
177
|
+
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
|
|
178
|
+
>>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
|
|
179
|
+
>>> class MatMulCell(nn.Cell):
|
|
180
|
+
... def __init__(self, param=None, shape=None):
|
|
181
|
+
... super().__init__()
|
|
182
|
+
... if shape is None:
|
|
183
|
+
... shape = [28 * 28, 512]
|
|
184
|
+
... weight_init = HeUniform(math.sqrt(5))
|
|
185
|
+
... self.param = Parameter(initializer(weight_init, shape), name="param")
|
|
186
|
+
... if param is not None:
|
|
187
|
+
... self.param = param
|
|
188
|
+
... self.print = ops.Print()
|
|
189
|
+
... self.matmul = ops.MatMul()
|
|
190
|
+
...
|
|
191
|
+
... def construct(self, x):
|
|
192
|
+
... out = self.matmul(x, self.param)
|
|
193
|
+
... self.print("out is:", out)
|
|
194
|
+
... return out
|
|
195
|
+
>>>
|
|
196
|
+
>>> class Network(nn.Cell):
|
|
197
|
+
... def __init__(self):
|
|
198
|
+
... super().__init__()
|
|
199
|
+
... self.flatten = nn.Flatten()
|
|
200
|
+
... self.layer1 = MatMulCell()
|
|
201
|
+
... self.relu1 = nn.ReLU()
|
|
202
|
+
... self.layer2 = nn.Dense(512, 512)
|
|
203
|
+
... self.relu2 = nn.ReLU()
|
|
204
|
+
... self.layer3 = nn.Dense(512, 10)
|
|
205
|
+
...
|
|
206
|
+
... def construct(self, x):
|
|
207
|
+
... x = self.flatten(x)
|
|
208
|
+
... x = self.layer1(x)
|
|
209
|
+
... x = self.relu1(x)
|
|
210
|
+
... x = self.layer2(x)
|
|
211
|
+
... x = self.relu2(x)
|
|
212
|
+
... logits = self.layer3(x)
|
|
213
|
+
... return logits
|
|
214
|
+
>>>
|
|
215
|
+
>>> net = Network()
|
|
216
|
+
>>> net.layer1.pipeline_stage = 0
|
|
217
|
+
>>> net.relu1.pipeline_stage = 0
|
|
218
|
+
>>> net.layer2.pipeline_stage = 0
|
|
219
|
+
>>> net.relu2.pipeline_stage = 1
|
|
220
|
+
>>> net.layer3.pipeline_stage = 1
|
|
221
|
+
>>>
|
|
222
|
+
>>> def create_dataset(batch_size):
|
|
223
|
+
... dataset_path = os.getenv("DATA_PATH")
|
|
224
|
+
... dataset = ds.MnistDataset(dataset_path)
|
|
225
|
+
... image_transforms = [
|
|
226
|
+
... ds.vision.Rescale(1.0 / 255.0, 0),
|
|
227
|
+
... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
|
228
|
+
... ds.vision.HWC2CHW()
|
|
229
|
+
... ]
|
|
230
|
+
... label_transform = ds.transforms.TypeCast(ms.int32)
|
|
231
|
+
... dataset = dataset.map(image_transforms, 'image')
|
|
232
|
+
... dataset = dataset.map(label_transform, 'label')
|
|
233
|
+
... dataset = dataset.batch(batch_size)
|
|
234
|
+
... return dataset
|
|
235
|
+
>>>
|
|
236
|
+
>>> data_set = create_dataset(32)
|
|
237
|
+
>>>
|
|
238
|
+
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
239
|
+
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
|
|
240
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
241
|
+
>>>
|
|
242
|
+
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
|
|
243
|
+
>>> net_with_loss.set_train()
|
|
244
|
+
>>> model = Model(net_with_loss, optimizer=optimizer)
|
|
245
|
+
>>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
|
|
246
|
+
>>> loss_cb = train.LossMonitor(1)
|
|
247
|
+
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
|
|
251
|
+
super(TFTRegister, self).__init__()
|
|
252
|
+
|
|
253
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
254
|
+
if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
|
|
255
|
+
raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
|
|
256
|
+
mode = context.get_context("mode")
|
|
257
|
+
device_target = context.get_context("device_target")
|
|
258
|
+
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
259
|
+
raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
|
|
260
|
+
|
|
261
|
+
# let it raise errors if not install mindio_tft package
|
|
262
|
+
from mindio_ttp import framework_ttp as tft
|
|
263
|
+
self.tft = tft
|
|
264
|
+
self.global_step = 0
|
|
265
|
+
Validator.check_non_negative_int(ctrl_port)
|
|
266
|
+
self.has_init_replica = False
|
|
267
|
+
self._controller_ip = ctrl_ip
|
|
268
|
+
self._controller_rank_id = ctrl_rank_id
|
|
269
|
+
self._controller_port = ctrl_port
|
|
270
|
+
self.cb_params = None
|
|
271
|
+
self.device_id = context.get_context("device_id")
|
|
272
|
+
self._init_tft()
|
|
273
|
+
self.ckpt_save_path = ckpt_save_path
|
|
274
|
+
self.assign = mindspore.ops.Assign()
|
|
275
|
+
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
|
|
276
|
+
self.s1 = mindspore.hal.Stream()
|
|
277
|
+
|
|
278
|
+
def _is_params_consistent(self):
|
|
279
|
+
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
280
|
+
if "tft_g_one_flag" in key:
|
|
281
|
+
with mindspore.hal.StreamCtx(self.s1):
|
|
282
|
+
tft_g_one_flag = Tensor(Tensor_.move_to(param, "CPU", False))
|
|
283
|
+
self.s1.synchronize()
|
|
284
|
+
return int(tft_g_one_flag) == 1
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
def _set_tft_optimizer_replica(self, run_context):
|
|
288
|
+
""" set Mindio TFT optimizer replica info, used internal. """
|
|
289
|
+
cur_rank = get_rank()
|
|
290
|
+
cb_params = run_context.original_args()
|
|
291
|
+
train_network = cb_params.train_network
|
|
292
|
+
# in data_parallel mode, every ranks has same train parameters
|
|
293
|
+
if context.get_auto_parallel_context("parallel_mode") == "data_parallel":
|
|
294
|
+
group_size = get_group_size()
|
|
295
|
+
dp = tuple(range(group_size))
|
|
296
|
+
else:
|
|
297
|
+
param_layout_dict = train_network.parameter_layout_dict
|
|
298
|
+
dp = _get_cur_rank_dp(param_layout_dict) if param_layout_dict else _get_cur_rank_dp(train_network)
|
|
299
|
+
logger.warning(f"Set TFT replica with dp: {dp}.")
|
|
300
|
+
replica_info = [
|
|
301
|
+
{
|
|
302
|
+
"type": 1,
|
|
303
|
+
"rank_list": dp,
|
|
304
|
+
"replica_cnt": len(dp),
|
|
305
|
+
"replica_shift": 0
|
|
306
|
+
}
|
|
307
|
+
]
|
|
308
|
+
self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
|
|
309
|
+
|
|
310
|
+
def _init_tft(self):
|
|
311
|
+
""" Init Mindio TFT, used internal. """
|
|
312
|
+
logger.info("Begin to init tft.")
|
|
313
|
+
self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
|
|
314
|
+
self.tft.tft_register_rename_handler(_rename_save_result, self)
|
|
315
|
+
self.tft.tft_register_exit_handler(_tft_exit_cb, self)
|
|
316
|
+
self.tft.tft_register_stop_handler(_tft_stop_callback, self)
|
|
317
|
+
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
|
|
318
|
+
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
|
|
319
|
+
|
|
320
|
+
world_size = _get_device_num()
|
|
321
|
+
cur_rank = get_rank()
|
|
322
|
+
enable_local_copy = False
|
|
323
|
+
enable_arf = False
|
|
324
|
+
enable_zit = False
|
|
325
|
+
enable_tls = False
|
|
326
|
+
tls_key_dir = ""
|
|
327
|
+
|
|
328
|
+
if cur_rank == self._controller_rank_id:
|
|
329
|
+
logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
|
|
330
|
+
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf, enable_zit)
|
|
331
|
+
self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
|
|
332
|
+
logger.info("Finish start tft controller.")
|
|
333
|
+
|
|
334
|
+
logger.info("Begin to start tft processor.")
|
|
335
|
+
self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
|
|
336
|
+
self.tft.tft_start_processor(self._controller_ip, self._controller_port)
|
|
337
|
+
logger.info("Finished start tft processor.")
|
|
338
|
+
|
|
339
|
+
def on_train_step_end(self, run_context):
|
|
340
|
+
"""
|
|
341
|
+
And report status to MindIO TFT after every step finished.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
345
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
346
|
+
"""
|
|
347
|
+
if self.has_init_replica is False:
|
|
348
|
+
self.has_init_replica = True
|
|
349
|
+
self._set_tft_optimizer_replica(run_context)
|
|
350
|
+
cb_params = run_context.original_args()
|
|
351
|
+
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
352
|
+
self.tft.tft_end_updating_os(cb_params.cur_step_num)
|
|
353
|
+
if cb_params.optimizer is not None:
|
|
354
|
+
self.global_step = int(cb_params.optimizer.global_step.data)
|
|
355
|
+
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
|
|
356
|
+
else:
|
|
357
|
+
self.global_step = int(cb_params.network.optimizer.global_step.data)
|
|
358
|
+
self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
|
|
359
|
+
logger.info("END Set optimizer finish step status to TFT.")
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def on_train_begin(self, run_context):
|
|
363
|
+
cb_params = run_context.original_args()
|
|
364
|
+
sink_size = cb_params.get("sink_size", 0)
|
|
365
|
+
if sink_size > 1:
|
|
366
|
+
raise ValueError("TFT feature doesn't support sink_size > 1.")
|
|
367
|
+
logger.info("Set set args to TFT.")
|
|
368
|
+
self.tft.tft_set_step_args(cb_params)
|
|
369
|
+
self.cb_params = cb_params
|
|
370
|
+
|
|
371
|
+
def end(self, run_context):
|
|
372
|
+
cur_rank = get_rank()
|
|
373
|
+
if cur_rank == self._controller_rank_id:
|
|
374
|
+
self.tft.tft_destroy_controller()
|
|
375
|
+
self.tft.tft_destroy_processor()
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Dataset help for minddata dataset"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
+
import os
|
|
18
19
|
import math
|
|
19
20
|
import copy
|
|
20
21
|
|
|
@@ -213,7 +214,8 @@ def _get_dataset_aux(dataset):
|
|
|
213
214
|
def connect_network_with_dataset(network, dataset_helper):
|
|
214
215
|
"""
|
|
215
216
|
Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
|
|
216
|
-
<https://mindspore.cn/
|
|
217
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
|
|
218
|
+
(dataset_sink_mode=True).
|
|
217
219
|
|
|
218
220
|
Args:
|
|
219
221
|
network (Cell): The training network for dataset.
|
|
@@ -261,7 +263,16 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
261
263
|
"The dataset has been connected to other network, please check the code.")
|
|
262
264
|
is_dynamic = bool(network.get_inputs())
|
|
263
265
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
264
|
-
|
|
266
|
+
# In pipeline parallel, some stages have no GetNext, should not get in.
|
|
267
|
+
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
268
|
+
|
|
269
|
+
# temp env to disable dynamic feature of sink size 1
|
|
270
|
+
dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
|
|
271
|
+
dynamic_sink1 = True
|
|
272
|
+
if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
|
|
273
|
+
dynamic_sink1 = False
|
|
274
|
+
|
|
275
|
+
if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and dynamic_sink1:
|
|
265
276
|
dataset_types, dataset_shapes = dataset_helper.get_data_info()
|
|
266
277
|
# Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
|
|
267
278
|
if _need_to_full():
|
|
@@ -302,7 +313,8 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
302
313
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
|
303
314
|
aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
|
|
304
315
|
|
|
305
|
-
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic)
|
|
316
|
+
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
|
|
317
|
+
not use_pipeline_parallel and dynamic_sink1:
|
|
306
318
|
dataset_helper.get_data_info()
|
|
307
319
|
network.add_flags(sink_mode=True)
|
|
308
320
|
return network
|
|
@@ -200,7 +200,7 @@ class Metric(metaclass=ABCMeta):
|
|
|
200
200
|
|
|
201
201
|
Tutorial Examples:
|
|
202
202
|
- `Evaluation Metrics - Customized Metrics
|
|
203
|
-
<https://mindspore.cn/
|
|
203
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/model/metric.html#customized-metrics>`_
|
|
204
204
|
"""
|
|
205
205
|
raise NotImplementedError('Must define clear function to use this base class')
|
|
206
206
|
|
|
@@ -214,7 +214,7 @@ class Metric(metaclass=ABCMeta):
|
|
|
214
214
|
|
|
215
215
|
Tutorial Examples:
|
|
216
216
|
- `Evaluation Metrics - Customized Metrics
|
|
217
|
-
<https://mindspore.cn/
|
|
217
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/model/metric.html#customized-metrics>`_
|
|
218
218
|
"""
|
|
219
219
|
raise NotImplementedError('Must define eval function to use this base class')
|
|
220
220
|
|
|
@@ -231,7 +231,7 @@ class Metric(metaclass=ABCMeta):
|
|
|
231
231
|
|
|
232
232
|
Tutorial Examples:
|
|
233
233
|
- `Evaluation Metrics - Customized Metrics
|
|
234
|
-
<https://mindspore.cn/
|
|
234
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/model/metric.html#customized-metrics>`_
|
|
235
235
|
"""
|
|
236
236
|
raise NotImplementedError('Must define update function to use this base class')
|
|
237
237
|
|
mindspore/train/metrics/roc.py
CHANGED
|
@@ -42,18 +42,18 @@ class ROC(Metric):
|
|
|
42
42
|
>>> from mindspore.train import ROC
|
|
43
43
|
>>>
|
|
44
44
|
>>> # 1) binary classification example
|
|
45
|
-
>>> x = Tensor(np.array([
|
|
45
|
+
>>> x = Tensor(np.array([0.28, 0.55, 0.15, 0.05]))
|
|
46
46
|
>>> y = Tensor(np.array([0, 1, 2, 3]))
|
|
47
47
|
>>> metric = ROC(pos_label=2)
|
|
48
48
|
>>> metric.clear()
|
|
49
49
|
>>> metric.update(x, y)
|
|
50
50
|
>>> fpr, tpr, thresholds = metric.eval()
|
|
51
51
|
>>> print(fpr)
|
|
52
|
-
[0.
|
|
52
|
+
[0. 0.33333333 0.66666667 0.66666667 1. ]
|
|
53
53
|
>>> print(tpr)
|
|
54
|
-
[0.
|
|
54
|
+
[0. 0. 0. 1. 1.]
|
|
55
55
|
>>> print(thresholds)
|
|
56
|
-
[
|
|
56
|
+
[1.55 0.55 0.28 0.15 0.05]
|
|
57
57
|
>>>
|
|
58
58
|
>>> # 2) multiclass classification example
|
|
59
59
|
>>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],
|