mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -36,13 +36,12 @@ from functools import partial
36
36
  import math
37
37
  import sys
38
38
  import time
39
- import numpy as np
40
39
  from safetensors.numpy import save_file
40
+ import numpy as np
41
41
  import google
42
42
 
43
43
  from mindspore.train.checkpoint_pb2 import Checkpoint
44
44
  from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
45
- from mindspore.train.print_pb2 import Print
46
45
 
47
46
  import mindspore
48
47
  import mindspore.nn as nn
@@ -55,11 +54,10 @@ from mindspore.common import dtype as mstype
55
54
  from mindspore.common.api import _cell_graph_executor as _executor
56
55
  from mindspore.common.api import _JitExecutor
57
56
  from mindspore.common.api import _get_parameter_layout
58
- from mindspore.common.initializer import initializer, One
57
+ from mindspore.common.initializer import initializer
59
58
  from mindspore.common.parameter import Parameter, _offload_if_config
60
59
  from mindspore.common.tensor import Tensor
61
60
  from mindspore._c_expression import TensorPy as Tensor_
62
- from mindspore.common._utils import is_shape_unknown
63
61
  from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
64
62
  from mindspore.communication.management import get_rank, get_group_size
65
63
  from mindspore.experimental import MapParameter
@@ -75,9 +73,9 @@ from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint
75
73
  from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
76
74
  from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
77
75
  from mindspore.parallel.transform_safetensors import _fast_safe_open
78
- from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
76
+ from mindspore.train._utils import get_parameter_redundancy, _progress_bar, _load_and_transform
79
77
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
80
- split_mindir, split_dynamic_mindir
78
+ split_mindir, split_dynamic_mindir, _get_snapshot_params
81
79
  from mindspore.common.generator import Generator
82
80
 
83
81
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
@@ -416,9 +414,6 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
416
414
  crc_num, crc_check,
417
415
  ckpt_total_io_time)
418
416
  continue
419
- if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
420
- _write_hugeparameter(name, value, f)
421
- continue
422
417
 
423
418
  crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
424
419
  crc_num, crc_check,
@@ -561,27 +556,6 @@ def _write_mapparameter(name, value, f, map_param_inc=False):
561
556
  break
562
557
 
563
558
 
564
- def _write_hugeparameter(name, value, f):
565
- """Write huge parameter into protobuf file."""
566
- slice_num = value[2].slice_num
567
- offset = 0
568
- max_size = value[0][0]
569
- for param_slice in range(slice_num):
570
- checkpoint_list = Checkpoint()
571
- param_value = checkpoint_list.value.add()
572
- param_value.tag = name
573
- param_tensor = param_value.tensor
574
- param_tensor.dims.extend(value[0])
575
- param_tensor.tensor_type = value[1]
576
- param_key = value[3]
577
- numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
578
- if offset + numpy_data.shape[0] > max_size:
579
- numpy_data = numpy_data[:max_size - offset]
580
- param_tensor.tensor_content = numpy_data.tobytes()
581
- f.write(checkpoint_list.SerializeToString())
582
- offset += numpy_data.shape[0]
583
-
584
-
585
559
  def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
586
560
  """Check save_obj and ckpt_file_name for save_checkpoint."""
587
561
  if format not in ["safetensors", "ckpt"]:
@@ -783,9 +757,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
783
757
  data_list[param["name"]].append(param["data"])
784
758
  continue
785
759
  if isinstance(param["data"], list):
786
- if param["data"][0] == "persistent_data":
787
- _save_param_list_data(data_list, key, param)
788
- elif param["data"][0] == "offload_parameter":
760
+ if param["data"][0] == "offload_parameter":
789
761
  data_list[key].append("offload_parameter")
790
762
  _save_param_list_data(data_list, key, param)
791
763
 
@@ -971,6 +943,8 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
971
943
  if not is_parallel_mode:
972
944
  save_obj.init_parameters_data()
973
945
  param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
946
+ enable_ckpt_d2h_sync = os.getenv('MS_ENABLE_D2H_ASYNC') == '1'
947
+ param_snapshot = _get_snapshot_params() if enable_ckpt_d2h_sync else {}
974
948
  for (key, value) in param_dict.items():
975
949
  each_param = {"name": key}
