mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Checkpoint strategy info"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
__all__ = ["get_strategy_metadata", "get_current_strategy_metadata", "enable_save_strategy_online", \
|
|
19
|
+
"clear_strategy_metadata"]
|
|
20
|
+
|
|
21
|
+
from itertools import chain
|
|
22
|
+
from typing import Sequence, Union, Tuple, List, Dict
|
|
23
|
+
from types import SimpleNamespace
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from mindspore import log as logger
|
|
28
|
+
from mindspore._c_expression import StrategyInfo
|
|
29
|
+
from mindspore._c_expression import StrategyLayout
|
|
30
|
+
from mindspore.parallel.shard import Layout
|
|
31
|
+
|
|
32
|
+
LayoutInfo = Tuple[Layout, str, str]
|
|
33
|
+
StrOrTuple = Union[str, Tuple["StrOrTuple", ...], List["StrOrTuple"]]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_strategy_metadata(network, rank_id=None) -> Dict[int, Dict[str, List[LayoutInfo]]]:
|
|
37
|
+
"""
|
|
38
|
+
Get all params strategy info or specific rank strategy info in this cell.
|
|
39
|
+
For more information on layouts, please refer to: :class:`mindspore.parallel.Layout`.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
network (str): The network name.
|
|
43
|
+
rank_id (int, optional): The rank id of the process on which this cell will be launched.
|
|
44
|
+
Defaults to ``None``, which means strategy metadata for all ranks will be returned.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Dict. A dictionary containing the parameter slicing strategies for either all ranks or a specific rank.
|
|
48
|
+
The key is `rank_id`, and the value is the slicing strategy for all parameters on that rank.
|
|
49
|
+
Within each rank's strategy, the key is the parameter name, and the value is the slicing strategy.
|
|
50
|
+
If a `rank_id` is specified, the dictionary returns the strategy information for that specific rank.
|
|
51
|
+
Otherwise, it returns the strategy information for all ranks in the network. If not supported, returns None.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> import mindspore as ms
|
|
55
|
+
>>> from mindspore import nn
|
|
56
|
+
>>> from mindspore.communication import init
|
|
57
|
+
>>> from mindspore.nn.utils import no_init_parameters
|
|
58
|
+
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
59
|
+
>>> from mindspore.train import Model
|
|
60
|
+
>>> from mindspore.parallel.strategy import get_strategy_metadata, get_current_strategy_metadata,
|
|
61
|
+
... enable_save_strategy_online, clear_strategy_metadata
|
|
62
|
+
>>>
|
|
63
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
64
|
+
>>> init()
|
|
65
|
+
>>> ms.set_seed(1)
|
|
66
|
+
>>>
|
|
67
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
68
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
69
|
+
>>> with no_init_parameters():
|
|
70
|
+
... net = LeNet5()
|
|
71
|
+
... optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
72
|
+
>>>
|
|
73
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
74
|
+
>>> train_net = AutoParallel(net, parallel_mode="semi_auto")
|
|
75
|
+
>>> model = Model(network=train_net, loss_fn=loss, optimizer=optim, metrics=None)
|
|
76
|
+
>>>
|
|
77
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
78
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
79
|
+
>>> dataset = create_dataset()
|
|
80
|
+
>>>
|
|
81
|
+
>>> enable_save_strategy_online()
|
|
82
|
+
>>> model.train(2, dataset)
|
|
83
|
+
>>>
|
|
84
|
+
>>> global_info = get_strategy_metadata(network=model.train_network)
|
|
85
|
+
>>> rank0_info = get_strategy_metadata(network=model.train_network, rank_id=0)
|
|
86
|
+
>>> local_info = get_current_strategy_metadata(network=model.train_network)
|
|
87
|
+
>>> clear_strategy_metadata()
|
|
88
|
+
"""
|
|
89
|
+
return _NetStrategyInfo(network, global_layout=None, local_layout=None).get_rank_layout(rank_id)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_current_strategy_metadata(network) -> Dict[int, Dict[str, List[LayoutInfo]]]:
|
|
93
|
+
"""
|
|
94
|
+
Get parameters dictionary of cur rank of the network.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
network(str): The network name.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Dict. The key is 0 (representing the local rank), and the value is the slicing strategy for all parameters.
|
|
101
|
+
The key within the value represents the parameter name, and the value is the corresponding slicing strategy \
|
|
102
|
+
for that parameter. If not supported, returns None.
|
|
103
|
+
"""
|
|
104
|
+
return _NetStrategyInfo(network, global_layout=None, local_layout=None).get_local_rank_layout()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def enable_save_strategy_online():
|
|
108
|
+
"""
|
|
109
|
+
Enable save strategy metadata online.
|
|
110
|
+
"""
|
|
111
|
+
strategy_layout_handle = StrategyLayout.get_instance()
|
|
112
|
+
if strategy_layout_handle is None:
|
|
113
|
+
raise ValueError("Strategy layout handle is none in parallel_strategy_checkpoint!!!")
|
|
114
|
+
strategy_layout_handle.enable_save_strategy_online()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def clear_strategy_metadata():
|
|
118
|
+
"""Clear all saved strategy metadata on the C++ side."""
|
|
119
|
+
strategy_layout_handle = StrategyLayout.get_instance()
|
|
120
|
+
if strategy_layout_handle is None:
|
|
121
|
+
raise ValueError("Strategy layout handle is none in parallel_strategy_checkpoint!!!")
|
|
122
|
+
return strategy_layout_handle.clear_strategy_metadata()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class _NetStrategyInfo:
|
|
126
|
+
"""
|
|
127
|
+
Describe the strategy information of a network.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, network, global_layout=None, local_layout=None):
|
|
131
|
+
self._network = network
|
|
132
|
+
self._compile_phase = network.compile_phase
|
|
133
|
+
if global_layout is None or local_layout is None:
|
|
134
|
+
layout_handle = self._get_layout_handle()
|
|
135
|
+
global_layout = layout_handle.global_network_layout()
|
|
136
|
+
local_layout = layout_handle.local_network_layout()
|
|
137
|
+
self._raw_global_layout = global_layout
|
|
138
|
+
self._raw_local_layout = local_layout
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def _get_layout_handle():
|
|
142
|
+
"""Get strategy handle"""
|
|
143
|
+
layout_handle = StrategyLayout.get_instance()
|
|
144
|
+
if layout_handle is None:
|
|
145
|
+
raise ValueError("Strategy layout handle is none in parallel_strategy_checkpoint!!!")
|
|
146
|
+
return layout_handle
|
|
147
|
+
|
|
148
|
+
def get_rank_layout(self, rank_id=None):
|
|
149
|
+
"""Get params of the network, global rank or special rank, interface."""
|
|
150
|
+
raw_global_layout = self._get_valid_layout(self._compile_phase, self._raw_global_layout)
|
|
151
|
+
if raw_global_layout is None:
|
|
152
|
+
return None
|
|
153
|
+
global_layout = self._extract_layout_metadata(raw_global_layout)
|
|
154
|
+
if rank_id is not None:
|
|
155
|
+
cur_rank_layout = {rank_id: global_layout[rank_id]}
|
|
156
|
+
self._layout_to_string(cur_rank_layout)
|
|
157
|
+
return cur_rank_layout
|
|
158
|
+
self._layout_to_string(global_layout)
|
|
159
|
+
return global_layout
|
|
160
|
+
|
|
161
|
+
def get_local_rank_layout(self):
|
|
162
|
+
"""Get local rank params of the network, {param_name: param_info[layout]}."""
|
|
163
|
+
raw_local_layout = self._get_valid_layout(self._compile_phase, self._raw_local_layout)
|
|
164
|
+
if raw_local_layout is None:
|
|
165
|
+
return None
|
|
166
|
+
local_layout = self._extract_layout_metadata(raw_local_layout)
|
|
167
|
+
self._layout_to_string(local_layout)
|
|
168
|
+
return local_layout
|
|
169
|
+
|
|
170
|
+
@staticmethod
|
|
171
|
+
def _get_valid_layout(phase, layout_dict):
|
|
172
|
+
"""Helper: Validate and extract layout by phase."""
|
|
173
|
+
if not phase:
|
|
174
|
+
return None
|
|
175
|
+
layout = layout_dict.get(phase)
|
|
176
|
+
if not layout or all(not v for v in layout.values()):
|
|
177
|
+
return None
|
|
178
|
+
return layout
|
|
179
|
+
|
|
180
|
+
def _extract_layout_metadata(self, layout: Dict[int, Dict[str, StrategyInfo]]) -> Dict:
|
|
181
|
+
"""Return new layout of special network."""
|
|
182
|
+
new_layout = {}
|
|
183
|
+
for rank_id, param_dict in layout.items():
|
|
184
|
+
new_param_info = {}
|
|
185
|
+
for param_name, param_info in param_dict.items():
|
|
186
|
+
new_param_layout = self._layout_process(param_info)
|
|
187
|
+
new_param_info[param_name] = new_param_layout
|
|
188
|
+
new_layout[rank_id] = new_param_info
|
|
189
|
+
return new_layout
|
|
190
|
+
|
|
191
|
+
def _layout_process(self, stra_layout):
|
|
192
|
+
"""
|
|
193
|
+
Return the layout list, stra_layout is one of params_info of cur_rank.
|
|
194
|
+
"""
|
|
195
|
+
new_dev_mat, counter, new_tensor_map, full_opt_shard = self._get_dev_mat_for_opt_shard(
|
|
196
|
+
stra_layout.opt_weight_shard_size, stra_layout.dev_matrix, stra_layout.tensor_map)
|
|
197
|
+
alphabet = 'abcdefghijklmnopqrstuvwxyz'
|
|
198
|
+
alias_name = [alphabet[i] for i in range(len(new_dev_mat))]
|
|
199
|
+
if stra_layout.opt_weight_shard_size == 0:
|
|
200
|
+
new_tensor_map = tuple(tuple(alias_name[len(alias_name) - idx - 1] if idx != -1 else "None" for idx in sub)
|
|
201
|
+
for sub in new_tensor_map)
|
|
202
|
+
else:
|
|
203
|
+
info = SimpleNamespace(
|
|
204
|
+
new_dev_mat=new_dev_mat,
|
|
205
|
+
new_tensor_map=new_tensor_map,
|
|
206
|
+
full_opt_shard=full_opt_shard,
|
|
207
|
+
counter=counter,
|
|
208
|
+
alias_name=alias_name
|
|
209
|
+
)
|
|
210
|
+
new_tensor_map = self._get_tensor_map_for_opt_shard(info)
|
|
211
|
+
new_tensor_map = self._compact_tensor_map(new_tensor_map)
|
|
212
|
+
new_dev_mat = tuple(new_dev_mat)
|
|
213
|
+
alias_name = tuple(alias_name)
|
|
214
|
+
layout = Layout(new_dev_mat, alias_name, stra_layout.rank_list)
|
|
215
|
+
final_layout = layout(*new_tensor_map)
|
|
216
|
+
logger.debug("The final layout is %s", final_layout.to_dict())
|
|
217
|
+
cur_param_list = [final_layout, stra_layout.tensor_type, stra_layout.tensor_shape]
|
|
218
|
+
return cur_param_list
|
|
219
|
+
|
|
220
|
+
def _get_dev_mat_for_opt_shard(self, opt_shard, dev_mat, tensor_map):
|
|
221
|
+
"""generate device matrix for opt shard scenario"""
|
|
222
|
+
if opt_shard == 0:
|
|
223
|
+
return dev_mat, -1, tensor_map, True
|
|
224
|
+
used_dev_num = self._calc_used_dev_num(dev_mat, tensor_map)
|
|
225
|
+
total_dev_num = int(np.prod(np.array(dev_mat)))
|
|
226
|
+
if opt_shard == -1 or used_dev_num * opt_shard == total_dev_num:
|
|
227
|
+
return dev_mat, -1, tensor_map, True
|
|
228
|
+
remain_dev_num = total_dev_num // (used_dev_num * opt_shard)
|
|
229
|
+
used_dev_mat_mask = self._get_used_dev_mat(dev_mat, tensor_map)
|
|
230
|
+
info = SimpleNamespace(
|
|
231
|
+
dev_mat=dev_mat,
|
|
232
|
+
tensor_map=tensor_map,
|
|
233
|
+
counter=-1,
|
|
234
|
+
real_remain_dev_num=1,
|
|
235
|
+
remain_dev_num=remain_dev_num
|
|
236
|
+
)
|
|
237
|
+
for axis, value in enumerate(dev_mat):
|
|
238
|
+
if used_dev_mat_mask[axis]:
|
|
239
|
+
continue
|
|
240
|
+
info.counter = axis
|
|
241
|
+
if info.real_remain_dev_num == info.remain_dev_num:
|
|
242
|
+
return dev_mat, axis, tensor_map, False
|
|
243
|
+
if info.real_remain_dev_num < info.remain_dev_num:
|
|
244
|
+
info.real_remain_dev_num *= value
|
|
245
|
+
continue
|
|
246
|
+
# info.real_remain_dev_num > info.remain_dev_num,split axis.
|
|
247
|
+
return self._split_dev_dim(info)
|
|
248
|
+
if info.real_remain_dev_num == info.remain_dev_num:
|
|
249
|
+
return dev_mat, info.counter, tensor_map, False
|
|
250
|
+
return self._split_dev_dim(info)
|
|
251
|
+
|
|
252
|
+
def _get_tensor_map_for_opt_shard(self, info: SimpleNamespace):
|
|
253
|
+
"""generate tensor map for opt shard scenario"""
|
|
254
|
+
|
|
255
|
+
def idx_to_alias(idx):
|
|
256
|
+
return "None" if idx == -1 else info.alias_name[len(info.alias_name) - idx - 1]
|
|
257
|
+
|
|
258
|
+
def entry_to_alias(entry):
|
|
259
|
+
if isinstance(entry, (list, tuple)):
|
|
260
|
+
return tuple(idx_to_alias(i) for i in entry)
|
|
261
|
+
return idx_to_alias(entry)
|
|
262
|
+
|
|
263
|
+
used_dev_mat = self._get_used_dev_mat(info.new_dev_mat, info.new_tensor_map)
|
|
264
|
+
if info.full_opt_shard:
|
|
265
|
+
unused_idx = [len(used_dev_mat) - i - 1 for i, used in enumerate(used_dev_mat) if not used]
|
|
266
|
+
else:
|
|
267
|
+
unused_idx = [len(used_dev_mat) - i - 1 for i, used in enumerate(used_dev_mat) if
|
|
268
|
+
not used and i > info.counter]
|
|
269
|
+
first_entry = info.new_tensor_map[0]
|
|
270
|
+
first_list = list(first_entry) if isinstance(first_entry, (list, tuple)) else [first_entry]
|
|
271
|
+
new_first_list = [dim for dim in first_list + unused_idx if dim != -1]
|
|
272
|
+
first_alias_list = [idx_to_alias(i) for i in new_first_list] or ["None"]
|
|
273
|
+
first_alias = first_alias_list[0] if len(first_alias_list) == 1 else tuple(first_alias_list)
|
|
274
|
+
rest_alias = [entry_to_alias(entry) for entry in info.new_tensor_map[1:]]
|
|
275
|
+
new_tensor_map = tuple([first_alias] + rest_alias)
|
|
276
|
+
return new_tensor_map
|
|
277
|
+
|
|
278
|
+
@staticmethod
|
|
279
|
+
def _split_dev_dim(info: SimpleNamespace):
|
|
280
|
+
"""Split the counter dimension of dev_mat and adjust tensor_map."""
|
|
281
|
+
dev_mat = info.dev_mat
|
|
282
|
+
counter = info.counter
|
|
283
|
+
splitted_dev_value = dev_mat[counter]
|
|
284
|
+
new_dev_mat_value_first = info.remain_dev_num // (info.real_remain_dev_num // splitted_dev_value)
|
|
285
|
+
new_dev_mat_value_second = splitted_dev_value // new_dev_mat_value_first
|
|
286
|
+
new_dev_mat = dev_mat[:counter] + [new_dev_mat_value_first, new_dev_mat_value_second] + dev_mat[counter + 1:]
|
|
287
|
+
flag = len(new_dev_mat) - 1 - counter
|
|
288
|
+
new_tensor_map = [[v if v < flag or v == -1 else v + 1 for v in sub] for sub in info.tensor_map]
|
|
289
|
+
return new_dev_mat, counter, new_tensor_map, False
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def _calc_used_dev_num(dev_mat, tensor_map):
|
|
293
|
+
"""Count the total number of device nums that have been used."""
|
|
294
|
+
idx_flat = [idx for idx in chain.from_iterable(tensor_map) if idx != -1]
|
|
295
|
+
if not idx_flat:
|
|
296
|
+
return 1
|
|
297
|
+
prod_list = [dev_mat[len(dev_mat) - idx - 1] for idx in idx_flat]
|
|
298
|
+
return int(np.prod(prod_list))
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
def _get_used_dev_mat(dev_mat, tensor_map) -> List[bool]:
|
|
302
|
+
"""List that records whether the device ID is being used or not."""
|
|
303
|
+
used = set()
|
|
304
|
+
for elem in tensor_map:
|
|
305
|
+
if isinstance(elem, (list, tuple)):
|
|
306
|
+
used.update(i for i in elem if i != -1)
|
|
307
|
+
elif elem != -1:
|
|
308
|
+
used.add(elem)
|
|
309
|
+
return [(len(dev_mat) - i - 1) in used for i in range(len(dev_mat))]
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
def _compact_tensor_map(alias_map: Sequence[StrOrTuple]) -> Tuple[StrOrTuple, ...]:
|
|
313
|
+
"""Extend tensor map of 'None'."""
|
|
314
|
+
|
|
315
|
+
def _compress(elem: StrOrTuple) -> StrOrTuple:
|
|
316
|
+
if isinstance(elem, (list, tuple)):
|
|
317
|
+
compressed = tuple(_compress(e) for e in elem)
|
|
318
|
+
if len(compressed) == 1:
|
|
319
|
+
return compressed[0]
|
|
320
|
+
if all(x == 'None' for x in compressed):
|
|
321
|
+
return 'None'
|
|
322
|
+
return compressed
|
|
323
|
+
return elem
|
|
324
|
+
|
|
325
|
+
return tuple(_compress(e) for e in alias_map)
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def _layout_to_string(layout_info):
|
|
329
|
+
"""Print layout info."""
|
|
330
|
+
for rank_id, param_layout in layout_info.items():
|
|
331
|
+
logger.info("rank_id=%s", rank_id)
|
|
332
|
+
for param_name, cur_param_list in param_layout.items():
|
|
333
|
+
final_layout, param_type, global_shape = cur_param_list
|
|
334
|
+
logger.info("param_name=%s: [param_layout=%s, param_type=%s, global_shape=%s]",
|
|
335
|
+
param_name, final_layout.to_dict(), param_type, global_shape)
|
|
336
|
+
logger.info("\n")
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Transform distributed safetensors"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
+
import copy
|
|
18
19
|
import os
|
|
19
20
|
import sys
|
|
20
21
|
import glob
|
|
@@ -68,6 +69,7 @@ dtype_size = {
|
|
|
68
69
|
"F64": 8,
|
|
69
70
|
}
|
|
70
71
|
np_dtype_size = {
|
|
72
|
+
"bool": 1,
|
|
71
73
|
"bool_": 1,
|
|
72
74
|
"uint8": 1,
|
|
73
75
|
"int8": 1,
|
|
@@ -696,6 +698,8 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
696
698
|
else:
|
|
697
699
|
if transform_param_dict:
|
|
698
700
|
if output_format == "safetensors":
|
|
701
|
+
if meta_data and "remove_redundancy" in meta_data:
|
|
702
|
+
meta_data["remove_redundancy"] = "False"
|
|
699
703
|
_save_file_atomically(transform_param_dict, save_file_name, metadata=meta_data)
|
|
700
704
|
else:
|
|
701
705
|
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
|
|
@@ -765,6 +769,11 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
|
|
|
765
769
|
param_type_dict[param_name][src_rank] = str(param.data.dtype)
|
|
766
770
|
param_total_dict[param_name][src_rank] = param
|
|
767
771
|
param_attr_dict[param_name][src_rank] = (True, False)
|
|
772
|
+
|
|
773
|
+
ckpt_prefix = os.path.basename(ckpt_prefix)
|
|
774
|
+
if '..' in ckpt_prefix or '/' in ckpt_prefix or '\\' in ckpt_prefix:
|
|
775
|
+
raise ValueError(f"Invalid ckpt_prefix: {ckpt_prefix}. Must not contain path traversal characters.")
|
|
776
|
+
|
|
768
777
|
for local_rank_id in range(dst_stage_device_num):
|
|
769
778
|
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
770
779
|
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
@@ -782,6 +791,7 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
|
|
|
782
791
|
"""
|
|
783
792
|
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank.
|
|
784
793
|
"""
|
|
794
|
+
save_safetensor_file_name = os.path.abspath(save_safetensor_file_name)
|
|
785
795
|
if not isinstance(safetensor_files_map, dict):
|
|
786
796
|
raise TypeError("The safetensor_files_map should be a dict.")
|
|
787
797
|
if not isinstance(rank_id, int):
|
|
@@ -829,11 +839,84 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
|
|
|
829
839
|
_save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
|
|
830
840
|
|
|
831
841
|
|
|
832
|
-
def
|
|
833
|
-
"""
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
842
|
+
def _extract_numbers(s):
|
|
843
|
+
"""Extract all numbers from a string and convert them to integers."""
|
|
844
|
+
return [int(num) for num in re.findall(r'\d+', s)]
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def _extract_last_two_numbers(file_name):
|
|
848
|
+
"""Get the last two numbers from a filename."""
|
|
849
|
+
all_numbers = _extract_numbers(file_name)
|
|
850
|
+
return all_numbers[-2:]
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def _find_shortest_file(matched_files, rank_ckpts, new_file_suffix, file_suffix):
|
|
854
|
+
"""Find the shortest file from a list of matched files."""
|
|
855
|
+
min_length = min(len(os.path.basename(ckpt)) for ckpt in matched_files)
|
|
856
|
+
shortest_files = [ckpt for ckpt in matched_files if len(os.path.basename(ckpt)) == min_length]
|
|
857
|
+
if len(shortest_files) == 1:
|
|
858
|
+
return shortest_files[0]
|
|
859
|
+
raise ValueError(f"Multiple files with suffix '{file_suffix}' found in {rank_ckpts}. Following MindSpore naming "
|
|
860
|
+
f"rules, searched for files ending with '{new_file_suffix}' but found multiple "
|
|
861
|
+
f"files {matched_files}. Then searched for the shortest filename, but found multiple shortest "
|
|
862
|
+
f"files {shortest_files}. Please set file_suffix to the longest common suffix of all files.")
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix):
|
|
866
|
+
"""Get the file from a list of matched files."""
|
|
867
|
+
if len(matched) == 1:
|
|
868
|
+
return matched[0]
|
|
869
|
+
if len(matched) > 1:
|
|
870
|
+
return _find_shortest_file(matched, rank_ckpts, new_file_suffix, file_suffix)
|
|
871
|
+
raise ValueError(f"Multiple files with suffix '{file_suffix}' found in {rank_ckpts}. Following MindSpore naming "
|
|
872
|
+
f"rules, searched for files ending with '{new_file_suffix}' but found zero files. "
|
|
873
|
+
f"Please set file_suffix to the longest common suffix of all files.")
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
def _find_most_matching_file(rank_ckpts, file_suffix, format):
|
|
877
|
+
"""Finds the most matching checkpoint file based on the file_suffix."""
|
|
878
|
+
if file_suffix is None:
|
|
879
|
+
rank_ckpts.sort(key=_extract_last_two_numbers)
|
|
880
|
+
return rank_ckpts[-1]
|
|
881
|
+
|
|
882
|
+
new_file_suffix = file_suffix
|
|
883
|
+
pattern1 = rf'^_(\d+)-(\d+)_(\d+)$'
|
|
884
|
+
matches1 = re.search(pattern1, file_suffix)
|
|
885
|
+
pattern2 = rf'^(\d+)-(\d+)_(\d+)$'
|
|
886
|
+
matches2 = re.search(pattern2, file_suffix)
|
|
887
|
+
# Pattern matching for _{task_id}-{epoch}_{step} format (e.g., _1-10_100 or 1-10_100)
|
|
888
|
+
if matches1 is not None or matches2 is not None:
|
|
889
|
+
if matches2 is not None:
|
|
890
|
+
new_file_suffix = "_" + new_file_suffix
|
|
891
|
+
matched = [ckpt for ckpt in rank_ckpts if ckpt.endswith(f"{new_file_suffix}.{format}") and
|
|
892
|
+
not ckpt.endswith(f"rank{new_file_suffix}.{format}")]
|
|
893
|
+
return _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix)
|
|
894
|
+
|
|
895
|
+
pattern3 = rf'^-(\d+)_(\d+)$'
|
|
896
|
+
matches3 = re.search(pattern3, file_suffix)
|
|
897
|
+
pattern4 = rf'^(\d+)_(\d+)$'
|
|
898
|
+
matches4 = re.search(pattern4, file_suffix)
|
|
899
|
+
# Pattern matching for -{epoch}_{step} format (e.g., -10_100 or 10_100)
|
|
900
|
+
if matches3 is not None or matches4 is not None:
|
|
901
|
+
if matches4 is not None:
|
|
902
|
+
new_file_suffix = "-" + new_file_suffix
|
|
903
|
+
matched = [ckpt for ckpt in rank_ckpts if ckpt.endswith(f"{new_file_suffix}.{format}")]
|
|
904
|
+
return _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix)
|
|
905
|
+
|
|
906
|
+
pattern5 = rf'^_(\d+)$'
|
|
907
|
+
matches5 = re.search(pattern5, file_suffix)
|
|
908
|
+
pattern6 = rf'^(\d+)$'
|
|
909
|
+
matches6 = re.search(pattern6, file_suffix)
|
|
910
|
+
# Pattern matching for _{step} format (e.g., _100 or 100)
|
|
911
|
+
if matches5 is not None or matches6 is not None:
|
|
912
|
+
if matches6 is not None:
|
|
913
|
+
new_file_suffix = "_" + new_file_suffix
|
|
914
|
+
matched = [ckpt for ckpt in rank_ckpts if ckpt.endswith(f"{new_file_suffix}.{format}")]
|
|
915
|
+
return _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix)
|
|
916
|
+
|
|
917
|
+
raise ValueError(f"Multiple {format} files ending with '{file_suffix}' found in {rank_ckpts}. "
|
|
918
|
+
f"Cannot determine which file is the intended one. "
|
|
919
|
+
f"Please set file_suffix to the longest common suffix.")
|
|
837
920
|
|
|
838
921
|
|
|
839
922
|
def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
|
|
@@ -844,6 +927,9 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
|
|
|
844
927
|
return {0: src_safetensors_dir}
|
|
845
928
|
safetensors_rank_dir_list = os.path.join(src_safetensors_dir, "rank_[0-9]*")
|
|
846
929
|
all_safetensor_files_map = {}
|
|
930
|
+
multiple_files_found_flag = False
|
|
931
|
+
multiple_files_list = None
|
|
932
|
+
chosen_file = None
|
|
847
933
|
for safetensor_dir in glob.glob(safetensors_rank_dir_list):
|
|
848
934
|
if not os.path.isdir(safetensor_dir):
|
|
849
935
|
ms.log.warning("{} is not a directory.".format(safetensor_dir))
|
|
@@ -859,9 +945,23 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
|
|
|
859
945
|
else:
|
|
860
946
|
safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
|
|
861
947
|
rank_ckpts = glob.glob(safetensor_file_name)
|
|
862
|
-
rank_ckpts
|
|
863
|
-
|
|
864
|
-
|
|
948
|
+
if len(rank_ckpts) > 1:
|
|
949
|
+
all_safetensor_files_map[rank_id] = _find_most_matching_file(rank_ckpts, file_suffix, format)
|
|
950
|
+
if not multiple_files_found_flag:
|
|
951
|
+
multiple_files_found_flag = True
|
|
952
|
+
multiple_files_list = copy.deepcopy(rank_ckpts)
|
|
953
|
+
chosen_file = all_safetensor_files_map[rank_id]
|
|
954
|
+
elif rank_ckpts:
|
|
955
|
+
all_safetensor_files_map[rank_id] = rank_ckpts[0]
|
|
956
|
+
elif file_suffix is not None:
|
|
957
|
+
raise ValueError(f"No safetensors files found in directory '{safetensor_dir}' "
|
|
958
|
+
f"with suffix '{file_suffix}' and format '{format}'. "
|
|
959
|
+
f"Please verify the directory contains the expected files. "
|
|
960
|
+
f"Recommend setting file_suffix to the longest common suffix.")
|
|
961
|
+
if file_suffix is not None and multiple_files_found_flag:
|
|
962
|
+
logger.warning(f"When unified_safetensors files with file_suffix `{file_suffix}`, multiple files were found. "
|
|
963
|
+
f"Showing one list: {multiple_files_list}; selected `{chosen_file}` from it. "
|
|
964
|
+
f"Please check whether the file_suffix is set correctly.")
|
|
865
965
|
return all_safetensor_files_map
|
|
866
966
|
|
|
867
967
|
|
|
@@ -978,7 +1078,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
978
1078
|
def _cal_param_size(shape, dtype):
|
|
979
1079
|
"""cal param size by dtype and shape"""
|
|
980
1080
|
num_elements = math.prod(shape)
|
|
981
|
-
element_size = np_dtype_size.get(dtype, 4)
|
|
1081
|
+
element_size = np_dtype_size.get(str(dtype), 4)
|
|
982
1082
|
total_bytes = num_elements * element_size
|
|
983
1083
|
return total_bytes
|
|
984
1084
|
|
|
@@ -1141,7 +1241,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
1141
1241
|
if os.path.isfile(src_dir):
|
|
1142
1242
|
raise ValueError("For 'unified_safetensors', the 'src_dir' can not be a file.")
|
|
1143
1243
|
all_safetensor_files_map = _collect_safetensor_files(src_dir, format="safetensors", file_suffix=file_suffix)
|
|
1144
|
-
all_ckpt_files_map = _collect_safetensor_files(src_dir, format="ckpt"
|
|
1244
|
+
all_ckpt_files_map = _collect_safetensor_files(src_dir, format="ckpt")
|
|
1145
1245
|
if all_safetensor_files_map and all_ckpt_files_map:
|
|
1146
1246
|
raise ValueError("For 'unified_safetensors', the 'src_dir' cannot contain "
|
|
1147
1247
|
"both ckpt file and safetensors file simultaneously")
|
|
@@ -1179,11 +1279,6 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
1179
1279
|
with _fast_safe_open(file_name, framework="np") as f:
|
|
1180
1280
|
for k in f.keys():
|
|
1181
1281
|
if k in name_list:
|
|
1182
|
-
py_slice = f.get_tensor(k)
|
|
1183
|
-
param_total_size += _cal_param_size(py_slice.shape, py_slice.dtype)
|
|
1184
|
-
param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
|
|
1185
|
-
# Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
|
|
1186
|
-
param_dst_shape = [int(item) for item in param_dst_shape]
|
|
1187
1282
|
if choice_func is not None:
|
|
1188
1283
|
choice_out = choice_func(k)
|
|
1189
1284
|
if isinstance(choice_out, bool):
|
|
@@ -1191,7 +1286,13 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
1191
1286
|
name_list.remove(k)
|
|
1192
1287
|
continue
|
|
1193
1288
|
if k not in param_size_dict:
|
|
1194
|
-
|
|
1289
|
+
py_slice = f.get_tensor(k)
|
|
1290
|
+
param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
|
|
1291
|
+
# Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
|
|
1292
|
+
param_dst_shape = [int(item) for item in param_dst_shape]
|
|
1293
|
+
param_size = _cal_param_size(param_dst_shape, py_slice.dtype)
|
|
1294
|
+
param_total_size += param_size
|
|
1295
|
+
param_size_dict[k] = param_size
|
|
1195
1296
|
split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
|
|
1196
1297
|
split_num = min(split_num, len(name_list))
|
|
1197
1298
|
split_list = _split_weight_dict(param_size_dict, split_num)
|
|
@@ -500,7 +500,7 @@ class BottleneckAnalyzer:
|
|
|
500
500
|
in_op_id, out_q = self._get_non_inline_child_recur(op_id), self.queue_utilization_pct[op_id]
|
|
501
501
|
# This is a leaf node since input queue does not exist and output queue exists
|
|
502
502
|
if in_op_id == self.op_id_not_exist and out_q != self.queue_usage_not_exist:
|
|
503
|
-
if out_q
|
|
503
|
+
if out_q <= self._THRESHOLDS['_LEAF_OUTPUT_QUEUE_EMPTY_FREQ_PCT_MAXIMUM']:
|
|
504
504
|
queue_usage_analysis.append(self._format_leaf_node_suggestion(op_id, out_q))
|
|
505
505
|
# This is device_queue op
|
|
506
506
|
elif self.op_names[op_id] == "DeviceQueue" and in_op_id != self.op_id_not_exist:
|
|
@@ -206,3 +206,12 @@ class FileManager:
|
|
|
206
206
|
if file_name.startswith(start_name) and file_name.endswith(".csv"):
|
|
207
207
|
file_list.append(os.path.join(source_path, file_name))
|
|
208
208
|
return file_list
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def check_file_owner(cls, path):
|
|
212
|
+
"""Check whether the file owner is the current user or root."""
|
|
213
|
+
stat_info = os.stat(path)
|
|
214
|
+
if stat_info.st_uid == 0:
|
|
215
|
+
return True
|
|
216
|
+
current_uid = os.geteuid()
|
|
217
|
+
return current_uid == stat_info.st_uid
|
|
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional
|
|
|
22
22
|
from mindspore import log as logger
|
|
23
23
|
from mindspore.profiler.common.command_executor import CommandExecutor
|
|
24
24
|
from mindspore.profiler.common.constant import ExportType
|
|
25
|
+
from mindspore.profiler.common.path_manager import PathManager
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class MsprofCmdTool:
|
|
@@ -120,6 +121,7 @@ class MsprofCmdTool:
|
|
|
120
121
|
Raises:
|
|
121
122
|
FileNotFoundError: If msprof or python3 command is not found.
|
|
122
123
|
"""
|
|
124
|
+
self._check_msprof_profile_path_is_valid()
|
|
123
125
|
if not shutil.which(self._MSPROF_CMD):
|
|
124
126
|
logger.warning(
|
|
125
127
|
"The msprof command is not found in PATH. Searching in environment variables..."
|
|
@@ -131,11 +133,44 @@ class MsprofCmdTool:
|
|
|
131
133
|
logger.info("Successfully added msprof command to PATH.")
|
|
132
134
|
else:
|
|
133
135
|
raise FileNotFoundError("Failed to find msprof command in environment.")
|
|
134
|
-
|
|
136
|
+
else:
|
|
137
|
+
msprof_path = shutil.which(self._MSPROF_CMD)
|
|
138
|
+
self._check_msprof_permission(msprof_path)
|
|
135
139
|
if not shutil.which("python3"):
|
|
136
140
|
logger.warning("Failed to find python3 command in environment.")
|
|
137
141
|
raise FileNotFoundError("Failed to find python3 command in environment.")
|
|
138
142
|
|
|
143
|
+
def _check_msprof_profile_path_is_valid(self):
|
|
144
|
+
"""Check msprof profiler path is invalid."""
|
|
145
|
+
PathManager.check_directory_path_readable(self._msprof_profile_path)
|
|
146
|
+
PathManager.check_directory_path_writeable(self._msprof_profile_path)
|
|
147
|
+
PathManager.check_path_owner_consistent(self._msprof_profile_path)
|
|
148
|
+
PathManager.check_path_is_other_writable(self._msprof_profile_path)
|
|
149
|
+
if not PathManager.check_path_is_executable(self._msprof_profile_path):
|
|
150
|
+
raise PermissionError(f"The '{self._msprof_profile_path}' path is not executable."
|
|
151
|
+
f"Please execute chmod -R 755 {self._msprof_profile_path}")
|
|
152
|
+
|
|
153
|
+
def _check_msprof_permission(self, msprof_path):
|
|
154
|
+
"""Check msprof path permissions."""
|
|
155
|
+
msprof_script_path = self._get_msprof_script_path(self._MSPROF_PY_PATH)
|
|
156
|
+
if not msprof_script_path:
|
|
157
|
+
raise FileNotFoundError(
|
|
158
|
+
"Failed to find msprof.py path. Perhaps the permission of the 'msprof' tool is unexecutable. "
|
|
159
|
+
"Please check the CANN environment. You can modify the 'msprof' file to an executable permission "
|
|
160
|
+
"through the chmod method."
|
|
161
|
+
)
|
|
162
|
+
if not PathManager.check_path_is_owner_or_root(msprof_script_path) or \
|
|
163
|
+
not PathManager.check_path_is_owner_or_root(msprof_path):
|
|
164
|
+
raise PermissionError(f"PermissionError, CANN package user id: {os.stat(msprof_path).st_uid}, "
|
|
165
|
+
f"current user id: {os.getuid()}. "
|
|
166
|
+
f"Ensure CANN package user id and current user id consistency")
|
|
167
|
+
if not PathManager.check_path_is_executable(msprof_script_path) or \
|
|
168
|
+
not PathManager.check_path_is_executable(msprof_path):
|
|
169
|
+
raise PermissionError(f"The '{msprof_script_path}' path or '{msprof_path}' path is not executable."
|
|
170
|
+
f"Please execute chmod u+x {msprof_script_path} and "
|
|
171
|
+
f"chmod u+x {msprof_path}")
|
|
172
|
+
PathManager.check_path_is_other_writable(msprof_script_path)
|
|
173
|
+
|
|
139
174
|
def _find_msprof_path(self) -> Optional[str]:
|
|
140
175
|
"""Find msprof path in environment variables.
|
|
141
176
|
|
|
@@ -166,7 +201,8 @@ class MsprofCmdTool:
|
|
|
166
201
|
if not script_path:
|
|
167
202
|
logger.error("Failed to find get_msprof_info.py path.")
|
|
168
203
|
return {}
|
|
169
|
-
|
|
204
|
+
if not PathManager.check_path_is_executable(script_path):
|
|
205
|
+
raise PermissionError(f"The '{script_path}' path is not executable. Please execute chmod u+x {script_path}")
|
|
170
206
|
host_dir = os.path.join(self._msprof_profile_path, "host")
|
|
171
207
|
cmd = ["python3", script_path, "-dir", host_dir]
|
|
172
208
|
command_outs = CommandExecutor.execute(cmd)[0]
|