mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,336 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Checkpoint strategy info"""
16
+ from __future__ import absolute_import
17
+
18
+ __all__ = ["get_strategy_metadata", "get_current_strategy_metadata", "enable_save_strategy_online", \
19
+ "clear_strategy_metadata"]
20
+
21
+ from itertools import chain
22
+ from typing import Sequence, Union, Tuple, List, Dict
23
+ from types import SimpleNamespace
24
+
25
+ import numpy as np
26
+
27
+ from mindspore import log as logger
28
+ from mindspore._c_expression import StrategyInfo
29
+ from mindspore._c_expression import StrategyLayout
30
+ from mindspore.parallel.shard import Layout
31
+
32
+ LayoutInfo = Tuple[Layout, str, str]
33
+ StrOrTuple = Union[str, Tuple["StrOrTuple", ...], List["StrOrTuple"]]
34
+
35
+
36
+ def get_strategy_metadata(network, rank_id=None) -> Dict[int, Dict[str, List[LayoutInfo]]]:
37
+ """
38
+ Get all params strategy info or specific rank strategy info in this cell.
39
+ For more information on layouts, please refer to: :class:`mindspore.parallel.Layout`.
40
+
41
+ Args:
42
+ network (str): The network name.
43
+ rank_id (int, optional): The rank id of the process on which this cell will be launched.
44
+ Defaults to ``None``, which means strategy metadata for all ranks will be returned.
45
+
46
+ Returns:
47
+ Dict. A dictionary containing the parameter slicing strategies for either all ranks or a specific rank.
48
+ The key is `rank_id`, and the value is the slicing strategy for all parameters on that rank.
49
+ Within each rank's strategy, the key is the parameter name, and the value is the slicing strategy.
50
+ If a `rank_id` is specified, the dictionary returns the strategy information for that specific rank.
51
+ Otherwise, it returns the strategy information for all ranks in the network. If not supported, returns None.
52
+
53
+ Examples:
54
+ >>> import mindspore as ms
55
+ >>> from mindspore import nn
56
+ >>> from mindspore.communication import init
57
+ >>> from mindspore.nn.utils import no_init_parameters
58
+ >>> from mindspore.parallel.auto_parallel import AutoParallel
59
+ >>> from mindspore.train import Model
60
+ >>> from mindspore.parallel.strategy import get_strategy_metadata, get_current_strategy_metadata,
61
+ ... enable_save_strategy_online, clear_strategy_metadata
62
+ >>>
63
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
64
+ >>> init()
65
+ >>> ms.set_seed(1)
66
+ >>>
67
+ >>> # Define the network structure of LeNet5. Refer to
68
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
69
+ >>> with no_init_parameters():
70
+ ... net = LeNet5()
71
+ ... optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
72
+ >>>
73
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
74
+ >>> train_net = AutoParallel(net, parallel_mode="semi_auto")
75
+ >>> model = Model(network=train_net, loss_fn=loss, optimizer=optim, metrics=None)
76
+ >>>
77
+ >>> # Create the dataset taking MNIST as an example. Refer to
78
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
79
+ >>> dataset = create_dataset()
80
+ >>>
81
+ >>> enable_save_strategy_online()
82
+ >>> model.train(2, dataset)
83
+ >>>
84
+ >>> global_info = get_strategy_metadata(network=model.train_network)
85
+ >>> rank0_info = get_strategy_metadata(network=model.train_network, rank_id=0)
86
+ >>> local_info = get_current_strategy_metadata(network=model.train_network)
87
+ >>> clear_strategy_metadata()
88
+ """
89
+ return _NetStrategyInfo(network, global_layout=None, local_layout=None).get_rank_layout(rank_id)
90
+
91
+
92
+ def get_current_strategy_metadata(network) -> Dict[int, Dict[str, List[LayoutInfo]]]:
93
+ """
94
+ Get parameters dictionary of cur rank of the network.
95
+
96
+ Args:
97
+ network(str): The network name.
98
+
99
+ Returns:
100
+ Dict. The key is 0 (representing the local rank), and the value is the slicing strategy for all parameters.
101
+ The key within the value represents the parameter name, and the value is the corresponding slicing strategy \
102
+ for that parameter. If not supported, returns None.
103
+ """
104
+ return _NetStrategyInfo(network, global_layout=None, local_layout=None).get_local_rank_layout()
105
+
106
+
107
+ def enable_save_strategy_online():
108
+ """
109
+ Enable save strategy metadata online.
110
+ """
111
+ strategy_layout_handle = StrategyLayout.get_instance()
112
+ if strategy_layout_handle is None:
113
+ raise ValueError("Strategy layout handle is none in parallel_strategy_checkpoint!!!")
114
+ strategy_layout_handle.enable_save_strategy_online()
115
+
116
+
117
+ def clear_strategy_metadata():
118
+ """Clear all saved strategy metadata on the C++ side."""
119
+ strategy_layout_handle = StrategyLayout.get_instance()
120
+ if strategy_layout_handle is None:
121
+ raise ValueError("Strategy layout handle is none in parallel_strategy_checkpoint!!!")
122
+ return strategy_layout_handle.clear_strategy_metadata()
123
+
124
+
125
+ class _NetStrategyInfo:
126
+ """
127
+ Describe the strategy information of a network.
128
+ """
129
+
130
+ def __init__(self, network, global_layout=None, local_layout=None):
131
+ self._network = network
132
+ self._compile_phase = network.compile_phase
133
+ if global_layout is None or local_layout is None:
134
+ layout_handle = self._get_layout_handle()
135
+ global_layout = layout_handle.global_network_layout()
136
+ local_layout = layout_handle.local_network_layout()
137
+ self._raw_global_layout = global_layout
138
+ self._raw_local_layout = local_layout
139
+
140
+ @staticmethod
141
+ def _get_layout_handle():
142
+ """Get strategy handle"""
143
+ layout_handle = StrategyLayout.get_instance()
144
+ if layout_handle is None:
145
+ raise ValueError("Strategy layout handle is none in parallel_strategy_checkpoint!!!")
146
+ return layout_handle
147
+
148
+ def get_rank_layout(self, rank_id=None):
149
+ """Get params of the network, global rank or special rank, interface."""
150
+ raw_global_layout = self._get_valid_layout(self._compile_phase, self._raw_global_layout)
151
+ if raw_global_layout is None:
152
+ return None
153
+ global_layout = self._extract_layout_metadata(raw_global_layout)
154
+ if rank_id is not None:
155
+ cur_rank_layout = {rank_id: global_layout[rank_id]}
156
+ self._layout_to_string(cur_rank_layout)
157
+ return cur_rank_layout
158
+ self._layout_to_string(global_layout)
159
+ return global_layout
160
+
161
+ def get_local_rank_layout(self):
162
+ """Get local rank params of the network, {param_name: param_info[layout]}."""
163
+ raw_local_layout = self._get_valid_layout(self._compile_phase, self._raw_local_layout)
164
+ if raw_local_layout is None:
165
+ return None
166
+ local_layout = self._extract_layout_metadata(raw_local_layout)
167
+ self._layout_to_string(local_layout)
168
+ return local_layout
169
+
170
+ @staticmethod
171
+ def _get_valid_layout(phase, layout_dict):
172
+ """Helper: Validate and extract layout by phase."""
173
+ if not phase:
174
+ return None
175
+ layout = layout_dict.get(phase)
176
+ if not layout or all(not v for v in layout.values()):
177
+ return None
178
+ return layout
179
+
180
+ def _extract_layout_metadata(self, layout: Dict[int, Dict[str, StrategyInfo]]) -> Dict:
181
+ """Return new layout of special network."""
182
+ new_layout = {}
183
+ for rank_id, param_dict in layout.items():
184
+ new_param_info = {}
185
+ for param_name, param_info in param_dict.items():
186
+ new_param_layout = self._layout_process(param_info)
187
+ new_param_info[param_name] = new_param_layout
188
+ new_layout[rank_id] = new_param_info
189
+ return new_layout
190
+
191
+ def _layout_process(self, stra_layout):
192
+ """
193
+ Return the layout list, stra_layout is one of params_info of cur_rank.
194
+ """
195
+ new_dev_mat, counter, new_tensor_map, full_opt_shard = self._get_dev_mat_for_opt_shard(
196
+ stra_layout.opt_weight_shard_size, stra_layout.dev_matrix, stra_layout.tensor_map)
197
+ alphabet = 'abcdefghijklmnopqrstuvwxyz'
198
+ alias_name = [alphabet[i] for i in range(len(new_dev_mat))]
199
+ if stra_layout.opt_weight_shard_size == 0:
200
+ new_tensor_map = tuple(tuple(alias_name[len(alias_name) - idx - 1] if idx != -1 else "None" for idx in sub)
201
+ for sub in new_tensor_map)
202
+ else:
203
+ info = SimpleNamespace(
204
+ new_dev_mat=new_dev_mat,
205
+ new_tensor_map=new_tensor_map,
206
+ full_opt_shard=full_opt_shard,
207
+ counter=counter,
208
+ alias_name=alias_name
209
+ )
210
+ new_tensor_map = self._get_tensor_map_for_opt_shard(info)
211
+ new_tensor_map = self._compact_tensor_map(new_tensor_map)
212
+ new_dev_mat = tuple(new_dev_mat)
213
+ alias_name = tuple(alias_name)
214
+ layout = Layout(new_dev_mat, alias_name, stra_layout.rank_list)
215
+ final_layout = layout(*new_tensor_map)
216
+ logger.debug("The final layout is %s", final_layout.to_dict())
217
+ cur_param_list = [final_layout, stra_layout.tensor_type, stra_layout.tensor_shape]
218
+ return cur_param_list
219
+
220
+ def _get_dev_mat_for_opt_shard(self, opt_shard, dev_mat, tensor_map):
221
+ """generate device matrix for opt shard scenario"""
222
+ if opt_shard == 0:
223
+ return dev_mat, -1, tensor_map, True
224
+ used_dev_num = self._calc_used_dev_num(dev_mat, tensor_map)
225
+ total_dev_num = int(np.prod(np.array(dev_mat)))
226
+ if opt_shard == -1 or used_dev_num * opt_shard == total_dev_num:
227
+ return dev_mat, -1, tensor_map, True
228
+ remain_dev_num = total_dev_num // (used_dev_num * opt_shard)
229
+ used_dev_mat_mask = self._get_used_dev_mat(dev_mat, tensor_map)
230
+ info = SimpleNamespace(
231
+ dev_mat=dev_mat,
232
+ tensor_map=tensor_map,
233
+ counter=-1,
234
+ real_remain_dev_num=1,
235
+ remain_dev_num=remain_dev_num
236
+ )
237
+ for axis, value in enumerate(dev_mat):
238
+ if used_dev_mat_mask[axis]:
239
+ continue
240
+ info.counter = axis
241
+ if info.real_remain_dev_num == info.remain_dev_num:
242
+ return dev_mat, axis, tensor_map, False
243
+ if info.real_remain_dev_num < info.remain_dev_num:
244
+ info.real_remain_dev_num *= value
245
+ continue
246
+ # info.real_remain_dev_num > info.remain_dev_num,split axis.
247
+ return self._split_dev_dim(info)
248
+ if info.real_remain_dev_num == info.remain_dev_num:
249
+ return dev_mat, info.counter, tensor_map, False
250
+ return self._split_dev_dim(info)
251
+
252
+ def _get_tensor_map_for_opt_shard(self, info: SimpleNamespace):
253
+ """generate tensor map for opt shard scenario"""
254
+
255
+ def idx_to_alias(idx):
256
+ return "None" if idx == -1 else info.alias_name[len(info.alias_name) - idx - 1]
257
+
258
+ def entry_to_alias(entry):
259
+ if isinstance(entry, (list, tuple)):
260
+ return tuple(idx_to_alias(i) for i in entry)
261
+ return idx_to_alias(entry)
262
+
263
+ used_dev_mat = self._get_used_dev_mat(info.new_dev_mat, info.new_tensor_map)
264
+ if info.full_opt_shard:
265
+ unused_idx = [len(used_dev_mat) - i - 1 for i, used in enumerate(used_dev_mat) if not used]
266
+ else:
267
+ unused_idx = [len(used_dev_mat) - i - 1 for i, used in enumerate(used_dev_mat) if
268
+ not used and i > info.counter]
269
+ first_entry = info.new_tensor_map[0]
270
+ first_list = list(first_entry) if isinstance(first_entry, (list, tuple)) else [first_entry]
271
+ new_first_list = [dim for dim in first_list + unused_idx if dim != -1]
272
+ first_alias_list = [idx_to_alias(i) for i in new_first_list] or ["None"]
273
+ first_alias = first_alias_list[0] if len(first_alias_list) == 1 else tuple(first_alias_list)
274
+ rest_alias = [entry_to_alias(entry) for entry in info.new_tensor_map[1:]]
275
+ new_tensor_map = tuple([first_alias] + rest_alias)
276
+ return new_tensor_map
277
+
278
+ @staticmethod
279
+ def _split_dev_dim(info: SimpleNamespace):
280
+ """Split the counter dimension of dev_mat and adjust tensor_map."""
281
+ dev_mat = info.dev_mat
282
+ counter = info.counter
283
+ splitted_dev_value = dev_mat[counter]
284
+ new_dev_mat_value_first = info.remain_dev_num // (info.real_remain_dev_num // splitted_dev_value)
285
+ new_dev_mat_value_second = splitted_dev_value // new_dev_mat_value_first
286
+ new_dev_mat = dev_mat[:counter] + [new_dev_mat_value_first, new_dev_mat_value_second] + dev_mat[counter + 1:]
287
+ flag = len(new_dev_mat) - 1 - counter
288
+ new_tensor_map = [[v if v < flag or v == -1 else v + 1 for v in sub] for sub in info.tensor_map]
289
+ return new_dev_mat, counter, new_tensor_map, False
290
+
291
+ @staticmethod
292
+ def _calc_used_dev_num(dev_mat, tensor_map):
293
+ """Count the total number of device nums that have been used."""
294
+ idx_flat = [idx for idx in chain.from_iterable(tensor_map) if idx != -1]
295
+ if not idx_flat:
296
+ return 1
297
+ prod_list = [dev_mat[len(dev_mat) - idx - 1] for idx in idx_flat]
298
+ return int(np.prod(prod_list))
299
+
300
+ @staticmethod
301
+ def _get_used_dev_mat(dev_mat, tensor_map) -> List[bool]:
302
+ """List that records whether the device ID is being used or not."""
303
+ used = set()
304
+ for elem in tensor_map:
305
+ if isinstance(elem, (list, tuple)):
306
+ used.update(i for i in elem if i != -1)
307
+ elif elem != -1:
308
+ used.add(elem)
309
+ return [(len(dev_mat) - i - 1) in used for i in range(len(dev_mat))]
310
+
311
+ @staticmethod
312
+ def _compact_tensor_map(alias_map: Sequence[StrOrTuple]) -> Tuple[StrOrTuple, ...]:
313
+ """Extend tensor map of 'None'."""
314
+
315
+ def _compress(elem: StrOrTuple) -> StrOrTuple:
316
+ if isinstance(elem, (list, tuple)):
317
+ compressed = tuple(_compress(e) for e in elem)
318
+ if len(compressed) == 1:
319
+ return compressed[0]
320
+ if all(x == 'None' for x in compressed):
321
+ return 'None'
322
+ return compressed
323
+ return elem
324
+
325
+ return tuple(_compress(e) for e in alias_map)
326
+
327
+ @staticmethod
328
+ def _layout_to_string(layout_info):
329
+ """Print layout info."""
330
+ for rank_id, param_layout in layout_info.items():
331
+ logger.info("rank_id=%s", rank_id)
332
+ for param_name, cur_param_list in param_layout.items():
333
+ final_layout, param_type, global_shape = cur_param_list
334
+ logger.info("param_name=%s: [param_layout=%s, param_type=%s, global_shape=%s]",
335
+ param_name, final_layout.to_dict(), param_type, global_shape)
336
+ logger.info("\n")
@@ -15,6 +15,7 @@
15
15
  """Transform distributed safetensors"""
16
16
  from __future__ import absolute_import
17
17
 
18
+ import copy
18
19
  import os
19
20
  import sys
20
21
  import glob
@@ -68,6 +69,7 @@ dtype_size = {
68
69
  "F64": 8,
69
70
  }
70
71
  np_dtype_size = {
72
+ "bool": 1,
71
73
  "bool_": 1,
72
74
  "uint8": 1,
73
75
  "int8": 1,
@@ -696,6 +698,8 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
696
698
  else:
697
699
  if transform_param_dict:
698
700
  if output_format == "safetensors":
701
+ if meta_data and "remove_redundancy" in meta_data:
702
+ meta_data["remove_redundancy"] = "False"
699
703
  _save_file_atomically(transform_param_dict, save_file_name, metadata=meta_data)
700
704
  else:
701
705
  transform_param_dict = _load_and_transform(transform_param_dict, None, None,
@@ -765,6 +769,11 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
765
769
  param_type_dict[param_name][src_rank] = str(param.data.dtype)
766
770
  param_total_dict[param_name][src_rank] = param
767
771
  param_attr_dict[param_name][src_rank] = (True, False)
772
+
773
+ ckpt_prefix = os.path.basename(ckpt_prefix)
774
+ if '..' in ckpt_prefix or '/' in ckpt_prefix or '\\' in ckpt_prefix:
775
+ raise ValueError(f"Invalid ckpt_prefix: {ckpt_prefix}. Must not contain path traversal characters.")
776
+
768
777
  for local_rank_id in range(dst_stage_device_num):
769
778
  transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
770
779
  param_attr_dict, src_strategy_list, dst_strategy_list,
@@ -782,6 +791,7 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
782
791
  """
