mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -76,6 +76,7 @@ class _PipelineConfig:
|
|
|
76
76
|
class _PipelineScheduler:
|
|
77
77
|
PIPELINE_1F1B = "1f1b"
|
|
78
78
|
PIPELINE_GPIPE = "gpipe"
|
|
79
|
+
PIPELINE_SEQPIPE = "seqpipe"
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
class _AutoParallelContext:
|
|
@@ -168,6 +169,24 @@ class _AutoParallelContext:
|
|
|
168
169
|
self.check_context_handle()
|
|
169
170
|
return _ParallelFusionConfig.CONFIG
|
|
170
171
|
|
|
172
|
+
def set_dump_local_norm(self, dump_local_norm):
|
|
173
|
+
"""
|
|
174
|
+
Set dump local norm for auto parallel.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
dump_local_norm (bool): User need to specify if he want to dump local norm. Default: False
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
KeyError: When key of comm_fusion is not 'allreduce'.
|
|
181
|
+
"""
|
|
182
|
+
self.check_context_handle()
|
|
183
|
+
self._context_handle.set_dump_local_norm(dump_local_norm)
|
|
184
|
+
|
|
185
|
+
def get_dump_local_norm(self):
|
|
186
|
+
"""Get dump local norm."""
|
|
187
|
+
self.check_context_handle()
|
|
188
|
+
return self._context_handle.get_dump_local_norm()
|
|
189
|
+
|
|
171
190
|
def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
|
|
172
191
|
"""
|
|
173
192
|
Set fusion threshold (MB) for auto parallel.
|
|
@@ -584,7 +603,7 @@ class _AutoParallelContext:
|
|
|
584
603
|
self.check_context_handle()
|
|
585
604
|
dir_path = os.path.dirname(strategy_ckpt_save_file)
|
|
586
605
|
if dir_path and not os.path.exists(dir_path):
|
|
587
|
-
os.makedirs(dir_path, exist_ok=True)
|
|
606
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
588
607
|
self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
|
|
589
608
|
|
|
590
609
|
def get_strategy_ckpt_save_file(self):
|
|
@@ -643,7 +662,7 @@ class _AutoParallelContext:
|
|
|
643
662
|
self.check_context_handle()
|
|
644
663
|
dir_path = os.path.dirname(group_ckpt_save_file)
|
|
645
664
|
if dir_path and not os.path.exists(dir_path):
|
|
646
|
-
os.makedirs(dir_path)
|
|
665
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
647
666
|
self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
|
|
648
667
|
|
|
649
668
|
def get_parameter_broadcast_is_set(self):
|
|
@@ -896,7 +915,8 @@ class _AutoParallelContext:
|
|
|
896
915
|
pipeline_config[pp_interleave])
|
|
897
916
|
|
|
898
917
|
Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
|
|
899
|
-
_PipelineScheduler.PIPELINE_GPIPE
|
|
918
|
+
_PipelineScheduler.PIPELINE_GPIPE,
|
|
919
|
+
_PipelineScheduler.PIPELINE_SEQPIPE])
|
|
900
920
|
if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
|
|
901
921
|
raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
|
|
902
922
|
|
|
@@ -1117,9 +1137,9 @@ class _AutoParallelContext:
|
|
|
1117
1137
|
"""
|
|
1118
1138
|
self.check_context_handle()
|
|
1119
1139
|
if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
|
|
1120
|
-
|
|
1140
|
+
self.set_enable_all_gather_fusion(True)
|
|
1121
1141
|
if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
|
|
1122
|
-
|
|
1142
|
+
self.set_enable_reduce_scatter_fusion(True)
|
|
1123
1143
|
if not isinstance(comm_fusion, dict):
|
|
1124
1144
|
raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
|
|
1125
1145
|
comm_type, type(comm_fusion)))
|
|
@@ -1153,7 +1173,7 @@ class _AutoParallelContext:
|
|
|
1153
1173
|
"""
|
|
1154
1174
|
self.check_context_handle()
|
|
1155
1175
|
if not self.get_enable_all_reduce_fusion():
|
|
1156
|
-
|
|
1176
|
+
self.set_enable_all_reduce_fusion(True)
|
|
1157
1177
|
if not isinstance(comm_fusion, dict):
|
|
1158
1178
|
raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
|
|
1159
1179
|
type(comm_fusion)))
|
|
@@ -1210,7 +1230,7 @@ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
|
|
|
1210
1230
|
"""
|
|
1211
1231
|
dir_path = os.path.dirname(path)
|
|
1212
1232
|
if dir_path and not os.path.exists(dir_path):
|
|
1213
|
-
os.makedirs(dir_path)
|
|
1233
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
1214
1234
|
check_type = ["SAVE", "LOAD"]
|
|
1215
1235
|
check_mode = ["all", "principal"]
|
|
1216
1236
|
if type in check_type and mode in check_mode:
|
|
@@ -1266,7 +1286,8 @@ _set_auto_parallel_context_func_map = {
|
|
|
1266
1286
|
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
|
1267
1287
|
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
|
|
1268
1288
|
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
|
|
1269
|
-
"comm_fusion": auto_parallel_context().set_comm_fusion
|
|
1289
|
+
"comm_fusion": auto_parallel_context().set_comm_fusion,
|
|
1290
|
+
"dump_local_norm": auto_parallel_context().set_dump_local_norm}
|
|
1270
1291
|
|
|
1271
1292
|
_get_auto_parallel_context_func_map = {
|
|
1272
1293
|
"device_num": auto_parallel_context().get_device_num,
|
|
@@ -1298,7 +1319,8 @@ _get_auto_parallel_context_func_map = {
|
|
|
1298
1319
|
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
|
|
1299
1320
|
"comm_fusion": auto_parallel_context().get_comm_fusion,
|
|
1300
1321
|
"strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
|
|
1301
|
-
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set
|
|
1322
|
+
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set,
|
|
1323
|
+
"dump_local_norm": auto_parallel_context().get_dump_local_norm}
|
|
1302
1324
|
|
|
1303
1325
|
|
|
1304
1326
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
|
@@ -16,11 +16,16 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
from __future__ import division
|
|
18
18
|
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from mindspore import context
|
|
19
22
|
from mindspore.nn.cell import Cell
|
|
20
23
|
from mindspore.ops import operations as P
|
|
21
24
|
from mindspore.ops.operations.comm_ops import AllGather
|
|
22
25
|
from mindspore.communication import GlobalComm
|
|
23
26
|
from mindspore.common import jit
|
|
27
|
+
from mindspore.communication import create_group
|
|
28
|
+
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
24
29
|
|
|
25
30
|
_ALLGATHER_CELL = None
|
|
26
31
|
|
|
@@ -30,6 +35,7 @@ class AllGatherCell(Cell):
|
|
|
30
35
|
Allgather cell, used in model parallel scenario.
|
|
31
36
|
To allgather the selected parameter slice from each device.
|
|
32
37
|
"""
|
|
38
|
+
|
|
33
39
|
def __init__(self, group, do_reshape, after_reshape_slice_shape):
|
|
34
40
|
super(AllGatherCell, self).__init__(auto_prefix=False)
|
|
35
41
|
self.allgather = AllGather(group)
|
|
@@ -54,6 +60,7 @@ class SaveOptShardCkptCell(Cell):
|
|
|
54
60
|
Note:
|
|
55
61
|
This could be optimized later with less communication consumption.
|
|
56
62
|
"""
|
|
63
|
+
|
|
57
64
|
def __init__(self, group, do_reshape, after_reshape_slice_shape):
|
|
58
65
|
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
|
|
59
66
|
self.allgather1 = AllGather(group)
|
|
@@ -71,6 +78,21 @@ class SaveOptShardCkptCell(Cell):
|
|
|
71
78
|
return x
|
|
72
79
|
|
|
73
80
|
|
|
81
|
+
class SingleCommunicator(Cell):
|
|
82
|
+
"""
|
|
83
|
+
Used to broadcast single parameter.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, group_name):
|
|
87
|
+
super(SingleCommunicator, self).__init__()
|
|
88
|
+
self.allreduce = P.AllReduce(group=group_name)
|
|
89
|
+
self.add_flags(skip_auto_parallel_compile=True)
|
|
90
|
+
|
|
91
|
+
def construct(self, loaded_param):
|
|
92
|
+
result = self.allreduce(loaded_param)
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
|
|
74
96
|
def get_allgather_cell(group, need_merge_twice=False, do_reshape=False, after_reshape_slice_shape=()):
|
|
75
97
|
"""Get AllGatherCell object."""
|
|
76
98
|
global _ALLGATHER_CELL
|
|
@@ -89,3 +111,66 @@ def destroy_allgather_cell():
|
|
|
89
111
|
global _ALLGATHER_CELL
|
|
90
112
|
if _ALLGATHER_CELL:
|
|
91
113
|
_ALLGATHER_CELL = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _chang_parallel_context(origin_dataset_strategy):
|
|
117
|
+
"""Change the original parallel state."""
|
|
118
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
119
|
+
context.set_auto_parallel_context(parallel_mode="hybrid_parallel")
|
|
120
|
+
if origin_dataset_strategy != "data_parallel":
|
|
121
|
+
context.set_auto_parallel_context(dataset_strategy="data_parallel")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
|
|
125
|
+
"""Restore the original parallel state."""
|
|
126
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
127
|
+
context.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
|
|
128
|
+
if origin_dataset_strategy != "data_parallel":
|
|
129
|
+
if origin_dataset_strategy is not None and isinstance(origin_dataset_strategy, list):
|
|
130
|
+
origin_dataset_strategy = tuple(tuple(ds_item) for ds_item in origin_dataset_strategy)
|
|
131
|
+
context.set_auto_parallel_context(dataset_strategy=origin_dataset_strategy)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
135
|
+
"""
|
|
136
|
+
Broadcast single parameter to other rank in data parallel dimension.
|
|
137
|
+
"""
|
|
138
|
+
from mindspore import Tensor
|
|
139
|
+
origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
140
|
+
origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
|
|
141
|
+
if layout:
|
|
142
|
+
param_redundancy = get_parameter_redundancy(layout, initial_rank)
|
|
143
|
+
else:
|
|
144
|
+
param_redundancy = get_parameter_redundancy(net)
|
|
145
|
+
if not param_redundancy:
|
|
146
|
+
return
|
|
147
|
+
single_params = remove_param_redundancy(param_redundancy)
|
|
148
|
+
if not single_params:
|
|
149
|
+
return
|
|
150
|
+
param_redundancy_reversed = {}
|
|
151
|
+
for key, redundancy in param_redundancy.items():
|
|
152
|
+
for item in redundancy:
|
|
153
|
+
if len(item) == 1:
|
|
154
|
+
continue
|
|
155
|
+
if cur_rank in item:
|
|
156
|
+
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
157
|
+
if not param_redundancy_reversed or cur_rank not in single_params:
|
|
158
|
+
return
|
|
159
|
+
net_param_dict = net.parameters_dict()
|
|
160
|
+
_chang_parallel_context(origin_dataset_strategy)
|
|
161
|
+
for group, params in param_redundancy_reversed.items():
|
|
162
|
+
create_group(str(group), list(group))
|
|
163
|
+
allreduce_input = []
|
|
164
|
+
for param in params:
|
|
165
|
+
if param not in net_param_dict:
|
|
166
|
+
continue
|
|
167
|
+
real_param = net_param_dict[param]
|
|
168
|
+
if param not in single_params[cur_rank]:
|
|
169
|
+
real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
|
|
170
|
+
allreduce_input.append(real_param)
|
|
171
|
+
if not allreduce_input:
|
|
172
|
+
continue
|
|
173
|
+
communicator = SingleCommunicator(str(group))
|
|
174
|
+
for real_param in allreduce_input:
|
|
175
|
+
real_param.set_data(communicator(real_param), real_param.sliced)
|
|
176
|
+
_restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
|
|
@@ -24,7 +24,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
|
|
|
24
24
|
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
25
25
|
_extract_layout_item
|
|
26
26
|
|
|
27
|
-
|
|
28
27
|
MAX_PATH_LENGTH = 1024
|
|
29
28
|
|
|
30
29
|
|
|
@@ -37,14 +36,17 @@ def _convert_to_list(strategy, rank_id=None):
|
|
|
37
36
|
dev_mat = list(layout.dev_matrix[0].dim)
|
|
38
37
|
tensor_map = list(layout.tensor_map[0].dim)
|
|
39
38
|
param_split_shape = list(layout.param_split_shape[0].dim)
|
|
39
|
+
field_size = int(layout.field)
|
|
40
|
+
shard_stride = int(layout.opt_weight_shard_step)
|
|
41
|
+
shard_size = int(layout.opt_weight_shard_size)
|
|
40
42
|
pipeline_stage = 0
|
|
41
43
|
origin_param_name = param_name
|
|
42
44
|
if "-" in param_name:
|
|
43
45
|
pipeline_stage, origin_param_name = param_name.split("-")
|
|
44
46
|
pipeline_stage = int(pipeline_stage)
|
|
45
47
|
if origin_param_name not in train_map:
|
|
46
|
-
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
|
|
47
|
-
|
|
48
|
+
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, field_size,
|
|
49
|
+
shard_stride, shard_size,
|
|
48
50
|
[pipeline_stage]]
|
|
49
51
|
else:
|
|
50
52
|
update_pipeline_stage_list = train_map.get(origin_param_name)[6] + [pipeline_stage]
|
|
@@ -54,15 +56,15 @@ def _convert_to_list(strategy, rank_id=None):
|
|
|
54
56
|
not_device0_nor_pipeline0 = ((rank_id // stage_device_num) > 0) and (pipeline_stage > 0)
|
|
55
57
|
if is_device0_and_pipeline0 or not_device0_nor_pipeline0:
|
|
56
58
|
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
field_size, shard_stride,
|
|
60
|
+
shard_size, update_pipeline_stage_list]
|
|
59
61
|
else:
|
|
60
62
|
train_map.get(origin_param_name)[6] = update_pipeline_stage_list
|
|
61
63
|
else:
|
|
62
64
|
if np.all(pipeline_stage <= np.array(update_pipeline_stage_list)):
|
|
63
65
|
train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
field_size, shard_stride,
|
|
67
|
+
shard_size, update_pipeline_stage_list]
|
|
66
68
|
else:
|
|
67
69
|
train_map.get(origin_param_name)[6] = update_pipeline_stage_list
|
|
68
70
|
except BaseException as e:
|
|
@@ -174,6 +176,8 @@ def _build_json_strategy(strategy_filename):
|
|
|
174
176
|
|
|
175
177
|
def _build_searched_strategy(strategy_filename):
|
|
176
178
|
"""build searched strategy"""
|
|
179
|
+
if strategy_filename is None:
|
|
180
|
+
return strategy_filename
|
|
177
181
|
_check_strategy_file(strategy_filename)
|
|
178
182
|
if strategy_filename[-5:] != ".json":
|
|
179
183
|
return _build_protobuf_strategy(strategy_filename)
|
|
@@ -239,7 +243,10 @@ def _extract_layout_map(strategy_file, rank_id=None):
|
|
|
239
243
|
"""Extract layout map"""
|
|
240
244
|
layout_map = None
|
|
241
245
|
if strategy_file is not None:
|
|
242
|
-
|
|
246
|
+
if not isinstance(strategy_file, dict):
|
|
247
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
248
|
+
else:
|
|
249
|
+
src_strategy = strategy_file
|
|
243
250
|
layout_map = _convert_to_list(src_strategy, rank_id)
|
|
244
251
|
return layout_map
|
|
245
252
|
|
|
@@ -248,7 +255,10 @@ def _extract_pipeline_stage_num(strategy_file):
|
|
|
248
255
|
"""extract pipeline stage num"""
|
|
249
256
|
pipeline_stage_num = 1
|
|
250
257
|
if strategy_file is not None:
|
|
251
|
-
|
|
258
|
+
if not isinstance(strategy_file, dict):
|
|
259
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
260
|
+
else:
|
|
261
|
+
src_strategy = strategy_file
|
|
252
262
|
layout_map = _convert_to_list(src_strategy)
|
|
253
263
|
pipeline_stage_set = set()
|
|
254
264
|
for _, layout in layout_map.items():
|
|
@@ -323,7 +333,10 @@ def _get_device_num_from_strategy(strategy_file=None):
|
|
|
323
333
|
"""Get device num from strategy file"""
|
|
324
334
|
if strategy_file is None:
|
|
325
335
|
return 1
|
|
326
|
-
|
|
336
|
+
if not isinstance(strategy_file, dict):
|
|
337
|
+
src_strategy = _build_searched_strategy(strategy_file)
|
|
338
|
+
else:
|
|
339
|
+
src_strategy = strategy_file
|
|
327
340
|
strategy_list = _convert_to_list(src_strategy)
|
|
328
341
|
device_mat = list(strategy_list.values())[0][0]
|
|
329
342
|
return np.prod(device_mat)
|
|
@@ -341,14 +354,15 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
|
|
|
341
354
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
|
|
342
355
|
src_strategy_list.get(param_name))
|
|
343
356
|
from_device_num = np.prod(from_dev_matrix)
|
|
344
|
-
fake_tensor_shape = [8] * len(from_tensor_map)
|
|
345
357
|
to_dev_matrix = [1]
|
|
346
|
-
to_tensor_map = [-1] * len(
|
|
358
|
+
to_tensor_map = [-1] * len(from_tensor_map)
|
|
347
359
|
to_opt_shard_step = 0
|
|
348
360
|
to_opt_shard_size = 0
|
|
349
361
|
if dst_strategy_list is not None:
|
|
350
362
|
to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
351
363
|
dst_strategy_list.get(param_name))
|
|
364
|
+
to_device_num = np.prod(to_dev_matrix)
|
|
365
|
+
fake_tensor_shape = [max(from_device_num, to_device_num)] * len(from_tensor_map)
|
|
352
366
|
handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
|
|
353
367
|
to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size)
|
|
354
368
|
if handled_key in handled_layout:
|
|
@@ -433,7 +447,6 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
433
447
|
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
434
448
|
device_list, rank_id)
|
|
435
449
|
|
|
436
|
-
|
|
437
450
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
438
451
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
439
452
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
@@ -443,10 +456,10 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
443
456
|
transform_tensor = ms.Tensor(param_total_dict_copy[rank_id % device_num])
|
|
444
457
|
requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
|
|
445
458
|
layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
|
|
446
|
-
|
|
459
|
+
transform_param = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
|
|
447
460
|
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
448
|
-
|
|
449
|
-
transform_param_dict[param_name] =
|
|
461
|
+
transform_param.set_dtype(ms.bfloat16)
|
|
462
|
+
transform_param_dict[param_name] = transform_param
|
|
450
463
|
if device_num < 1:
|
|
451
464
|
raise ValueError("None of the parameters in checkpoint file are in either src strategy or "
|
|
452
465
|
"dst strategy. Please check correctness of strategy files.")
|
|
@@ -454,13 +467,13 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
454
467
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
455
468
|
for param_name, _ in param_total_dict.items():
|
|
456
469
|
if param_name not in transform_param_dict:
|
|
457
|
-
|
|
470
|
+
transform_param = ms.Parameter(
|
|
458
471
|
ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
|
|
459
472
|
param_attr_dict[param_name][rank_id % device_num][0],
|
|
460
473
|
param_attr_dict[param_name][rank_id % device_num][1])
|
|
461
474
|
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
462
|
-
|
|
463
|
-
transform_param_dict[param_name] =
|
|
475
|
+
transform_param.set_dtype(ms.bfloat16)
|
|
476
|
+
transform_param_dict[param_name] = transform_param
|
|
464
477
|
|
|
465
478
|
transform_param_list = [{"name": param_name, "data": param_data}
|
|
466
479
|
for param_name, param_data in transform_param_dict.items()]
|
|
@@ -531,3 +544,18 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
|
|
|
531
544
|
continue
|
|
532
545
|
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
|
533
546
|
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
|
|
550
|
+
"""Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
|
|
551
|
+
total_device_num = 1
|
|
552
|
+
for n in device_arrangement:
|
|
553
|
+
total_device_num *= n
|
|
554
|
+
if first_dim_sharded_device_index != len(device_arrangement) - 1:
|
|
555
|
+
return list(range(0, total_device_num))
|
|
556
|
+
first_dim_sharded_size = device_arrangement[-1 - first_dim_sharded_device_index]
|
|
557
|
+
range_size = total_device_num // first_dim_sharded_size
|
|
558
|
+
offset = rank % range_size
|
|
559
|
+
start = rank - offset
|
|
560
|
+
param_total_list = list(range(start, start + range_size))
|
|
561
|
+
return param_total_list
|
mindspore/parallel/_tensor.py
CHANGED
|
@@ -334,8 +334,10 @@ def _extract_layout_item(layout_item):
|
|
|
334
334
|
tensor_map = layout_item[1]
|
|
335
335
|
opt_shard_step = layout_item[4]
|
|
336
336
|
opt_shard_size = layout_item[5]
|
|
337
|
+
tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
|
|
338
|
+
model_parallel_shard_size = np.prod(tensor_strategy)
|
|
337
339
|
if opt_shard_size == -1:
|
|
338
|
-
opt_shard_size = np.prod(dev_matrix) //
|
|
340
|
+
opt_shard_size = np.prod(dev_matrix) // model_parallel_shard_size
|
|
339
341
|
return dev_matrix, tensor_map, opt_shard_step, opt_shard_size
|
|
340
342
|
|
|
341
343
|
|
|
@@ -406,12 +408,35 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
|
|
|
406
408
|
if opt_shard_step == 0 or opt_shard_size == 0:
|
|
407
409
|
return dev_matrix, tensor_map, list(origin_full_tensor_shape)
|
|
408
410
|
tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
+
repeated_dim = []
|
|
412
|
+
dev_sharded_index = []
|
|
413
|
+
for dim in tensor_map:
|
|
414
|
+
if dim != -1:
|
|
415
|
+
dev_sharded_index.append(len(dev_matrix) - dim - 1)
|
|
416
|
+
for index, value in enumerate(dev_matrix):
|
|
417
|
+
if index not in dev_sharded_index and value > 1:
|
|
418
|
+
repeated_dim.append(index)
|
|
419
|
+
if not repeated_dim:
|
|
420
|
+
raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
|
|
421
|
+
format(dev_matrix, tensor_map))
|
|
422
|
+
if len(repeated_dim) == 1 and np.prod(dev_matrix[repeated_dim[0] + 1:]) != opt_shard_step:
|
|
411
423
|
raise ValueError("The optimizer sharding step {} is not equal to the model parallel sharding size {}.".
|
|
412
|
-
format(opt_shard_step,
|
|
413
|
-
|
|
424
|
+
format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
|
|
414
425
|
first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
|
|
426
|
+
if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
|
|
427
|
+
tensor_shape_new = list(origin_full_tensor_shape)
|
|
428
|
+
tensor_shape_new[0] = tensor_strategy[0]
|
|
429
|
+
accu_shp = 1
|
|
430
|
+
for i in range(len(repeated_dim) - 1):
|
|
431
|
+
opt_sharding_size = dev_matrix[repeated_dim[i]]
|
|
432
|
+
tensor_shape_new.insert(i + 1, opt_sharding_size)
|
|
433
|
+
accu_shp = accu_shp * opt_sharding_size
|
|
434
|
+
tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
|
|
435
|
+
tensor_map_new = list(copy.deepcopy(tensor_map))
|
|
436
|
+
for index, r_dim in enumerate(repeated_dim):
|
|
437
|
+
tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
|
|
438
|
+
return list(dev_matrix), tensor_map_new, tensor_shape_new
|
|
439
|
+
|
|
415
440
|
full_tensor_shape = list(origin_full_tensor_shape)
|
|
416
441
|
full_tensor_shape[0] = tensor_strategy[0]
|
|
417
442
|
full_tensor_shape.insert(1, first_dim_no_sharding_size)
|
|
@@ -452,7 +477,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
|
|
|
452
477
|
result_map = {self_rank: transform_operators}
|
|
453
478
|
for operators in transform_operators:
|
|
454
479
|
op_name = operators[0]
|
|
455
|
-
if op_name == "
|
|
480
|
+
if op_name == "AllConcat":
|
|
456
481
|
groups = operators[1][:-1]
|
|
457
482
|
stack.append((index, groups))
|
|
458
483
|
index += 1
|
|
@@ -466,7 +491,7 @@ def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_te
|
|
|
466
491
|
index = 0
|
|
467
492
|
for operators in new_transform_operators:
|
|
468
493
|
op_name = operators[0]
|
|
469
|
-
if op_name == "
|
|
494
|
+
if op_name == "AllConcat" and index < group_info[0]:
|
|
470
495
|
groups = operators[1][:-1]
|
|
471
496
|
stack.insert(0, (index, groups))
|
|
472
497
|
index += 1
|
|
@@ -491,7 +516,7 @@ def _generate_transform_operator_stack(transform_operators_map, self_rank):
|
|
|
491
516
|
level = queue_front[1]
|
|
492
517
|
current_operator = queue_front[2]
|
|
493
518
|
if level >= 1:
|
|
494
|
-
if current_operator[0] == "
|
|
519
|
+
if current_operator[0] == "AllConcat":
|
|
495
520
|
current_group = current_operator[1][:-1]
|
|
496
521
|
for rank_id in current_group:
|
|
497
522
|
handle_queue.append((rank_id, level - 1, transform_operators_map[rank_id][level - 1]))
|
|
@@ -523,7 +548,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
|
|
|
523
548
|
if operator[0] != op_name:
|
|
524
549
|
raise ValueError("The operator in the same level should be equal in the transform tensor operator "
|
|
525
550
|
"list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
|
|
526
|
-
if operator[0] != "
|
|
551
|
+
if operator[0] != "AllConcat":
|
|
527
552
|
tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(tensor_dict[rank_id % device_num],
|
|
528
553
|
operator)
|
|
529
554
|
continue
|
|
@@ -532,7 +557,7 @@ def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, dev
|
|
|
532
557
|
raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
|
|
533
558
|
allgather_list = [tensor_dict[rank % device_num] for rank in operator[1][:-1]]
|
|
534
559
|
tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
|
|
535
|
-
if op_name == "
|
|
560
|
+
if op_name == "AllConcat":
|
|
536
561
|
for rank, value in tmp_tensor_dict.items():
|
|
537
562
|
tensor_dict[rank % device_num] = value
|
|
538
563
|
level_operators.clear()
|
|
@@ -565,6 +590,8 @@ def _apply_operator(operator_name):
|
|
|
565
590
|
Returns:
|
|
566
591
|
The data of tensor after apply operator.
|
|
567
592
|
"""
|
|
593
|
+
if str(type(numpy_data)) == "<class 'builtins.PySafeSlice'>":
|
|
594
|
+
numpy_data = numpy_data[:]
|
|
568
595
|
if not isinstance(numpy_data, np.ndarray):
|
|
569
596
|
raise TypeError("The data should be a numpy.ndarray.")
|
|
570
597
|
_check_operator(reshape_op)
|
|
@@ -604,8 +631,6 @@ def _apply_operator(operator_name):
|
|
|
604
631
|
Returns:
|
|
605
632
|
The data of tensor after apply operator.
|
|
606
633
|
"""
|
|
607
|
-
if not isinstance(numpy_data, np.ndarray):
|
|
608
|
-
raise TypeError("The data should be a numpy.ndarray.")
|
|
609
634
|
_check_operator(slice_op)
|
|
610
635
|
if len(slice_op[1]) % 3 != 0:
|
|
611
636
|
raise ValueError("The slice operator information is wrong.")
|
|
@@ -621,7 +646,7 @@ def _apply_operator(operator_name):
|
|
|
621
646
|
return numpy_data[slice_index]
|
|
622
647
|
|
|
623
648
|
_apply_operator_map = {"Reshape": _apply_reshape_operator, "StridedSlice": _apply_slice_operator,
|
|
624
|
-
"
|
|
649
|
+
"AllConcat": _apply_allconcat_operator}
|
|
625
650
|
return _apply_operator_map.get(operator_name)
|
|
626
651
|
|
|
627
652
|
|
|
@@ -658,3 +683,92 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
|
|
658
683
|
for i in range(1, len(tensor_slices_col)):
|
|
659
684
|
new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
|
|
660
685
|
return Tensor(new_tensor)
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def _load_tensor_shape(dev_mat, tensor_map, full_shape=None, rank_id=-1):
|
|
689
|
+
"""get tensor shape by slice"""
|
|
690
|
+
if rank_id == -1:
|
|
691
|
+
rank = get_rank()
|
|
692
|
+
else:
|
|
693
|
+
rank = rank_id
|
|
694
|
+
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
695
|
+
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
696
|
+
np_tensor_list = _chunk_shape_by_strategy(full_shape, tensor_strategy)
|
|
697
|
+
np_tensor_slice_index = np_tensor_list[int(tensor_slice_index)]
|
|
698
|
+
res = []
|
|
699
|
+
for index in np_tensor_slice_index:
|
|
700
|
+
res.append(slice(index[0], index[1]))
|
|
701
|
+
return tuple(res)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _count_tensor_shape(dev_mat, tensor_map, full_shape=None, rank_id=-1):
|
|
705
|
+
"""get tensor shape"""
|
|
706
|
+
if rank_id == -1:
|
|
707
|
+
rank = get_rank()
|
|
708
|
+
else:
|
|
709
|
+
rank = rank_id
|
|
710
|
+
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
711
|
+
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
712
|
+
np_tensor_list = _chunk_shape_by_strategy(full_shape, tensor_strategy)
|
|
713
|
+
np_tensor_slice_index = np_tensor_list[int(tensor_slice_index)]
|
|
714
|
+
res = []
|
|
715
|
+
for index in np_tensor_slice_index:
|
|
716
|
+
res.append(index[1] - index[0])
|
|
717
|
+
return res
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def _load_tensor_shape_by_layout(tensor, layout, rank_id):
|
|
721
|
+
"""get tensor shape by layout"""
|
|
722
|
+
if not isinstance(layout, tuple):
|
|
723
|
+
raise TypeError("The layout should be tuple! layout is {}".format(layout))
|
|
724
|
+
if len(layout) < 7:
|
|
725
|
+
raise ValueError("The length of layout must be larger than 6! layout is {}".format(layout))
|
|
726
|
+
slice_shape = layout[2]
|
|
727
|
+
if slice_shape:
|
|
728
|
+
return slice_shape
|
|
729
|
+
tensor_map = layout[1]
|
|
730
|
+
if not tensor_map:
|
|
731
|
+
return tensor.shape
|
|
732
|
+
dev_mat = layout[0]
|
|
733
|
+
uniform_split = layout[4]
|
|
734
|
+
group = layout[5]
|
|
735
|
+
full_shape = layout[6]
|
|
736
|
+
if not full_shape:
|
|
737
|
+
full_shape = tensor.shape
|
|
738
|
+
if uniform_split == 0:
|
|
739
|
+
raise RuntimeError("The load tensor only support uniform split now")
|
|
740
|
+
tensor_slice_shape = _count_tensor_shape(dev_mat, tensor_map, full_shape, rank_id)
|
|
741
|
+
if group:
|
|
742
|
+
# get a totally shard tensor slice for parallel optimizer
|
|
743
|
+
size = get_group_size(group)
|
|
744
|
+
tensor_slice_shape[0] //= size
|
|
745
|
+
return tensor_slice_shape
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def _chunk_shape_by_strategy(full_shape, strategy):
|
|
749
|
+
"""chunk shape by strategy"""
|
|
750
|
+
shape = []
|
|
751
|
+
for i in full_shape:
|
|
752
|
+
shape.append([0, i])
|
|
753
|
+
return _chunk_shape(shape, strategy, len(strategy))
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def _chunk_shape(np_tensor, strategy, depth):
|
|
757
|
+
"""_chunk shape"""
|
|
758
|
+
output = []
|
|
759
|
+
axis = len(np_tensor) - depth
|
|
760
|
+
left, right = np_tensor[axis]
|
|
761
|
+
num = strategy[0]
|
|
762
|
+
chunk_size = (right - left) / num
|
|
763
|
+
append = [[i, int(i + chunk_size)] for i in range(left, right) if i % chunk_size == 0]
|
|
764
|
+
np_tensor_new = []
|
|
765
|
+
for i in append:
|
|
766
|
+
np_tensor_tmp = copy.deepcopy(np_tensor)
|
|
767
|
+
np_tensor_tmp[axis] = i
|
|
768
|
+
np_tensor_new.append(np_tensor_tmp)
|
|
769
|
+
if depth == 1:
|
|
770
|
+
return np_tensor_new
|
|
771
|
+
for ret_ in np_tensor_new:
|
|
772
|
+
output.extend(
|
|
773
|
+
_chunk_shape(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
|
|
774
|
+
return output
|