976
950
  if isinstance(value, MapParameter):
@@ -978,10 +952,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
978
952
  param_list.append(each_param)
979
953
  continue
980
954
 
981
- if value.data.is_persistent_data():
982
- # list save persistent_data: [Tensor, shape, type, param.key]
983
- param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
984
- elif value.data.offload_file_path() != "":
955
+ if value.data.offload_file_path() != "":
985
956
  # list save offload data: [Param, shape, type, param.key]
986
957
  param_data = ["offload_parameter"]
987
958
  param_tensor = value.data
@@ -996,7 +967,8 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
996
967
  if append_dict and "__exception_save__" in append_dict:
997
968
  param_data = Tensor(Tensor_.move_to(value, "CPU", False))
998
969
  else:
999
- param_data = Tensor(value.data)
970
+ # when enable MS_ENABLE_D2H_ASYNC=1, fetch param from sanpshot in priority
971
+ param_data = param_snapshot.get(key, Tensor(value.data))
1000
972
 
1001
973
  # in automatic model parallel scenario, some parameters were split to all the devices,
1002
974
  # which should be combined before saving
@@ -1020,13 +992,16 @@ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choi
1020
992
 
1021
993
  return _handle_shared_param_for_pipeline_parallel(save_obj)
1022
994
 
1023
- return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
995
+ if isinstance(save_obj, nn.Cell):
996
+ return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
997
+
998
+ raise TypeError("For 'save_checkpoint', the argument 'save_obj' must be list、dict or nn.cell, "
999
+ "but got {}.".format(type(save_obj)))
1024
1000
 
1025
1001
 
1026
1002
  def _save_param_list_data(data_list, key, param):
1027
1003
  """Save persistent data into save_obj."""
1028
1004
  dims = []
1029
- # persistent_data shape can not be ()
1030
1005
  for dim in param['data'][2]:
1031
1006
  dims.append(dim)
1032
1007
  data_list[key].append(dims)
@@ -1302,7 +1277,6 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1302
1277
  param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
1303
1278
  parameter = Parameter(param_data, name=element.tag)
1304
1279
  parameter_dict[element.tag] = parameter
1305
- _offload_if_config(parameter)
1306
1280
 
1307
1281
  logger.info("Loading checkpoint files process is finished.")
1308
1282
  return remove_redundancy
@@ -2148,6 +2122,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
2148
2122
  if file_format == 'AIR':
2149
2123
  _save_air(net, file_name, *inputs, **kwargs)
2150
2124
  elif file_format == 'ONNX':
2125
+ logger.warning("mindspore.export(file_format='ONNX') will be deleted, please use mindspore.onnx.export()")
2151
2126
  _save_onnx(net, file_name, *inputs, **kwargs)
2152
2127
  elif file_format == 'MINDIR':
2153
2128
  _save_mindir(net, file_name, *inputs, **kwargs)
@@ -2497,147 +2472,6 @@ def _save_dataset_to_mindir(model, dataset):
2497
2472
  model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
2498
2473
 
2499
2474
 