783
792
  Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank.
784
793
  """
794
+ save_safetensor_file_name = os.path.abspath(save_safetensor_file_name)
785
795
  if not isinstance(safetensor_files_map, dict):
786
796
  raise TypeError("The safetensor_files_map should be a dict.")
787
797
  if not isinstance(rank_id, int):
@@ -829,11 +839,84 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
829
839
  _save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
830
840
 
831
841
 
832
- def _extrace_number(file_name):
833
- """get file last two number"""
834
- number_ls = re.findall(r'\d+', file_name)
835
- number_ls = [int(i) for i in number_ls]
836
- return number_ls[-2:]
842
+ def _extract_numbers(s):
843
+ """Extract all numbers from a string and convert them to integers."""
844
+ return [int(num) for num in re.findall(r'\d+', s)]
845
+
846
+
847
+ def _extract_last_two_numbers(file_name):
848
+ """Get the last two numbers from a filename."""
849
+ all_numbers = _extract_numbers(file_name)
850
+ return all_numbers[-2:]
851
+
852
+
853
+ def _find_shortest_file(matched_files, rank_ckpts, new_file_suffix, file_suffix):
854
+ """Find the shortest file from a list of matched files."""
855
+ min_length = min(len(os.path.basename(ckpt)) for ckpt in matched_files)
856
+ shortest_files = [ckpt for ckpt in matched_files if len(os.path.basename(ckpt)) == min_length]
857
+ if len(shortest_files) == 1:
858
+ return shortest_files[0]
859
+ raise ValueError(f"Multiple files with suffix '{file_suffix}' found in {rank_ckpts}. Following MindSpore naming "
860
+ f"rules, searched for files ending with '{new_file_suffix}' but found multiple "
861
+ f"files {matched_files}. Then searched for the shortest filename, but found multiple shortest "
862
+ f"files {shortest_files}. Please set file_suffix to the longest common suffix of all files.")
863
+
864
+
865
+ def _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix):
866
+ """Get the file from a list of matched files."""
867
+ if len(matched) == 1:
868
+ return matched[0]
869
+ if len(matched) > 1:
870
+ return _find_shortest_file(matched, rank_ckpts, new_file_suffix, file_suffix)
871
+ raise ValueError(f"Multiple files with suffix '{file_suffix}' found in {rank_ckpts}. Following MindSpore naming "
872
+ f"rules, searched for files ending with '{new_file_suffix}' but found zero files. "
873
+ f"Please set file_suffix to the longest common suffix of all files.")
874
+
875
+
876
+ def _find_most_matching_file(rank_ckpts, file_suffix, format):
877
+ """Finds the most matching checkpoint file based on the file_suffix."""
878
+ if file_suffix is None:
879
+ rank_ckpts.sort(key=_extract_last_two_numbers)
880
+ return rank_ckpts[-1]
881
+
882
+ new_file_suffix = file_suffix
883
+ pattern1 = rf'^_(\d+)-(\d+)_(\d+)$'
884
+ matches1 = re.search(pattern1, file_suffix)
885
+ pattern2 = rf'^(\d+)-(\d+)_(\d+)$'
886
+ matches2 = re.search(pattern2, file_suffix)
887
+ # Pattern matching for _{task_id}-{epoch}_{step} format (e.g., _1-10_100 or 1-10_100)
888
+ if matches1 is not None or matches2 is not None:
889
+ if matches2 is not None:
890
+ new_file_suffix = "_" + new_file_suffix
891
+ matched = [ckpt for ckpt in rank_ckpts if ckpt.endswith(f"{new_file_suffix}.{format}") and
892
+ not ckpt.endswith(f"rank{new_file_suffix}.{format}")]
893
+ return _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix)
894
+
895
+ pattern3 = rf'^-(\d+)_(\d+)$'
896
+ matches3 = re.search(pattern3, file_suffix)
897
+ pattern4 = rf'^(\d+)_(\d+)$'
898
+ matches4 = re.search(pattern4, file_suffix)
899
+ # Pattern matching for -{epoch}_{step} format (e.g., -10_100 or 10_100)
900
+ if matches3 is not None or matches4 is not None:
901
+ if matches4 is not None:
902
+ new_file_suffix = "-" + new_file_suffix
903
+ matched = [ckpt for ckpt in rank_ckpts if ckpt.endswith(f"{new_file_suffix}.{format}")]
904
+ return _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix)
905
+
906
+ pattern5 = rf'^_(\d+)$'
907
+ matches5 = re.search(pattern5, file_suffix)
908
+ pattern6 = rf'^(\d+)$'
909
+ matches6 = re.search(pattern6, file_suffix)
910
+ # Pattern matching for _{step} format (e.g., _100 or 100)
911
+ if matches5 is not None or matches6 is not None:
912
+ if matches6 is not None:
913
+ new_file_suffix = "_" + new_file_suffix
914
+ matched = [ckpt for ckpt in rank_ckpts if ckpt.endswith(f"{new_file_suffix}.{format}")]
915
+ return _get_matched_file(matched, rank_ckpts, new_file_suffix, file_suffix)
916
+
917
+ raise ValueError(f"Multiple {format} files ending with '{file_suffix}' found in {rank_ckpts}. "
918
+ f"Cannot determine which file is the intended one. "
919
+ f"Please set file_suffix to the longest common suffix.")
837
920
 
838
921
 
839
922
  def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
@@ -844,6 +927,9 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
844
927
  return {0: src_safetensors_dir}
845
928
  safetensors_rank_dir_list = os.path.join(src_safetensors_dir, "rank_[0-9]*")
846
929
  all_safetensor_files_map = {}
930
+ multiple_files_found_flag = False
931
+ multiple_files_list = None
932
+ chosen_file = None
847
933
  for safetensor_dir in glob.glob(safetensors_rank_dir_list):
848
934
  if not os.path.isdir(safetensor_dir):
849
935
  ms.log.warning("{} is not a directory.".format(safetensor_dir))
@@ -859,9 +945,23 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
859
945
  else:
860
946
  safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
861
947
  rank_ckpts = glob.glob(safetensor_file_name)
862
- rank_ckpts.sort(key=_extrace_number)
863
- if rank_ckpts:
864
- all_safetensor_files_map[rank_id] = rank_ckpts[-1]
948
+ if len(rank_ckpts) > 1:
949
+ all_safetensor_files_map[rank_id] = _find_most_matching_file(rank_ckpts, file_suffix, format)
950
+ if not multiple_files_found_flag:
951
+ multiple_files_found_flag = True
952
+ multiple_files_list = copy.deepcopy(rank_ckpts)
953
+ chosen_file = all_safetensor_files_map[rank_id]
954
+ elif rank_ckpts:
955
+ all_safetensor_files_map[rank_id] = rank_ckpts[0]
956
+ elif file_suffix is not None:
957
+ raise ValueError(f"No safetensors files found in directory '{safetensor_dir}' "
958
+ f"with suffix '{file_suffix}' and format '{format}'. "
959
+ f"Please verify the directory contains the expected files. "
960
+ f"Recommend setting file_suffix to the longest common suffix.")
961
+ if file_suffix is not None and multiple_files_found_flag:
962
+ logger.warning(f"When unified_safetensors files with file_suffix `{file_suffix}`, multiple files were found. "
963
+ f"Showing one list: {multiple_files_list}; selected `{chosen_file}` from it. "
964
+ f"Please check whether the file_suffix is set correctly.")
865
965
  return all_safetensor_files_map
866
966
 
867
967
 
@@ -978,7 +1078,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
978
1078
  def _cal_param_size(shape, dtype):
979
1079
  """cal param size by dtype and shape"""
980
1080
  num_elements = math.prod(shape)
981
- element_size = np_dtype_size.get(dtype, 4)
1081
+ element_size = np_dtype_size.get(str(dtype), 4)
982
1082
  total_bytes = num_elements * element_size
983
1083
  return total_bytes
984
1084
 
@@ -1141,7 +1241,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
1141
1241
  if os.path.isfile(src_dir):
1142
1242
  raise ValueError("For 'unified_safetensors', the 'src_dir' can not be a file.")
1143
1243
  all_safetensor_files_map = _collect_safetensor_files(src_dir, format="safetensors", file_suffix=file_suffix)
1144
- all_ckpt_files_map = _collect_safetensor_files(src_dir, format="ckpt", file_suffix=file_suffix)
1244
+ all_ckpt_files_map = _collect_safetensor_files(src_dir, format="ckpt")
1145
1245
  if all_safetensor_files_map and all_ckpt_files_map:
1146
1246
  raise ValueError("For 'unified_safetensors', the 'src_dir' cannot contain "
1147
1247
  "both ckpt file and safetensors file simultaneously")
@@ -1179,11 +1279,6 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
1179
1279
  with _fast_safe_open(file_name, framework="np") as f:
1180
1280
  for k in f.keys():
1181
1281
  if k in name_list:
1182
- py_slice = f.get_tensor(k)
1183
- param_total_size += _cal_param_size(py_slice.shape, py_slice.dtype)
1184
- param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
1185
- # Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
1186
- param_dst_shape = [int(item) for item in param_dst_shape]
1187
1282
  if choice_func is not None:
1188
1283
  choice_out = choice_func(k)
1189
1284
  if isinstance(choice_out, bool):
@@ -1191,7 +1286,13 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
1191
1286
  name_list.remove(k)
1192
1287
  continue
1193
1288
  if k not in param_size_dict:
1194
- param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.dtype)
1289
+ py_slice = f.get_tensor(k)
1290
+ param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
1291
+ # Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
1292
+ param_dst_shape = [int(item) for item in param_dst_shape]
1293
+ param_size = _cal_param_size(param_dst_shape, py_slice.dtype)
1294
+ param_total_size += param_size
1295
+ param_size_dict[k] = param_size
1195
1296
  split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
1196
1297
  split_num = min(split_num, len(name_list))
1197
1298
  split_list = _split_weight_dict(param_size_dict, split_num)
@@ -248,4 +248,7 @@ def _get_step_id_by_ts(ts: Decimal, step_events_dict: dict):
248
248
  if st <= ts <= et:
249
249
  return step_id
250
250
 
251
+ if step_events_dict:
252
+ return list(step_events_dict.keys())[-1]
253
+
251
254
  return None
@@ -500,7 +500,7 @@ class BottleneckAnalyzer:
500
500
  in_op_id, out_q = self._get_non_inline_child_recur(op_id), self.queue_utilization_pct[op_id]
501
501
  # This is a leaf node since input queue does not exist and output queue exists
502
502
  if in_op_id == self.op_id_not_exist and out_q != self.queue_usage_not_exist:
503
- if out_q < self._THRESHOLDS['_LEAF_OUTPUT_QUEUE_EMPTY_FREQ_PCT_MAXIMUM']:
503
+ if out_q <= self._THRESHOLDS['_LEAF_OUTPUT_QUEUE_EMPTY_FREQ_PCT_MAXIMUM']:
504
504
  queue_usage_analysis.append(self._format_leaf_node_suggestion(op_id, out_q))
505
505
  # This is device_queue op
506
506
  elif self.op_names[op_id] == "DeviceQueue" and in_op_id != self.op_id_not_exist:
@@ -226,3 +226,8 @@ class HostSystem(Enum):
226
226
  DISK = "disk"
227
227
  NETWORK = "network"
228
228
  OSRT = "osrt"
229
+
230
+
231
+ class MsprofModeName:
232
+ """msprof mode name"""
233
+ MSPROF_DYNAMIC_ENV = "PROFILING_MODE"
@@ -206,3 +206,12 @@ class FileManager:
206
206
  if file_name.startswith(start_name) and file_name.endswith(".csv"):
207
207
  file_list.append(os.path.join(source_path, file_name))
208
208
  return file_list
209
+
210
+ @classmethod
211
+ def check_file_owner(cls, path):
212
+ """Check whether the file owner is the current user or root."""
213
+ stat_info = os.stat(path)
214
+ if stat_info.st_uid == 0:
215
+ return True
216
+ current_uid = os.geteuid()
217
+ return current_uid == stat_info.st_uid
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional
22
22
  from mindspore import log as logger
23
23
  from mindspore.profiler.common.command_executor import CommandExecutor
24
24
  from mindspore.profiler.common.constant import ExportType
25
+ from mindspore.profiler.common.path_manager import PathManager
25
26
 
26
27
 
27
28
  class MsprofCmdTool:
@@ -120,6 +121,7 @@ class MsprofCmdTool:
120
121
  Raises:
121
122
  FileNotFoundError: If msprof or python3 command is not found.
122
123
  """
