mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -168,6 +168,24 @@ class _AutoParallelContext:
|
|
|
168
168
|
self.check_context_handle()
|
|
169
169
|
return _ParallelFusionConfig.CONFIG
|
|
170
170
|
|
|
171
|
+
def set_dump_local_norm(self, dump_local_norm):
|
|
172
|
+
"""
|
|
173
|
+
Set dump local norm for auto parallel.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
dump_local_norm (bool): User need to specify if he want to dump local norm. Default: False
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
KeyError: When key of comm_fusion is not 'allreduce'.
|
|
180
|
+
"""
|
|
181
|
+
self.check_context_handle()
|
|
182
|
+
self._context_handle.set_dump_local_norm(dump_local_norm)
|
|
183
|
+
|
|
184
|
+
def get_dump_local_norm(self):
|
|
185
|
+
"""Get dump local norm."""
|
|
186
|
+
self.check_context_handle()
|
|
187
|
+
return self._context_handle.get_dump_local_norm()
|
|
188
|
+
|
|
171
189
|
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
|
|
172
190
|
"""
|
|
173
191
|
Set fusion threshold (MB) for auto parallel.
|
|
@@ -584,7 +602,7 @@ class _AutoParallelContext:
|
|
|
584
602
|
self.check_context_handle()
|
|
585
603
|
dir_path = os.path.dirname(strategy_ckpt_save_file)
|
|
586
604
|
if dir_path and not os.path.exists(dir_path):
|
|
587
|
-
os.makedirs(dir_path, exist_ok=True)
|
|
605
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
588
606
|
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
|
|
589
607
|
|
|
590
608
|
def get_strategy_ckpt_save_file(self):
|
|
@@ -643,7 +661,7 @@ class _AutoParallelContext:
|
|
|
643
661
|
self.check_context_handle()
|
|
644
662
|
dir_path = os.path.dirname(group_ckpt_save_file)
|
|
645
663
|
if dir_path and not os.path.exists(dir_path):
|
|
646
|
-
os.makedirs(dir_path)
|
|
664
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
647
665
|
self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
|
|
648
666
|
|
|
649
667
|
def get_parameter_broadcast_is_set(self):
|
|
@@ -1117,9 +1135,9 @@ class _AutoParallelContext:
|
|
|
1117
1135
|
"""
|
|
1118
1136
|
self.check_context_handle()
|
|
1119
1137
|
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
|
1120
|
-
|
|
1138
|
+
self.set_enable_all_gather_fusion(True)
|
|
1121
1139
|
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
|
1122
|
-
|
|
1140
|
+
self.set_enable_reduce_scatter_fusion(True)
|
|
1123
1141
|
if not isinstance(comm_fusion, dict):
|
|
1124
1142
|
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
|
1125
1143
|
comm_type, type(comm_fusion)))
|
|
@@ -1153,7 +1171,7 @@ class _AutoParallelContext:
|
|
|
1153
1171
|
"""
|
|
1154
1172
|
self.check_context_handle()
|
|
1155
1173
|
if not self.get_enable_all_reduce_fusion():
|
|
1156
|
-
|
|
1174
|
+
self.set_enable_all_reduce_fusion(True)
|
|
1157
1175
|
if not isinstance(comm_fusion, dict):
|
|
1158
1176
|
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
|
|
1159
1177
|
type(comm_fusion)))
|
|
@@ -1210,7 +1228,7 @@ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
|
|
|
1210
1228
|
"""
|
|
1211
1229
|
dir_path = os.path.dirname(path)
|
|
1212
1230
|
if dir_path and not os.path.exists(dir_path):
|
|
1213
|
-
os.makedirs(dir_path)
|
|
1231
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
1214
1232
|
check_type = ["SAVE", "LOAD"]
|
|
1215
1233
|
check_mode = ["all", "principal"]
|
|
1216
1234
|
if type in check_type and mode in check_mode:
|
|
@@ -1266,7 +1284,8 @@ _set_auto_parallel_context_func_map = {
|
|
|
1266
1284
|
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
|
1267
1285
|
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
|
|
1268
1286
|
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
|
|
1269
|
-
"comm_fusion": auto_parallel_context().set_comm_fusion
|
|
1287
|
+
"comm_fusion": auto_parallel_context().set_comm_fusion,
|
|
1288
|
+
"dump_local_norm": auto_parallel_context().set_dump_local_norm}
|
|
1270
1289
|
|
|
1271
1290
|
_get_auto_parallel_context_func_map = {
|
|
1272
1291
|
"device_num": auto_parallel_context().get_device_num,
|
|
@@ -1298,7 +1317,8 @@ _get_auto_parallel_context_func_map = {
|
|
|
1298
1317
|
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
|
|
1299
1318
|
"comm_fusion": auto_parallel_context().get_comm_fusion,
|
|
1300
1319
|
"strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
|
|
1301
|
-
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set
|
|
1320
|
+
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set,
|
|
1321
|
+
"dump_local_norm": auto_parallel_context().get_dump_local_norm}
|
|
1302
1322
|
|
|
1303
1323
|
|
|
1304
1324
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
|
@@ -16,11 +16,16 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from __future__ import division
|
|
18
18
|
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from mindspore import context
|
|
19
22
|
from mindspore.nn.cell import Cell
|
|
20
23
|
from mindspore.ops import operations as P
|
|
21
24
|
from mindspore.ops.operations.comm_ops import AllGather
|
|
22
25
|
from mindspore.communication import GlobalComm
|
|
23
26
|
from mindspore.common import jit
|
|
27
|
+
from mindspore.communication import create_group
|
|
28
|
+
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
24
29
|
|
|
25
30
|
_ALLGATHER_CELL = None
|
|
26
31
|
|
|
@@ -30,6 +35,7 @@ class AllGatherCell(Cell):
|
|
|
30
35
|
Allgather cell, used in model parallel scenario.
|
|
31
36
|
To allgather the selected parameter slice from each device.
|
|
32
37
|
"""
|
|
38
|
+
|
|
33
39
|
def __init__(self, group, do_reshape, after_reshape_slice_shape):
|
|
34
40
|
super(AllGatherCell, self).__init__(auto_prefix=False)
|
|
35
41
|
self.allgather = AllGather(group)
|
|
@@ -54,6 +60,7 @@ class SaveOptShardCkptCell(Cell):
|
|
|
54
60
|
Note:
|
|
55
61
|
This could be optimized later with less communication consumption.
|
|
56
62
|
"""
|
|
63
|
+
|
|
57
64
|
def __init__(self, group, do_reshape, after_reshape_slice_shape):
|
|
58
65
|
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
|
|
59
66
|
self.allgather1 = AllGather(group)
|
|
@@ -71,6 +78,21 @@ class SaveOptShardCkptCell(Cell):
|
|
|
71
78
|
return x
|
|
72
79
|
|
|
73
80
|
|
|
81
|
+
class SingleCommunicator(Cell):
|
|
82
|
+
"""
|
|
83
|
+
Used to broadcast single parameter.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, group_name):
|
|
87
|
+
super(SingleCommunicator, self).__init__()
|
|
88
|
+
self.allreduce = P.AllReduce(group=group_name)
|
|
89
|
+
self.add_flags(skip_auto_parallel_compile=True)
|
|
90
|
+
|
|
91
|
+
def construct(self, loaded_param):
|
|
92
|
+
result = self.allreduce(loaded_param)
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
|
|
74
96
|
def get_allgather_cell(group, need_merge_twice=False, do_reshape=False, after_reshape_slice_shape=()):
|
|
75
97
|
"""Get AllGatherCell object."""
|
|
76
98
|
global _ALLGATHER_CELL
|
|
@@ -89,3 +111,64 @@ def destroy_allgather_cell():
|
|
|
89
111
|
global _ALLGATHER_CELL
|
|
90
112
|
if _ALLGATHER_CELL:
|
|
91
113
|
_ALLGATHER_CELL = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _chang_parallel_context(origin_dataset_strategy):
|
|
117
|
+
"""Change the original parallel state."""
|
|
118
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
119
|
+
context.set_auto_parallel_context(parallel_mode="hybrid_parallel")
|
|
120
|
+
if origin_dataset_strategy != "data_parallel":
|
|
121
|
+
context.set_auto_parallel_context(dataset_strategy="data_parallel")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
|
|
125
|
+
"""Restore the original parallel state."""
|
|
126
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
127
|
+
context.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
|
|
128
|
+
if origin_dataset_strategy != "data_parallel":
|
|
129
|
+
context.set_auto_parallel_context(dataset_strategy=origin_dataset_strategy)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
133
|
+
"""
|
|
134
|
+
Broadcast single parameter to other rank in data parallel dimension.
|
|
135
|
+
"""
|
|
136
|
+
from mindspore import Tensor
|
|
137
|
+
origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
138
|
+
origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
|
|
139
|
+
if layout:
|
|
140
|
+
param_redundancy = get_parameter_redundancy(layout, initial_rank)
|
|
141
|
+
else:
|
|
142
|
+
param_redundancy = get_parameter_redundancy(net)
|
|
143
|
+
if not param_redundancy:
|
|
144
|
+
return
|
|
145
|
+
single_params = remove_param_redundancy(param_redundancy)
|
|
146
|
+
if not single_params:
|
|
147
|
+
return
|
|
148
|
+
param_redundancy_reversed = {}
|
|
149
|
+
for key, redundancy in param_redundancy.items():
|
|
150
|
+
for item in redundancy:
|
|
151
|
+
if len(item) == 1:
|
|
152
|
+
continue
|
|
153
|
+
if cur_rank in item:
|
|
154
|
+
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
155
|
+
if not param_redundancy_reversed or cur_rank not in single_params:
|
|
156
|
+
return
|
|
157
|
+
net_param_dict = net.parameters_dict()
|
|
158
|
+
_chang_parallel_context(origin_dataset_strategy)
|
|
159
|
+
for group, params in param_redundancy_reversed.items():
|
|
160
|
+
create_group(str(group), list(group))
|
|
161
|
+
allreduce_input = []
|
|
162
|
+
for param in params:
|
|
163
|
+
if param not in net_param_dict:
|
|
164
|
+
continue
|
|
165
|
+
real_param = net_param_dict[param]
|
|
166
|
+
if param not in single_params[cur_rank]:
|
|
167
|
+
real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
|
|
168
|
+
allreduce_input.append(real_param)
|
|
169
|
+
if not allreduce_input:
|
|
170
|
+
continue
|
|
171
|
+
communicator = SingleCommunicator(str(group))
|
|
172
|
+
for real_param in allreduce_input:
|
|
173
|
+
real_param.set_data(communicator(real_param), real_param.sliced)
|
|
174
|
+
_restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
|
|
@@ -24,7 +24,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
|
|
|
24
24
|
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
25
25
|
_extract_layout_item
|
|
26
26
|
|
|
27
|
-
|
|
28
27
|
MAX_PATH_LENGTH = 1024
|
|
29
28
|
|
|
30
29
|
|
|
@@ -37,14 +36,17 @@ def _convert_to_list(strategy, rank_id=None):
|
|
|
37
36
|
dev_mat = list(layout.dev_matrix[0].dim)
|
|
38
37
|
tensor_map = list(layout.tensor_map[0].dim)
|
|
39
38
|
param_split_shape = list(layout.param_split_shape[0].dim)
|
|
39
|
+
field_size = int(layout.field)
|
|
40
|
+
shard_stride = int(layout.opt_weight_shard_step)
|
|
41
|
+
shard_size = int(layout.opt_weight_shard_size)
|
|
40
42
|
pipeline_stage = 0
|
|
41
43
|
origin_param_name = param_name
|
|
42
44
|
if "-" in param_name:
|
|
43
45
|
pipeline_stage, origin_param_name = param_name.split("-")
|
|
44
46
|
pipeline_stage = int(pipeline_stage)
|
|
45
47
|
if origin_param_name not in train_map:
|
|
46
|
-
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
|
|
47
|
-
|
|
48
|
+
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, field_size,
|
|
49
|
+
shard_stride, shard_size,
|
|
48
50
|
[pipeline_stage]]
|
|
49
51
|
else:
|
|
50
52
|
update_pipeline_stage_list = train_map.get(origin_param_name)[6] + [pipeline_stage]
|
|
@@ -54,15 +56,15 @@ def _convert_to_list(strategy, rank_id=None):
|
|
|
54
56
|
not_device0_nor_pipeline0 = ((rank_id // stage_device_num) > 0) and (pipeline_stage > 0)
|
|
55
57
|
if is_device0_and_pipeline0 or not_device0_nor_pipeline0:
|
|
56
58
|
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
field_size, shard_stride,
|
|
60
|
+
shard_size, update_pipeline_stage_list]
|
|
59
61
|
else:
|
|
60
62
|
train_map.get(origin_param_name)[6] = update_pipeline_stage_list
|
|
61
63
|
else:
|
|
62
64
|
if np.all(pipeline_stage <= np.array(update_pipeline_stage_list)):
|
|
63
65
|
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
field_size, shard_stride,
|
|
67
|
+
shard_size, update_pipeline_stage_list]
|
|
66
68
|
else:
|
|
67
69
|
train_map.get(origin_param_name)[6] = update_pipeline_stage_list
|
|
68
70
|
except BaseException as e:
|
|
@@ -174,6 +176,8 @@ def _build_json_strategy(strategy_filename):
|
|
|
174
176
|
|
|
175
177
|
def _build_searched_strategy(strategy_filename):
|
|
176
178
|
"""build searched strategy"""
|
|
179
|
+
if strategy_filename is None:
|
|
180
|
+
return strategy_filename
|
|
177
181
|
_check_strategy_file(strategy_filename)
|
|
178
182
|
if strategy_filename[-5:] != ".json":
|
|
179
183
|
return _build_protobuf_strategy(strategy_filename)
|
|
@@ -239,7 +243,10 @@ def _extract_layout_map(strategy_file, rank_id=None):
|
|
|
239
243
|
"""Extract layout map"""
|
|
240
244
|
layout_map = None
|
|
241
245
|
if strategy_file is not None:
|
|
242
|
-
|
|
246
|
+
if not isinstance(strategy_file, dict):
|
|
247
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
248
|
+
else:
|
|
249
|
+
src_strategy = strategy_file
|
|
243
250
|
layout_map = _convert_to_list(src_strategy, rank_id)
|
|
244
251
|
return layout_map
|
|
245
252
|
|
|
@@ -248,7 +255,10 @@ def _extract_pipeline_stage_num(strategy_file):
|
|
|
248
255
|
"""extract pipeline stage num"""
|
|
249
256
|
pipeline_stage_num = 1
|
|
250
257
|
if strategy_file is not None:
|
|
251
|
-
|
|
258
|
+
if not isinstance(strategy_file, dict):
|
|
259
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
260
|
+
else:
|
|
261
|
+
src_strategy = strategy_file
|
|
252
262
|
layout_map = _convert_to_list(src_strategy)
|
|
253
263
|
pipeline_stage_set = set()
|
|
254
264
|
for _, layout in layout_map.items():
|
|
@@ -323,7 +333,10 @@ def _get_device_num_from_strategy(strategy_file=None):
|
|
|
323
333
|
"""Get device num from strategy file"""
|
|
324
334
|
if strategy_file is None:
|
|
325
335
|
return 1
|
|
326
|
-
|
|
336
|
+
if not isinstance(strategy_file, dict):
|
|
337
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
338
|
+
else:
|
|
339
|
+
src_strategy = strategy_file
|
|
327
340
|
strategy_list = _convert_to_list(src_strategy)
|
|
328
341
|
device_mat = list(strategy_list.values())[0][0]
|
|
329
342
|
return np.prod(device_mat)
|
|
@@ -341,14 +354,15 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
|
|
|
341
354
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
|
|
342
355
|
src_strategy_list.get(param_name))
|
|
343
356
|
from_device_num = np.prod(from_dev_matrix)
|
|
344
|
-
fake_tensor_shape = [8] * len(from_tensor_map)
|
|
345
357
|
to_dev_matrix = [1]
|
|
346
|
-
to_tensor_map = [-1] * len(
|
|
358
|
+
to_tensor_map = [-1] * len(from_tensor_map)
|
|
347
359
|
to_opt_shard_step = 0
|
|
348
360
|
to_opt_shard_size = 0
|
|
349
361
|
if dst_strategy_list is not None:
|
|
350
362
|
to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
351
363
|
dst_strategy_list.get(param_name))
|
|
364
|
+
to_device_num = np.prod(to_dev_matrix)
|
|
365
|
+
fake_tensor_shape = [max(from_device_num, to_device_num)] * len(from_tensor_map)
|
|
352
366
|
handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
|
|
353
367
|
to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size)
|
|
354
368
|
if handled_key in handled_layout:
|
|
@@ -433,7 +447,6 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
433
447
|
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
434
448
|
device_list, rank_id)
|
|
435
449
|
|
|
436
|
-
|
|
437
450
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
438
451
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
439
452
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
@@ -443,10 +456,10 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
443
456
|
transform_tensor = ms.Tensor(param_total_dict_copy[rank_id % device_num])
|
|
444
457
|
requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
|
|
445
458
|
layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
|
|
446
|
-
|
|
459
|
+
transform_param = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
|
|
447
460
|
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
448
|
-
|
|
449
|
-
transform_param_dict[param_name] =
|
|
461
|
+
transform_param.set_dtype(ms.bfloat16)
|
|
462
|
+
transform_param_dict[param_name] = transform_param
|
|
450
463
|
if device_num < 1:
|
|
451
464
|
raise ValueError("None of the parameters in checkpoint file are in either src strategy or "
|
|
452
465
|
"dst strategy. Please check correctness of strategy files.")
|
|
@@ -454,13 +467,13 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
454
467
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
455
468
|
for param_name, _ in param_total_dict.items():
|
|
456
469
|
if param_name not in transform_param_dict:
|
|
457
|
-
|
|
470
|
+
transform_param = ms.Parameter(
|
|
458
471
|
ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
|
|
459
472
|
param_attr_dict[param_name][rank_id % device_num][0],
|
|
460
473
|
param_attr_dict[param_name][rank_id % device_num][1])
|
|
461
474
|
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
462
|
-
|
|
463
|
-
transform_param_dict[param_name] =
|
|
475
|
+
transform_param.set_dtype(ms.bfloat16)
|
|
476
|
+
transform_param_dict[param_name] = transform_param
|
|
464
477
|
|
|
465
478
|
transform_param_list = [{"name": param_name, "data": param_data}
|
|
466
479
|
for param_name, param_data in transform_param_dict.items()]
|
|
@@ -531,3 +544,18 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
|
|
|
531
544
|
continue
|
|
532
545
|
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
|
533
546
|
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
|
|
550
|
+
"""Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
|
|
551
|
+
total_device_num = 1
|
|
552
|
+
for n in device_arrangement:
|
|
553
|
+
total_device_num *= n
|
|
554
|
+
if first_dim_sharded_device_index != len(device_arrangement) - 1:
|
|
555
|
+
return list(range(0, total_device_num))
|
|
556
|
+
first_dim_sharded_size = device_arrangement[-1 - first_dim_sharded_device_index]
|
|
557
|
+
range_size = total_device_num // first_dim_sharded_size
|
|
558
|
+
offset = rank % range_size
|
|
559
|
+
start = rank - offset
|
|
560
|
+
param_total_list = list(range(start, start + range_size))
|
|
561
|
+
return param_total_list
|
mindspore/parallel/_tensor.py
CHANGED
|
@@ -334,8 +334,10 @@ def _extract_layout_item(layout_item):
|
|
|
334
334
|
tensor_map = layout_item[1]
|
|
335
335
|
opt_shard_step = layout_item[4]
|
|
336
336
|
opt_shard_size = layout_item[5]
|
|
337
|
+
tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
|
|
338
|
+
model_parallel_shard_size = np.prod(tensor_strategy)
|
|
337
339
|
if opt_shard_size == -1:
|
|
338
|
-
opt_shard_size = np.prod(dev_matrix) //
|
|
340
|
+
opt_shard_size = np.prod(dev_matrix) // model_parallel_shard_size
|
|
339
341
|
return dev_matrix, tensor_map, opt_shard_step, opt_shard_size
|
|
340
342
|
|
|
341
343
|
|
|
@@ -406,12 +408,35 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
|
|
|
406
408
|
if opt_shard_step == 0 or opt_shard_size == 0:
|
|
407
409
|
return dev_matrix, tensor_map, list(origin_full_tensor_shape)
|
|
408
410
|
tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
+
repeated_dim = []
|
|
412
|
+
dev_sharded_index = []
|
|
413
|
+
for dim in tensor_map:
|
|
414
|
+
if dim != -1:
|
|
415
|
+
dev_sharded_index.append(len(dev_matrix) - dim - 1)
|
|
416
|
+
for index, value in enumerate(dev_matrix):
|
|
417
|
+
if index not in dev_sharded_index and value > 1:
|
|
418
|
+
repeated_dim.append(index)
|
|
419
|
+
if not repeated_dim:
|
|
420
|
+
raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
|
|
421
|
+
format(dev_matrix, tensor_map))
|
|
422
|
+
if len(repeated_dim) == 1 and np.prod(dev_matrix[repeated_dim[0] + 1:]) != opt_shard_step:
|
|
411
423
|
raise ValueError("The optimizer sharding step {} is not equal to the model parallel sharding size {}.".
|
|
412
|
-
format(opt_shard_step,
|
|
413
|
-
|
|
424
|
+
format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
|
|
414
425
|
first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
|
|
426
|
+
if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
|
|
427
|
+
tensor_shape_new = list(origin_full_tensor_shape)
|
|
428
|
+
tensor_shape_new[0] = tensor_strategy[0]
|
|
429
|
+
accu_shp = 1
|
|
430
|
+
for i in range(len(repeated_dim) - 1):
|
|
431
|
+
opt_sharding_size = dev_matrix[repeated_dim[i]]
|
|
432
|
+
tensor_shape_new.insert(i + 1, opt_sharding_size)
|
|
433
|
+
accu_shp = accu_shp * opt_sharding_size
|
|
434
|
+
tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
|
|
435
|
+
tensor_map_new = list(copy.deepcopy(tensor_map))
|
|
436
|
+
for index, r_dim in enumerate(repeated_dim):
|
|
437
|
+
tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
|
|
438
|
+
return list(dev_matrix), tensor_map_new, tensor_shape_new
|
|
439
|
+
|
|
415
440
|
full_tensor_shape = list(origin_full_tensor_shape)
|
|
416
441
|
full_tensor_shape[0] = tensor_strategy[0]
|
|
417
442
|
full_tensor_shape.insert(1, first_dim_no_sharding_size)
|
|
@@ -452,7 +477,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
|
|
|
452
477
|
result_map = {self_rank: transform_operators}
|
|
453
478
|
for operators in transform_operators:
|
|
454
479
|
op_name = operators[0]
|
|
455
|
-
if op_name == "
|
|
480
|
+
if op_name == "AllConcat":
|
|
456
481
|
groups = operators[1][:-1]
|
|
457
482
|
stack.append((index, groups))
|
|
458
483
|
index += 1
|
|
@@ -466,7 +491,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
|
|
|
466
491
|
index = 0
|
|
467
492
|
for operators in new_transform_operators:
|
|
468
493
|
op_name = operators[0]
|
|
469
|
-
if op_name == "
|
|
494
|
+
if op_name == "AllConcat" and index < group_info[0]:
|
|
470
495
|
groups = operators[1][:-1]
|
|
471
496
|
stack.insert(0, (index, groups))
|
|
472
497
|
index += 1
|
|
@@ -491,7 +516,7 @@ def _generate_transform_operator_stack(transform_operators_map, self_rank):
|
|
|
491
516
|
level = queue_front[1]
|
|
492
517
|
current_operator = queue_front[2]
|
|
493
518
|
if level >= 1:
|
|
494
|
-
if current_operator[0] == "
|
|
519
|
+
if current_operator[0] == "AllConcat":
|
|
495
520
|
current_group = current_operator[1][:-1]
|
|
496
521
|
for rank_id in current_group:
|
|
497
522
|
handle_queue.append((rank_id, level - 1, transform_operators_map[rank_id][level - 1]))
|
|
@@ -523,7 +548,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
|
|
|
523
548
|
if operator[0] != op_name:
|
|
524
549
|
raise ValueError("The operator in the same level should be equal in the transform tensor operator "
|
|
525
550
|
"list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
|
|
526
|
-
if operator[0] != "
|
|
551
|
+
if operator[0] != "AllConcat":
|
|
527
552
|
tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(tensor_dict[rank_id % device_num],
|
|
528
553
|
operator)
|
|
529
554
|
continue
|
|
@@ -532,7 +557,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
|
|
|
532
557
|
raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
|
|
533
558
|
allgather_list = [tensor_dict[rank % device_num] for rank in operator[1][:-1]]
|
|
534
559
|
tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
|
|
535
|
-
if op_name == "
|
|
560
|
+
if op_name == "AllConcat":
|
|
536
561
|
for rank, value in tmp_tensor_dict.items():
|
|
537
562
|
tensor_dict[rank % device_num] = value
|
|
538
563
|
level_operators.clear()
|
|
@@ -621,7 +646,7 @@ def _apply_operator(operator_name):
|
|
|
621
646
|
return numpy_data[slice_index]
|
|
622
647
|
|
|
623
648
|
_apply_operator_map = {"Reshape": _apply_reshape_operator, "StridedSlice": _apply_slice_operator,
|
|
624
|
-
"
|
|
649
|
+
"AllConcat": _apply_allconcat_operator}
|
|
625
650
|
return _apply_operator_map.get(operator_name)
|
|
626
651
|
|
|
627
652
|
|
|
@@ -658,3 +683,48 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
|
|
658
683
|
for i in range(1, len(tensor_slices_col)):
|
|
659
684
|
new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
|
|
660
685
|
return Tensor(new_tensor)
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def _load_tensor_shape(dev_mat, tensor_map, full_shape=None, rank_id=-1):
|
|
689
|
+
"""get tensor shape by slice"""
|
|
690
|
+
if rank_id == -1:
|
|
691
|
+
rank = get_rank()
|
|
692
|
+
else:
|
|
693
|
+
rank = rank_id
|
|
694
|
+
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
695
|
+
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
696
|
+
np_tensor_list = _chunk_shape_by_strategy(full_shape, tensor_strategy)
|
|
697
|
+
np_tensor_slice_index = np_tensor_list[int(tensor_slice_index)]
|
|
698
|
+
res = []
|
|
699
|
+
for index in np_tensor_slice_index:
|
|
700
|
+
res.append(slice(index[0], index[1]))
|
|
701
|
+
return tuple(res)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _chunk_shape_by_strategy(full_shape, strategy):
|
|
705
|
+
"""chunk shape by strategy"""
|
|
706
|
+
shape = []
|
|
707
|
+
for i in full_shape:
|
|
708
|
+
shape.append([0, i])
|
|
709
|
+
return _chunk_shape(shape, strategy, len(strategy))
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _chunk_shape(np_tensor, strategy, depth):
|
|
713
|
+
"""_chunk shape"""
|
|
714
|
+
output = []
|
|
715
|
+
axis = len(np_tensor) - depth
|
|
716
|
+
left, right = np_tensor[axis]
|
|
717
|
+
num = strategy[0]
|
|
718
|
+
chunk_size = (right - left) / num
|
|
719
|
+
append = [[i, int(i + chunk_size)] for i in range(left, right) if i % chunk_size == 0]
|
|
720
|
+
np_tensor_new = []
|
|
721
|
+
for i in append:
|
|
722
|
+
np_tensor_tmp = copy.deepcopy(np_tensor)
|
|
723
|
+
np_tensor_tmp[axis] = i
|
|
724
|
+
np_tensor_new.append(np_tensor_tmp)
|
|
725
|
+
if depth == 1:
|
|
726
|
+
return np_tensor_new
|
|
727
|
+
for ret_ in np_tensor_new:
|
|
728
|
+
output.extend(
|
|
729
|
+
_chunk_shape(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
|
|
730
|
+
return output
|
mindspore/parallel/_utils.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Utils of auto parallel"""
|
|
16
|
+
import os
|
|
16
17
|
from importlib import import_module
|
|
17
18
|
import numpy as np
|
|
18
19
|
import mindspore as ms
|
|
@@ -22,12 +23,13 @@ from mindspore.common.tensor import Tensor
|
|
|
22
23
|
from mindspore.common.dtype import dtype_to_nptype
|
|
23
24
|
from mindspore.common import dtype as mstype
|
|
24
25
|
from mindspore.communication.management import get_group_size, get_rank
|
|
26
|
+
from mindspore.communication._comm_helper import _is_initialized
|
|
25
27
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
26
28
|
from mindspore.common.seed import get_seed
|
|
27
29
|
from mindspore._c_expression import GraphExecutor_
|
|
28
30
|
from mindspore.parallel._tensor import _load_tensor_by_layout
|
|
29
31
|
|
|
30
|
-
SUPPORTED_TUPLE_IN_TUPLE_STRATEGY = ["GroupedMatmul", "FusedInferAttentionScore"]
|
|
32
|
+
SUPPORTED_TUPLE_IN_TUPLE_STRATEGY = ["GroupedMatmul", "FusedInferAttentionScore", "Custom"]
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
def _get_parallel_mode():
|
|
@@ -45,6 +47,16 @@ def _is_in_auto_parallel_mode():
|
|
|
45
47
|
return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL]
|
|
46
48
|
|
|
47
49
|
|
|
50
|
+
def _is_parallel_mode():
|
|
51
|
+
if not _is_initialized() or context.get_context('mode') == context.PYNATIVE_MODE:
|
|
52
|
+
return False
|
|
53
|
+
if os.getenv("RUN_MODE") != "predict":
|
|
54
|
+
return False
|
|
55
|
+
if get_group_size() > 1 and _get_parallel_mode() == ms.ParallelMode.STAND_ALONE:
|
|
56
|
+
return True
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
|
|
48
60
|
def _is_in_data_parallel_mode():
|
|
49
61
|
return _get_parallel_mode() == ms.ParallelMode.DATA_PARALLEL
|
|
50
62
|
|
|
@@ -234,7 +234,7 @@ def set_algo_parameters(**kwargs):
|
|
|
234
234
|
|
|
235
235
|
Args:
|
|
236
236
|
fully_use_devices (bool): Whether ONLY searching strategies that fully use all available devices.
|
|
237
|
-
Default: ``
|
|
237
|
+
Default: ``False`` . For example with 8 devices available, if set ``True`` , strategy (4, 1) will not be
|
|
238
238
|
included in ReLU's candidate strategies, because strategy (4, 1) only utilizes 4 devices.
|
|
239
239
|
elementwise_op_strategy_follow (bool): Whether the elementwise operator has the consistent strategies as its
|
|
240
240
|
subsequent operators. Elementwise operators refer to operators that operate on input element by element,
|
|
@@ -264,14 +264,14 @@ def set_algo_parameters(**kwargs):
|
|
|
264
264
|
|
|
265
265
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
266
266
|
Please see the `rank table startup
|
|
267
|
-
<https://www.mindspore.cn/
|
|
267
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
|
|
268
268
|
for more details.
|
|
269
269
|
|
|
270
270
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
271
|
-
<https://www.mindspore.cn/
|
|
271
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
|
|
272
272
|
|
|
273
273
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
274
|
-
Startup <https://www.mindspore.cn/
|
|
274
|
+
Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
|
|
275
275
|
|
|
276
276
|
>>> import numpy as np
|
|
277
277
|
>>> import mindspore as ms
|
|
@@ -386,7 +386,7 @@ def reset_algo_parameters():
|
|
|
386
386
|
|
|
387
387
|
After reset, the values of the attributes are:
|
|
388
388
|
|
|
389
|
-
- fully_use_devices:
|
|
389
|
+
- fully_use_devices: False.
|
|
390
390
|
- elementwise_op_strategy_follow: False.
|
|
391
391
|
- enable_algo_approxi: False.
|
|
392
392
|
- algo_approxi_epsilon: 0.1.
|