2500
- def check_checkpoint(ckpt_file_name):
2501
- """
2502
- Check whether the checkpoint is valid.
2503
-
2504
- Note:
2505
- The interface is deprecated from version 2.5 and will be removed in a future version.
2506
-
2507
- Args:
2508
- ckpt_file_name (str): Checkpoint file name.
2509
-
2510
- Returns:
2511
- bool, whether the checkpoint is valid.
2512
-
2513
- Examples:
2514
- >>> import mindspore as ms
2515
- >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
2516
- >>> check_result = ms.check_checkpoint(ckpt_file_name)
2517
- >>> print(check_result)
2518
- True
2519
- """
2520
- logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
2521
- "and will be removed in a future version.")
2522
- if not ckpt_file_name.endswith('.ckpt'):
2523
- return False
2524
- checkpoint_list = Checkpoint()
2525
- with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
2526
- pb_content = f.read()
2527
- if pb_content[-17:-10] == b"crc_num":
2528
- crc_num_bytes = pb_content[-10:]
2529
- pb_content = pb_content[:-17]
2530
- crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
2531
- cal_crc_num = binascii.crc32(pb_content, 0)
2532
- if cal_crc_num != crc_num:
2533
- logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
2534
- return False
2535
- try:
2536
- checkpoint_list.ParseFromString(pb_content)
2537
- except google.protobuf.message.DecodeError as e:
2538
- logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
2539
- logger.warning(e)
2540
- return False
2541
- return True
2542
-
2543
-
2544
- def parse_print(print_file_name):
2545
- """
2546
- Parse data file generated by :class:`mindspore.ops.Print`.
2547
-
2548
- Note:
2549
- The interface is deprecated from version 2.5 and will be removed in a future version.
2550
-
2551
- Args:
2552
- print_file_name (str): The file name needs to be parsed.
2553
-
2554
- Returns:
2555
- List, element of list is Tensor.
2556
-
2557
- Raises:
2558
- ValueError: The print file does not exist or is empty.
2559
- RuntimeError: Failed to parse the file.
2560
-
2561
- Examples:
2562
- >>> import numpy as np
2563
- >>> import mindspore as ms
2564
- >>> from mindspore import nn, Tensor, ops
2565
- >>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
2566
- >>> class PrintInputTensor(nn.Cell):
2567
- ... def __init__(self):
2568
- ... super().__init__()
2569
- ... self.print = ops.Print()
2570
- ...
2571
- ... def construct(self, input_pra):
2572
- ... self.print('print:', input_pra)
2573
- ... return input_pra
2574
- >>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
2575
- >>> input_pra = Tensor(x)
2576
- >>> net = PrintInputTensor()
2577
- >>> net(input_pra)
2578
- >>>
2579
- >>> data = ms.parse_print('./log.data')
2580
- >>> print(data)
2581
- ['print:', Tensor(shape=[2, 4], dtype=Float32, value=
2582
- [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2583
- [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2584
- """
2585
- logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
2586
- "and will be removed in a future version.")
2587
- print_file_path = os.path.realpath(print_file_name)
2588
-
2589
- if os.path.getsize(print_file_path) == 0:
2590
- raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
2591
- "'print_file_name'.")
2592
-
2593
- logger.info("Execute load print process.")
2594
- print_list = Print()
2595
-
2596
- try:
2597
- with open(print_file_path, "rb") as f:
2598
- pb_content = f.read()
2599
- print_list.ParseFromString(pb_content)
2600
- except BaseException as e:
2601
- logger.critical("Failed to read the print file %s, please check whether the file is "
2602
- "correct.", print_file_name)
2603
- raise ValueError(e.__str__() + "\nFailed to read the print file {}, please check whether "
2604
- "the file is correct.".format(print_file_name)) from e
2605
-
2606
- tensor_list = []
2607
-
2608
- try:
2609
- for print_ in print_list.value:
2610
- # String type
2611
- if print_.HasField("desc"):
2612
- tensor_list.append(print_.desc)
2613
- elif print_.HasField("tensor"):
2614
- dims = print_.tensor.dims
2615
- data_type = print_.tensor.tensor_type
2616
- data = print_.tensor.tensor_content
2617
- np_type = tensor_to_np_type(data_type)
2618
- param_data = np.fromstring(data, np_type)
2619
- ms_type = tensor_to_ms_type.get(data_type)
2620
- if dims and dims != [0]:
2621
- param_value = param_data.reshape(dims)
2622
- tensor_list.append(Tensor(param_value, ms_type))
2623
- # Scalar type
2624
- else:
2625
- data_type_ = data_type.lower()
2626
- if 'float' in data_type_:
2627
- param_data = float(param_data[0])
2628
- elif 'int' in data_type_:
2629
- param_data = int(param_data[0])
2630
- elif 'bool' in data_type_:
2631
- param_data = bool(param_data[0])
2632
- tensor_list.append(Tensor(param_data, ms_type))
2633
-
2634
- except BaseException as e:
2635
- logger.critical("Failed to load the print file %s.", print_list)
2636
- raise RuntimeError(e.__str__() + "\nFailed to load the print file {}.".format(print_list)) from e
2637
-
2638
- return tensor_list
2639
-
2640
-
2641
2475
  def async_ckpt_thread_status():