124
+ self._check_msprof_profile_path_is_valid()
123
125
  if not shutil.which(self._MSPROF_CMD):
124
126
  logger.warning(
125
127
  "The msprof command is not found in PATH. Searching in environment variables..."
@@ -131,11 +133,44 @@ class MsprofCmdTool:
131
133
  logger.info("Successfully added msprof command to PATH.")
132
134
  else:
133
135
  raise FileNotFoundError("Failed to find msprof command in environment.")
134
-
136
+ else:
137
+ msprof_path = shutil.which(self._MSPROF_CMD)
138
+ self._check_msprof_permission(msprof_path)
135
139
  if not shutil.which("python3"):
136
140
  logger.warning("Failed to find python3 command in environment.")
137
141
  raise FileNotFoundError("Failed to find python3 command in environment.")
138
142
 
143
+ def _check_msprof_profile_path_is_valid(self):
144
+ """Check msprof profiler path is invalid."""
145
+ PathManager.check_directory_path_readable(self._msprof_profile_path)
146
+ PathManager.check_directory_path_writeable(self._msprof_profile_path)
147
+ PathManager.check_path_owner_consistent(self._msprof_profile_path)
148
+ PathManager.check_path_is_other_writable(self._msprof_profile_path)
149
+ if not PathManager.check_path_is_executable(self._msprof_profile_path):
150
+ raise PermissionError(f"The '{self._msprof_profile_path}' path is not executable."
151
+ f"Please execute chmod -R 755 {self._msprof_profile_path}")
152
+
153
+ def _check_msprof_permission(self, msprof_path):
154
+ """Check msprof path permissions."""
155
+ msprof_script_path = self._get_msprof_script_path(self._MSPROF_PY_PATH)
156
+ if not msprof_script_path:
157
+ raise FileNotFoundError(
158
+ "Failed to find msprof.py path. Perhaps the permission of the 'msprof' tool is unexecutable. "
159
+ "Please check the CANN environment. You can modify the 'msprof' file to an executable permission "
160
+ "through the chmod method."
161
+ )
162
+ if not PathManager.check_path_is_owner_or_root(msprof_script_path) or \
163
+ not PathManager.check_path_is_owner_or_root(msprof_path):
164
+ raise PermissionError(f"PermissionError, CANN package user id: {os.stat(msprof_path).st_uid}, "
165
+ f"current user id: {os.getuid()}. "
166
+ f"Ensure CANN package user id and current user id consistency")
167
+ if not PathManager.check_path_is_executable(msprof_script_path) or \
168
+ not PathManager.check_path_is_executable(msprof_path):
169
+ raise PermissionError(f"The '{msprof_script_path}' path or '{msprof_path}' path is not executable."
170
+ f"Please execute chmod u+x {msprof_script_path} and "
171
+ f"chmod u+x {msprof_path}")
172
+ PathManager.check_path_is_other_writable(msprof_script_path)
173
+
139
174
  def _find_msprof_path(self) -> Optional[str]:
140
175
  """Find msprof path in environment variables.
141
176
 
@@ -166,7 +201,8 @@ class MsprofCmdTool:
166
201
  if not script_path:
167
202
  logger.error("Failed to find get_msprof_info.py path.")
168
203
  return {}
169
-
204
+ if not PathManager.check_path_is_executable(script_path):
205
+ raise PermissionError(f"The '{script_path}' path is not executable. Please execute chmod u+x {script_path}")
170
206
  host_dir = os.path.join(self._msprof_profile_path, "host")
171
207
  cmd = ["python3", script_path, "-dir", host_dir]
172
208
  command_outs = CommandExecutor.execute(cmd)[0]