2642
2476
  """
2643
2477
  Get the status of asynchronous save checkpoint thread.
@@ -2672,170 +2506,132 @@ def _calculation_net_size(net):
2672
2506
  return data_total
2673
2507
 
2674
2508
 
2675
- def _get_mindir_inputs(file_name):
2509
+ def _load_file_and_convert_name(path, name_map=None, format="ckpt"):
2676
2510
  """
2677
- Get MindIR file's inputs.
2678
-
2679
- Note:
2680
- 1. Parsing encrypted MindIR file is not supported.
2681
- 2. Parsing dynamic shape MindIR file is not supported.
2511
+ Load file, during load convert name by name_map.
2682
2512
 
2683
2513
  Args:
2684
- file_name (str): MindIR file name.
2514
+ path (str): The file path.
2515
+ name_map (dict): Convert the name of parameter by name_map.
2516
+ format (str): The format of the file. Option: 'ckpt', 'safetensors'
2685
2517
 
2686
2518
  Returns:
2687
- Tensor, list(Tensor), the input of MindIR file.
2688
-
2689
- Raises:
2690
- TypeError: If the parameter file_name is not `str`.
2691
- RuntimeError: MindIR's input is not tensor type or has no dims.
2692
-
2693
- Examples:
2694
- >>> input_tensor = get_mindir_inputs("lenet.mindir")
2695
- """
2696
- Validator.check_file_name_by_regular(file_name)
2697
- file_name = os.path.realpath(file_name)
2698
- model = read_proto(file_name)
2699
- input_tensor = []
2700
-
2701
- for ele_input in model.graph.input:
2702
- input_shape = []
2703
- if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
2704
- raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
2705
-
2706
- for ele_shape in ele_input.tensor[0].dims:
2707
- input_shape.append(ele_shape)
2708
- if is_shape_unknown(input_shape):
2709
- raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
2710
-
2711
- mindir_type = ele_input.tensor[0].data_type
2712
- if mindir_type not in mindir_to_tensor_type:
2713
- raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
2714
-
2715
- input_type = mindir_to_tensor_type.get(mindir_type)
2716
- input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
2717
-
2718
- if not input_tensor:
2719
- logger.warning("The MindIR model has no input, return None.")
2720
- return None
2721
- return input_tensor[0] if len(input_tensor) == 1 else input_tensor
2722
-
2723
-
2724
- def convert_model(mindir_file, convert_file, file_format):
2725
- """
2726
- Convert mindir model to other format model. The current version only supports conversion to ONNX models.
2727
-
2728
- Note:
2729
- The interface is deprecated from version 2.5 and will be removed in a future version.
2730
-
2731
- Args:
2732
- mindir_file (str): MindIR file name.
2733
- convert_file (str): Convert model file name.
2734
- file_format (str): Convert model's format, current version only supports "ONNX".
2735
-
2736
- Raises:
2737
- TypeError: If the parameter `mindir_file` is not `str`.
2738
- TypeError: If the parameter `convert_file` is not `str`.
2739
- ValueError: If the parameter `file_format` is not "ONNX".
2740
-
2741
- Examples:
2742
- >>> import mindspore as ms
2743
- >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
2519
+ Dict, key is parameter name, value is a Parameter or string.
2744
2520
  """
2745
- logger.warning("The interface 'mindspore.train.serialization.convert_model' is deprecated from version 2.5 "
2746
- "and will be removed in a future version.")
2747
- Validator.check_file_name_by_regular(mindir_file)
2748
- Validator.check_file_name_by_regular(convert_file)
2749
- if file_format != "ONNX":
2750
- raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
2751
- net_input = _get_mindir_inputs(mindir_file)
2752
- graph = load(mindir_file)
2753
- net = nn.GraphCell(graph)
2754
- if isinstance(net_input, Tensor):
2755
- export(net, net_input, file_name=convert_file, file_format=file_format)
2756
- else:
2757
- export(net, *net_input, file_name=convert_file, file_format=file_format)
2758
-
2759
-
2760
- def _load_ckpt_to_new_name_map(path, name_map=None):
2761
- return _load_and_transform(path, name_map, mindspore.load_checkpoint, None)
2762
-
2521
+ if name_map is not None:
2522
+ load_func = partial(mindspore.load_checkpoint, format=format)
2523
+ return _load_and_transform(path, name_map, load_func)
2763
2524
 
2764
- def _load_sf_to_new_name_map(path, name_map=None):
2765
- load_func = partial(mindspore.load_checkpoint, format="safetensors")
2766
- return _load_and_transform(path, name_map, load_func, None)
2525
+ return mindspore.load_checkpoint(path, format=format)
2767
2526
 
2768
2527
 
2769
2528
  def _process_file(file_info):
2770
- cur_ckpt_path, name_map, save_path, file = file_info
2771
- if name_map is not None:
2772
- param_dict = _load_ckpt_to_new_name_map(cur_ckpt_path, name_map)
2529
+ """Rrocess file."""
2530
+ cur_path, name_map, save_path, file, dst_format = file_info
2531
+ if dst_format == "safetensors":
2532
+ param_dict = _load_file_and_convert_name(cur_path, name_map, format="ckpt")
2533
+ safetensors_filename = file.replace(".ckpt", ".safetensors")
2534
+ dst_file = os.path.join(save_path, safetensors_filename)
2535
+ mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
2773
2536
  else:
2774
- param_dict = mindspore.load_checkpoint(cur_ckpt_path)
2775
- safetensors_filename = file.replace(".ckpt", ".safetensors")
2776
- dst_file = os.path.join(save_path, safetensors_filename)
2777
- mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
2537
+ param_dict = _load_file_and_convert_name(cur_path, name_map, format="safetensors")
2538
+ ckpt_filename = file.replace(".safetensors", ".ckpt")
2539
+ dst_file = os.path.join(save_path, ckpt_filename)
2540
+ mindspore.save_checkpoint(param_dict, dst_file)
2778
2541
 
2779
2542
 
2780
- def _process_file_safetensors(file_info):
2781
- cur_safe_path, name_map, save_path, file = file_info
2782
- if name_map is not None:
2783
- param_dict = _load_sf_to_new_name_map(cur_safe_path, name_map)
2543
+ def _gather_all_tasks(file_path, save_path, file_name_regex, name_map, dst_format="ckpt"):
2544
+ """gather transform rank together"""
2545
+ if dst_format == "ckpt":
2546
+ cur_file_suffix = ".safetensors"
2784
2547
  else:
2785
- param_dict = mindspore.load_checkpoint(cur_safe_path, format="safetensors")
2786
- ckpt_filename = file.replace(".safetensors", ".ckpt")
2787
- dst_file = os.path.join(save_path, ckpt_filename)
2788
- mindspore.save_checkpoint(param_dict, dst_file)
2789
-
2548
+ cur_file_suffix = ".ckpt"
2790
2549
 
2791
- def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
2792
- """gather transform rank together"""
2793
- tasks = []
2550
+ tasks_list = []
2794
2551
  for root, dirs, _ in os.walk(file_path):
2795
2552
  if root != file_path:
2796
2553
  continue
2797
2554
 
2798
2555
  rank_dirs = [d for d in dirs if d.startswith('rank')]
2799
2556
  if not rank_dirs:
2800
- raise ValueError(
2801
- f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
2557
+ if dst_format == "safetensors":
2558
+ raise ValueError(
2559
+ f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}.")
2560
+ if dst_format == "ckpt":
2561
+ raise ValueError(
2562
+ f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}.")
2563
+
2564
+ raise ValueError(f"For '_gather_all_tasks', error args 'format': {dst_format}.")
2802
2565
 
2803
2566
  for rank_dir in rank_dirs:
2804
2567
  rank_dir_path = os.path.join(root, rank_dir)
2805
- dst_root = os.path.join(save_path,
2806
- os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
2568
+ if save_path is not None:
2569
+ dst_root = os.path.join(save_path, os.path.relpath(rank_dir_path, file_path))
2570
+ else:
2571
+ dst_root = rank_dir_path
2572
+
2807
2573
  os.makedirs(dst_root, exist_ok=True)
2808
- tasks.extend(
2809
- (os.path.join(rank_dir_path, file), name_map, dst_root, file)
2810
- for file in os.listdir(rank_dir_path)
2811
- if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
2812
- )
2813
- return tasks
2814
2574
 
2575
+ for file in os.listdir(rank_dir_path):
2576
+ if file.endswith(cur_file_suffix) and (file_name_regex is None or re.search(file_name_regex, file)):
2577
+ tasks_list.append((os.path.join(rank_dir_path, file), name_map, dst_root, file, dst_format))
2815
2578
 
2816
- def _gather_tasks_covert(file_path, save_path, file_name_regex, name_map):
2817
- """gather transform rank together"""
2818
- tasks = []
2819
- for root, dirs, _ in os.walk(file_path):
2820
- if root != file_path:
2821
- continue
2579
+ return tasks_list
2822
2580
 
2823
- rank_dirs = [d for d in dirs if d.startswith('rank')]
2824
- if not rank_dirs:
2825
- raise ValueError(
2826
- f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
2827
2581
 
2828
- for rank_dir in rank_dirs:
2829
- rank_dir_path = os.path.join(root, rank_dir)
2830
- dst_root = os.path.join(save_path,
2831
- os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
2832
- os.makedirs(dst_root, exist_ok=True)
2833
- tasks.extend(
2834
- (os.path.join(rank_dir_path, file), name_map, dst_root, file)
2835
- for file in os.listdir(rank_dir_path)
2836
- if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
2837
- )
2838
- return tasks
2582
+ def _convert_checkpoint_file(file_path, save_path=None, name_map=None, file_name_regex=None,
2583
+ processes_num=1, dst_format="safetensors"):
2584
+ """
2585
+ Converts MindSpore checkpoint files format and saves them to `save_path`.
2586
+ Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
2587
+ used for securely storing Tensors with fast speed (zero copy).
2588
+
2589
+ Args:
2590
+ file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
2591
+ save_path (str, optional): Directory path where safetensors files will be saved. Default: ``None``.
2592
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
2593
+ file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
2594
+ Default: ``None``.
2595
+ processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
2596
+ dst_format (str): dst file format. Default: "safetensors".
2597
+ """
2598
+ if dst_format == "safetensors":
2599
+ src_format = "ckpt"
2600
+ src_file_suffix = ".ckpt"
2601
+ dst_file_suffix = ".safetensors"
2602
+ func_name = "ckpt_to_safetensors"
2603
+ else:
2604
+ src_format = "safetensors"
2605
+ src_file_suffix = ".safetensors"
2606
+ dst_file_suffix = ".ckpt"
2607
+ func_name = "safetensors_to_ckpt"
2608
+ is_dir = os.path.isdir(file_path)
2609
+ is_file = os.path.isfile(file_path)
2610
+ if not is_dir and not is_file:
2611
+ raise ValueError(f"For {func_name}, the input path must be a valid path or file, but got {file_path}")
2612
+ if save_path and os.path.splitext(save_path)[1]:
2613
+ raise ValueError(f"For {func_name}, the save_path must be a directory, but got '{save_path}'")
2614
+ if name_map is not None and not isinstance(name_map, dict):
2615
+ raise ValueError(
2616
+ f"For {func_name}, the type of 'name_map' must be a directory, but got '{type(name_map)}'")
2617
+
2618
+ if is_dir:
2619
+ tasks_list = _gather_all_tasks(file_path, save_path, file_name_regex, name_map, dst_format=dst_format)
2620
+ with mp.Pool(processes=processes_num) as pool:
2621
+ list(_progress_bar(pool.imap(_process_file, tasks_list), total=len(tasks_list)))
2622
+ elif is_file:
2623
+ if not file_path.endswith(src_file_suffix):
2624
+ raise ValueError(f"For {func_name}, the input file must be a {src_file_suffix} file, but got {file_path}")
2625
+ if file_name_regex is not None and not re.findall(file_name_regex, file_path):
2626
+ raise ValueError(f"For {func_name}, the input file does not match the regular expression.")
2627
+ if save_path and not os.path.exists(save_path):
2628
+ os.makedirs(save_path, exist_ok=True)
2629
+
2630
+ param_dict = _load_file_and_convert_name(file_path, name_map, format=src_format)
2631
+
2632
+ file_filename = os.path.basename(file_path).replace(src_file_suffix, dst_file_suffix)
2633
+ dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), file_filename)
2634
+ mindspore.save_checkpoint(param_dict, dst_file, format=dst_format)
2839
2635
 
2840
2636
 
2841
2637
  def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
@@ -2854,11 +2650,11 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
2854
2650
 
2855
2651
  Args:
2856
2652
  file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
2857
- save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
2858
- name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
2653
+ save_path (str, optional): Directory path where safetensors files will be saved. Default: ``None``.
2654
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
2859
2655
  file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
2860
- Defaults: ``None``.
2861
- processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
2656
+ Default: ``None``.
2657
+ processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
2862
2658
  Raises:
2863
2659
  ValueError: If the input path is invalid or the save_path is not a directory,
2864
2660
  or the file_path does not end with '.ckpt'.
@@ -2874,36 +2670,8 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
2874
2670
  >>> namemap = {"lin.weight":"new_name"}
2875
2671
  >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
2876
2672
  """
2877
- is_dir = os.path.isdir(file_path)
2878
- is_file = os.path.isfile(file_path)
2879
- if not is_dir and not is_file:
2880
- raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
2881
- if save_path and os.path.splitext(save_path)[1]:
2882
- raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
2883
- if name_map is not None and not isinstance(name_map, dict):
2884
- raise ValueError(
2885
- f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
2886
-
2887
- if is_dir:
2888
- tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
2889
- with mp.Pool(processes=processes_num) as pool:
2890
- list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
2891
- elif is_file:
2892
- if not file_path.endswith(".ckpt"):
2893
- raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
2894
- if file_name_regex is not None and not re.findall(file_name_regex, file_path):
2895
- raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
2896
- if save_path and not os.path.exists(save_path):
2897
- os.makedirs(save_path, exist_ok=True)
2898
-
2899
- if name_map is not None:
2900
- param_dict = _load_ckpt_to_new_name_map(file_path, name_map)
2901
- else:
2902
- param_dict = mindspore.load_checkpoint(file_path)
2903
-
2904
- safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
2905
- dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
2906
- mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
2673
+ _convert_checkpoint_file(file_path, save_path, name_map,
2674
+ file_name_regex, processes_num, "safetensors")
2907
2675
 
2908
2676
 
2909
2677
  def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
@@ -2918,11 +2686,11 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
2918
2686
 
2919
2687
  Args:
2920
2688
  file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
2921
- save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
2922
- name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
2689
+ save_path (str, optional): Directory path where checkpoint files will be saved. Default: ``None``.
2690
+ name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
2923
2691
  file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
2924
- Defaults: ``None``.
2925
- processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
2692
+ Default: ``None``.
2693
+ processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
2926
2694
 
2927
2695
  Raises:
2928
2696
  ValueError: If the input path is invalid, the save_path is not a directory,
@@ -2939,37 +2707,8 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
2939
2707
  >>> namemap = {"lin.weight":"new_name"}
2940
2708
  >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
2941
2709
  """
2942
- is_dir = os.path.isdir(file_path)
2943
- is_file = os.path.isfile(file_path)
2944
- if not is_dir and not is_file:
2945
- raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
2946
- if save_path and os.path.splitext(save_path)[1]:
2947
- raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
2948
- if name_map is not None and not isinstance(name_map, dict):
2949
- raise ValueError(
2950
- f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
2951
-
2952
- if is_dir:
2953
- tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
2954
- with mp.Pool(processes=processes_num) as pool:
2955
- list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
2956
- elif is_file:
2957
- if not file_path.endswith(".safetensors"):
2958
- raise ValueError(
2959
- f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
2960
- if file_name_regex is not None and not re.findall(file_name_regex, file_path):
2961
- raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
2962
- if save_path and not os.path.exists(save_path):
2963
- os.makedirs(save_path, exist_ok=True)
2964
-
2965
- if name_map is not None:
2966
- param_dict = _load_sf_to_new_name_map(file_path, name_map)
2967
- else:
2968
- param_dict = mindspore.load_checkpoint(file_path, format="safetensors")
2969
-
2970
- ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
2971
- dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
2972
- mindspore.save_checkpoint(param_dict, dst_file)
2710
+ _convert_checkpoint_file(file_path, save_path, name_map,
2711
+ file_name_regex, processes_num, "ckpt")
2973
2712
 
2974
2713
 
2975
2714
  def restore_group_info_list(group_info_file_name):