mindspore 2.6.0__cp310-cp310-win_amd64.whl → 2.7.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (455) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +64 -83
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +47 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +177 -52
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +338 -208
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +2 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +84 -133
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +47 -38
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +69 -23
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +5 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +1 -0
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +4 -44
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +425 -19
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +125 -101
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +488 -620
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +86 -85
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +2 -4
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/lamb.py +1 -3
  188. mindspore/nn/optim/optimizer.py +1 -1
  189. mindspore/nn/optim/tft_wrapper.py +2 -3
  190. mindspore/nn/optim/thor.py +2 -2
  191. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  192. mindspore/nn/probability/distribution/exponential.py +2 -1
  193. mindspore/nn/probability/distribution/poisson.py +2 -1
  194. mindspore/nn/sparse/sparse.py +3 -3
  195. mindspore/nn/wrap/cell_wrapper.py +73 -42
  196. mindspore/nn/wrap/grad_reducer.py +37 -52
  197. mindspore/nn/wrap/loss_scale.py +72 -74
  198. mindspore/numpy/array_creations.py +7 -7
  199. mindspore/numpy/fft.py +1 -1
  200. mindspore/numpy/math_ops.py +1 -1
  201. mindspore/numpy/utils_const.py +1 -1
  202. mindspore/opencv_core452.dll +0 -0
  203. mindspore/opencv_imgcodecs452.dll +0 -0
  204. mindspore/opencv_imgproc452.dll +0 -0
  205. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  206. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  207. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  208. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  209. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  210. mindspore/ops/_vmap/vmap_array_ops.py +6 -13
  211. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  212. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +29 -10
  213. mindspore/ops/auto_generate/gen_extend_func.py +5 -55
  214. mindspore/ops/auto_generate/gen_ops_def.py +753 -273
  215. mindspore/ops/auto_generate/gen_ops_prim.py +1687 -958
  216. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  217. mindspore/ops/composite/__init__.py +10 -0
  218. mindspore/ops/composite/base.py +9 -5
  219. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  220. mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
  221. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  222. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  223. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  224. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  225. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  226. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  227. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  228. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  229. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  230. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  231. mindspore/ops/function/__init__.py +4 -1
  232. mindspore/ops/function/_add_attr_func.py +11 -6
  233. mindspore/ops/function/array_func.py +17 -100
  234. mindspore/ops/function/debug_func.py +8 -5
  235. mindspore/ops/function/grad/grad_func.py +5 -13
  236. mindspore/ops/function/math_func.py +65 -399
  237. mindspore/ops/function/nn_func.py +44 -61
  238. mindspore/ops/function/other_func.py +4 -1
  239. mindspore/ops/function/random_func.py +31 -4
  240. mindspore/ops/functional.py +2 -3
  241. mindspore/ops/functional_overload.py +486 -18
  242. mindspore/ops/op_info_register.py +21 -0
  243. mindspore/ops/operations/__init__.py +5 -2
  244. mindspore/ops/operations/_custom_ops_utils.py +675 -8
  245. mindspore/ops/operations/_inner_ops.py +14 -18
  246. mindspore/ops/operations/_sequence_ops.py +1 -1
  247. mindspore/ops/operations/array_ops.py +4 -50
  248. mindspore/ops/operations/comm_ops.py +186 -41
  249. mindspore/ops/operations/custom_ops.py +244 -175
  250. mindspore/ops/operations/debug_ops.py +55 -4
  251. mindspore/ops/operations/image_ops.py +13 -13
  252. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  253. mindspore/ops/operations/math_ops.py +8 -9
  254. mindspore/ops/operations/nn_ops.py +6 -7
  255. mindspore/ops/primitive.py +9 -20
  256. mindspore/ops/tensor_method.py +52 -11
  257. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  258. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  259. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  260. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  261. mindspore/ops_generate/common/base_generator.py +14 -0
  262. mindspore/ops_generate/common/gen_constants.py +7 -2
  263. mindspore/ops_generate/common/gen_utils.py +0 -19
  264. mindspore/ops_generate/common/op_proto.py +11 -4
  265. mindspore/ops_generate/common/template.py +88 -11
  266. mindspore/ops_generate/gen_ops.py +1 -1
  267. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  268. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  269. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  270. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  271. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  272. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  273. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  274. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  275. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  276. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  277. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  278. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  279. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  280. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  281. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  282. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  283. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  284. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  285. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  286. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  287. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  288. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  289. mindspore/parallel/_auto_parallel_context.py +9 -17
  290. mindspore/parallel/_cell_wrapper.py +106 -40
  291. mindspore/parallel/_parallel_serialization.py +4 -3
  292. mindspore/parallel/_ps_context.py +4 -6
  293. mindspore/parallel/_tensor.py +167 -12
  294. mindspore/parallel/_transformer/moe.py +1 -1
  295. mindspore/parallel/_transformer/transformer.py +17 -12
  296. mindspore/parallel/_utils.py +5 -11
  297. mindspore/parallel/auto_parallel.py +33 -12
  298. mindspore/parallel/checkpoint_convert.py +3 -3
  299. mindspore/parallel/checkpoint_transform.py +5 -1
  300. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  301. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  302. mindspore/parallel/cluster/run.py +48 -7
  303. mindspore/parallel/function/__init__.py +8 -1
  304. mindspore/parallel/function/reshard_func.py +7 -6
  305. mindspore/parallel/nn/__init__.py +15 -2
  306. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  307. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  308. mindspore/parallel/shard.py +9 -23
  309. mindspore/parallel/transform_safetensors.py +468 -174
  310. mindspore/pgodb140.dll +0 -0
  311. mindspore/pgort140.dll +0 -0
  312. mindspore/profiler/__init__.py +2 -1
  313. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  314. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  315. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
  316. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  317. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  318. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  319. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  321. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  322. mindspore/profiler/analysis/task_manager.py +1 -1
  323. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  324. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  325. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  326. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  327. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  328. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  329. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  330. mindspore/profiler/common/constant.py +16 -0
  331. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  332. mindspore/profiler/common/path_manager.py +9 -0
  333. mindspore/profiler/common/profiler_context.py +50 -29
  334. mindspore/profiler/common/profiler_info.py +0 -16
  335. mindspore/profiler/common/profiler_meta_data.py +1 -0
  336. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  337. mindspore/profiler/common/profiler_output_path.py +23 -8
  338. mindspore/profiler/common/profiler_parameters.py +128 -35
  339. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  340. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  341. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  342. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  343. mindspore/profiler/dynamic_profiler.py +374 -338
  344. mindspore/profiler/envprofiler.py +42 -12
  345. mindspore/profiler/experimental_config.py +112 -7
  346. mindspore/profiler/mstx.py +33 -12
  347. mindspore/profiler/platform/__init__.py +2 -3
  348. mindspore/profiler/platform/cpu_profiler.py +10 -4
  349. mindspore/profiler/platform/npu_profiler.py +30 -20
  350. mindspore/profiler/profiler.py +218 -154
  351. mindspore/profiler/profiler_action_controller.py +65 -77
  352. mindspore/profiler/profiler_interface.py +2 -2
  353. mindspore/profiler/schedule.py +10 -4
  354. mindspore/rewrite/common/config.py +1 -0
  355. mindspore/rewrite/common/namer.py +1 -0
  356. mindspore/rewrite/common/namespace.py +1 -0
  357. mindspore/rewrite/node/node.py +31 -11
  358. mindspore/rewrite/parsers/assign_parser.py +1 -1
  359. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  360. mindspore/run_check/_check_version.py +7 -10
  361. mindspore/runtime/__init__.py +8 -6
  362. mindspore/runtime/event.py +10 -4
  363. mindspore/runtime/executor.py +87 -45
  364. mindspore/runtime/memory.py +22 -30
  365. mindspore/runtime/thread_bind_core.py +299 -165
  366. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  367. mindspore/swresample-4.dll +0 -0
  368. mindspore/swscale-6.dll +0 -0
  369. mindspore/tbbmalloc.dll +0 -0
  370. mindspore/tinyxml2.dll +0 -0
  371. mindspore/train/_utils.py +9 -5
  372. mindspore/train/amp.py +43 -23
  373. mindspore/train/callback/__init__.py +5 -5
  374. mindspore/train/callback/_callback.py +2 -1
  375. mindspore/train/callback/_checkpoint.py +4 -14
  376. mindspore/train/callback/_flops_collector.py +11 -7
  377. mindspore/train/callback/_landscape.py +0 -1
  378. mindspore/train/callback/_train_fault_tolerance.py +72 -18
  379. mindspore/train/data_sink.py +15 -6
  380. mindspore/train/dataset_helper.py +14 -5
  381. mindspore/train/model.py +49 -47
  382. mindspore/train/serialization.py +168 -126
  383. mindspore/train/summary/summary_record.py +13 -2
  384. mindspore/train/train_thor/model_thor.py +2 -2
  385. mindspore/turbojpeg.dll +0 -0
  386. mindspore/utils/__init__.py +3 -2
  387. mindspore/utils/dryrun.py +0 -6
  388. mindspore/utils/runtime_execution_order_check.py +162 -78
  389. mindspore/utils/sdc_detect.py +68 -0
  390. mindspore/utils/utils.py +14 -17
  391. mindspore/vcmeta.dll +0 -0
  392. mindspore/vcruntime140.dll +0 -0
  393. mindspore/vcruntime140_1.dll +0 -0
  394. mindspore/version.py +1 -1
  395. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  396. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/RECORD +400 -439
  397. mindspore/_deprecated/jit.py +0 -198
  398. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  399. mindspore/communication/_hccl_management.py +0 -297
  400. mindspore/experimental/es/embedding_service.py +0 -891
  401. mindspore/experimental/es/embedding_service_layer.py +0 -581
  402. mindspore/profiler/common/validator/__init__.py +0 -14
  403. mindspore/profiler/common/validator/validate_path.py +0 -84
  404. mindspore/profiler/parser/__init__.py +0 -14
  405. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  406. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  407. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  408. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  409. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  410. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  411. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  412. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  413. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  414. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  415. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  416. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  417. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  418. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  419. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  420. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  421. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  422. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  423. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  424. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  425. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  426. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  427. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  428. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  429. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  430. mindspore/profiler/parser/container.py +0 -229
  431. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  432. mindspore/profiler/parser/flops_parser.py +0 -531
  433. mindspore/profiler/parser/framework_enum.py +0 -111
  434. mindspore/profiler/parser/framework_parser.py +0 -464
  435. mindspore/profiler/parser/framework_struct.py +0 -61
  436. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  437. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  438. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  439. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  440. mindspore/profiler/parser/hccl_parser.py +0 -573
  441. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  442. mindspore/profiler/parser/integrator.py +0 -526
  443. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  444. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  445. mindspore/profiler/parser/minddata_parser.py +0 -186
  446. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  447. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  448. mindspore/profiler/parser/optime_parser.py +0 -250
  449. mindspore/profiler/parser/profiler_info.py +0 -213
  450. mindspore/profiler/parser/step_trace_parser.py +0 -666
  451. mindspore/utils/hooks.py +0 -81
  452. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  453. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  454. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  455. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
@@ -21,27 +21,269 @@ import glob
21
21
  import math
22
22
  import json
23
23
  import re
24
- from collections import defaultdict
24
+ import mmap
25
+ import stat
26
+ from collections import defaultdict, OrderedDict
25
27
 
26
28
  import time
27
29
  import multiprocessing as mp
30
+
31
+ from safetensors.numpy import save_file, load_file
28
32
  import psutil
29
33
  import numpy as np
30
- from safetensors.numpy import save_file, load_file
31
- from safetensors import safe_open
32
34
 
33
35
  import mindspore as ms
34
36
  from mindspore import log as logger
35
37
  from mindspore.log import vlog_print
38
+ from mindspore.common.parameter import Parameter
39
+ from mindspore.common.tensor import Tensor
36
40
  from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
37
41
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
38
42
  _insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
39
43
  from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
40
44
  _get_needed_rank_transform_operator_map_by_layouts, \
41
45
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
42
- _extract_layout_item, _load_tensor_shape, _apply_operator
46
+ _extract_layout_item, _apply_operator
43
47
  from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
44
48
  _convert_to_list
49
+ from mindspore.common import dtype as mstype
50
+
51
+ safetensors_to_mstype = {'Int4': mstype.qint4x2}
52
+
53
+ MAX_HEADER_SIZE = 100 * 1000 * 1000
54
+
55
+ dtype_size = {
56
+ "BOOL": 1,
57
+ "U8": 1,
58
+ "I8": 1,
59
+ "I16": 2,
60
+ "U16": 2,
61
+ "I32": 4,
62
+ "U32": 4,
63
+ "I64": 8,
64
+ "U64": 8,
65
+ "F16": 2,
66
+ "BF16": 2,
67
+ "F32": 4,
68
+ "F64": 8,
69
+ }
70
+ np_dtype_size = {
71
+ "bool_": 1,
72
+ "uint8": 1,
73
+ "int8": 1,
74
+ "int16": 2,
75
+ "uint16": 2,
76
+ "int32": 4,
77
+ "uint32": 4,
78
+ "int64": 8,
79
+ "uint64": 8,
80
+ "float16": 2,
81
+ "bfloat16": 2,
82
+ "float32": 4,
83
+ "float64": 8,
84
+ }
85
+ numpy_dtype = {
86
+ "BOOL": np.bool_,
87
+ "U8": np.uint8,
88
+ "I8": np.int8,
89
+ "I16": np.int16,
90
+ "U16": np.uint16,
91
+ "I32": np.int32,
92
+ "U32": np.uint32,
93
+ "I64": np.int64,
94
+ "U64": np.uint64,
95
+ "F16": np.float16,
96
+ "F32": np.float32,
97
+ "F64": np.float64,
98
+ }
99
+
100
+
101
+ def getSize(fileobject):
102
+ fileobject.seek(0, 2) # move the cursor to the end of the file
103
+ size = fileobject.tell()
104
+ fileobject.seek(0) # move the cursor to the start of the file
105
+ return size
106
+
107
+
108
+ def _save_file_atomically(transform_param_dict, save_file_name, metadata=None):
109
+ """Atomically save file using temporary name and rename."""
110
+ if metadata is None:
111
+ metadata = {"format": "ms"}
112
+ file_name_list = list(os.path.splitext(save_file_name))
113
+ file_name_list[1] = file_name_list[1].replace('.safetensors', '.tmp')
114
+ tmp_name = ''.join(file_name_list)
115
+ try:
116
+ if os.path.exists(save_file_name):
117
+ os.chmod(save_file_name, stat.S_IWUSR)
118
+ os.remove(save_file_name)
119
+ if os.path.exists(tmp_name):
120
+ os.chmod(tmp_name, stat.S_IWUSR)
121
+ os.remove(tmp_name)
122
+ save_file(transform_param_dict, tmp_name, metadata=metadata)
123
+ os.rename(tmp_name, save_file_name)
124
+ os.chmod(save_file_name, stat.S_IRUSR)
125
+ except Exception as e:
126
+ if not os.path.exists(save_file_name):
127
+ logger.warning(f"Save failed, {save_file_name} not found. "
128
+ f"This may indicate multiple processes modifying the same file "
129
+ f"or insufficient disk space.")
130
+ raise e
131
+
132
+
133
+ def metadata_validate(metadata):
134
+ """validation metadata"""
135
+ start = 0
136
+ for key, info in metadata.items():
137
+ s, e = info["data_offsets"]
138
+ if s != start or e < s:
139
+ raise ValueError(f"SafeTensorError::InvalidOffset({key})")
140
+ start = e
141
+ nelements = np.prod(info["shape"])
142
+ nbytes = nelements * dtype_size[info["dtype"]]
143
+ if (e - s) != nbytes:
144
+ raise ValueError("SafeTensorError::TensorInvalidInfo")
145
+ return start
146
+
147
+
148
+ def read_metadata(buffer):
149
+ """read metadata by buffer"""
150
+ buffer_len = getSize(buffer)
151
+ if buffer_len < 8:
152
+ raise ValueError("SafeTensorError::HeaderTooSmall")
153
+
154
+ n = np.frombuffer(buffer.read(8), dtype=np.uint64).item()
155
+ if n > MAX_HEADER_SIZE:
156
+ raise ValueError("SafeTensorError::HeaderTooLarge")
157
+
158
+ stop = n + 8
159
+ if stop > buffer_len:
160
+ raise ValueError("SafeTensorError::InvalidHeaderLength")
161
+
162
+ tensors = json.loads(buffer.read(n), object_pairs_hook=OrderedDict)
163
+ metadata = tensors.pop("__metadata__", None)
164
+ buffer_end = metadata_validate(tensors)
165
+
166
+ if buffer_end + 8 + n != buffer_len:
167
+ raise ValueError("SafeTensorError::MetadataIncompleteBuffer")
168
+
169
+ return stop, tensors, metadata
170
+
171
+
172
+ class PySafeSlice:
173
+ """Create PySafeSlice by file"""
174
+
175
+ def __init__(self, info, bufferfile, base_ptr, buffermmap):
176
+ self.info = info
177
+ self.bufferfile = bufferfile
178
+ self.buffermmap = buffermmap
179
+ self.base_ptr = base_ptr
180
+
181
+ self.start = [0 for dim in self.shape]
182
+ self.stop = [dim for dim in self.shape]
183
+ self.step = [1 for dim in self.shape]
184
+
185
+ @property
186
+ def ndim(self):
187
+ return len(self.shape)
188
+
189
+ def get(self, *args, **kwargs):
190
+ """Get tensor from buffer by data_offset"""
191
+ nbytes = int(np.prod(self.shape)) * np.dtype(self.dtype).itemsize
192
+ offset = self.start_offset
193
+ tensor = np.frombuffer(self.buffermmap, dtype=self.dtype, offset=offset,
194
+ count=nbytes // np.dtype(self.dtype).itemsize)
195
+ tensor = tensor.reshape(self.shape)
196
+ if not tensor.flags["ALIGNED"]:
197
+ logger.info("This safetensors file is not aligned.")
198
+ tensor = tensor.copy()
199
+ return tensor
200
+
201
+ @property
202
+ def start_offset(self):
203
+ return self.base_ptr + self.info["data_offsets"][0]
204
+
205
+ def get_shape(self):
206
+ return self.shape
207
+
208
+ @property
209
+ def shape(self):
210
+ return self.info["shape"]
211
+
212
+ @property
213
+ def dtype(self):
214
+ """Get dtype by numpy_dtype"""
215
+ if self.info["dtype"] == "BF16":
216
+ from mindspore.common import np_dtype
217
+ if not np_dtype.np_dtype_valid(True):
218
+ raise TypeError(
219
+ "The Numpy bfloat16 data type is not supported now, please ensure that the current "
220
+ "Numpy version is not less than the version when the mindspore is compiled, "
221
+ "and the major versions are same."
222
+ )
223
+ return np_dtype.bfloat16
224
+ return numpy_dtype[self.info["dtype"]]
225
+
226
+ @property
227
+ def nelements(self):
228
+ return np.prod(self.info["shape"])
229
+
230
+ @property
231
+ def bits(self):
232
+ return dtype_size[self.info["dtype"]]
233
+
234
+ @property
235
+ def nbytes(self):
236
+ return self.nelements * dtype_size[self.info["dtype"]]
237
+
238
+
239
+ class _fast_safe_open:
240
+ """
241
+ Open a safetensors file and access its metadata and tensors efficiently.
242
+
243
+ This function is designed to work similarly to `safetensors.safe_open`,
244
+ providing a fast way to open and interact with safetensors files.
245
+ """
246
+
247
+ def __init__(self, filename, framework=None, device="cpu"):
248
+ self.filename = filename
249
+ self.framework = framework
250
+ self.file = open(self.filename, "rb")
251
+ self.file_mmap = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_COPY)
252
+ try:
253
+ self.base, self.tensors_decs, self.__metadata__ = read_metadata(self.file)
254
+ except ValueError:
255
+ raise ValueError(f"Fail to parse the input safetensors file: '{self.filename}'. "
256
+ f"Please check the correctness of the file.")
257
+ self.tensors = OrderedDict()
258
+ for key, info in self.tensors_decs.items():
259
+ self.tensors[key] = PySafeSlice(info, self.file, self.base, self.file_mmap)
260
+ self.tensors[key].key = key
261
+
262
+ def __enter__(self):
263
+ return self
264
+
265
+ def __exit__(self, *args):
266
+ self.file.close()
267
+
268
+ def metadata(self):
269
+ return self.__metadata__
270
+
271
+ def keys(self):
272
+ return list(self.tensors.keys())
273
+
274
+ def get_tensor(self, name):
275
+ return self.tensors[name].get()
276
+
277
+
278
+ def _fast_load_file(filename):
279
+ """
280
+ Load safetensors info from a specified file.
281
+ """
282
+ result = {}
283
+ with _fast_safe_open(filename, framework="np") as f:
284
+ for k in f.keys():
285
+ result[k] = f.get_tensor(k)
286
+ return result
45
287
 
46
288
 
47
289
  def _progress_bar(iterable, total=None):
@@ -267,15 +509,22 @@ def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_fi
267
509
  pipe_param_list[layout[6][0]].append(name)
268
510
  part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
269
511
  processes = []
270
- for i in range(process_num):
271
- p = mp.Process(target=_transform_safetensors_single, args=(
272
- part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
273
- src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
274
- ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
275
- p.start()
276
- processes.append(p)
277
- for p in processes:
278
- p.join()
512
+ if process_num > 1:
513
+ for i in range(process_num):
514
+ p = mp.Process(target=_transform_safetensors_single, args=(
515
+ part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
516
+ src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
517
+ ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
518
+ p.start()
519
+ processes.append(p)
520
+ for p in processes:
521
+ p.join()
522
+ else:
523
+ _transform_safetensors_single(part_list_dict[0], all_safetensor_files_map, src_stage_device_num,
524
+ dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
525
+ origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
526
+ dst_safetensors_dir, output_format, _transform_param_list,
527
+ pipe_param_list[0])
279
528
 
280
529
 
281
530
  def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
@@ -288,7 +537,7 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
288
537
  return set()
289
538
 
290
539
 
291
- def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
540
+ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, safetensor_dict, redundancy_dict,
292
541
  needed_rank, device_num, choice_func):
293
542
  """Find the rank_id under redundant groups."""
294
543
  io_time = 0
@@ -305,7 +554,7 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
305
554
  break
306
555
  if open_file_id is not None:
307
556
  start_time = time.time()
308
- output = file_dict[open_file_id].get_slice(param_name)
557
+ output = file_dict[open_file_id].get_tensor(param_name)
309
558
  end_time = time.time()
310
559
  cost_time = end_time - start_time
311
560
  io_time += cost_time
@@ -316,7 +565,7 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
316
565
  if not isinstance(choice_out, (bool, str)):
317
566
  raise ValueError("For 'unified_safetensors', the return value type of the function "
318
567
  f"'choice_func' must be bool or str, but got {type(choice_out)}.")
319
- saftensor_dict[param_name] = output
568
+ safetensor_dict[param_name] = output
320
569
  else:
321
570
  raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
322
571
  f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
@@ -334,6 +583,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
334
583
  Transforms safetensors files to a specified format without using parallel processing.
335
584
  """
336
585
  io_cost_time = 0
586
+ meta_data = {"format": "ms"}
337
587
  if src_strategy_file is not None:
338
588
  from mindspore.train._utils import get_parameter_redundancy
339
589
  redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
@@ -353,13 +603,15 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
353
603
  file_dict = {}
354
604
  single_param_dict = {}
355
605
  for file_id, _ in all_safetensor_files_map.items():
356
- f = safe_open(all_safetensor_files_map.get(file_id), framework="np")
606
+ f = _fast_safe_open(all_safetensor_files_map.get(file_id), framework="np")
357
607
  file_dict[file_id] = f
358
608
  for param_name in f.keys():
359
609
  if param_name not in single_param_dict.keys():
360
610
  single_param_dict[param_name] = {file_id}
361
611
  else:
362
612
  single_param_dict[param_name].add(file_id)
613
+ if f.metadata() is not None:
614
+ meta_data.update(f.metadata())
363
615
  src_strategy_list_keys = _convert_to_list(src_strategy_dict).keys() if src_strategy_dict else []
364
616
  dst_strategy_list_keys = _convert_to_list(dst_strategy_dict).keys() if dst_strategy_dict else []
365
617
  for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
@@ -368,27 +620,29 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
368
620
  needed_rank_list = needed_rank_list_key.split("-")
369
621
  for needed_rank in needed_rank_list:
370
622
  if pipe_param_list:
371
- saftensor_dict = dict()
623
+ safetensor_dict = dict()
372
624
  if src_strategy_file is not None:
373
625
  io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
374
- saftensor_dict, redundancy_dict, needed_rank,
626
+ safetensor_dict, redundancy_dict, needed_rank,
375
627
  device_num, choice_func)
376
628
  io_cost_time += io_time
377
629
  else:
378
- with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
630
+ with _fast_safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
379
631
  if not unified_flag:
380
632
  all_param_name_set = set(f.keys())
381
633
  src_param_name_set = set(src_strategy_list_keys)
382
634
  dst_param_name_set = set(dst_strategy_list_keys)
383
635
  hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
384
636
  pipe_param_list.extend(list(hyper_param_set))
637
+ if f.metadata() is not None:
638
+ meta_data.update(f.metadata())
385
639
  io_time = 0
386
640
  for param_name in pipe_param_list:
387
641
  if param_name not in f.keys():
388
642
  # param not in ckpt file, check reason
389
643
  continue
390
644
  start_time = time.time()
391
- output = f.get_slice(param_name)
645
+ output = f.get_tensor(param_name)
392
646
  end_time = time.time()
393
647
  cost_time = end_time - start_time
394
648
  io_time += cost_time
@@ -400,15 +654,15 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
400
654
  if not isinstance(choice_out, (bool, str)):
401
655
  raise ValueError("For 'unified_safetensors', the return value type of the function "
402
656
  f"'choice_func' must be bool or str, but got {type(choice_out)}.")
403
- saftensor_dict[param_name] = output
657
+ safetensor_dict[param_name] = output
404
658
  else:
405
659
  start_time = time.time()
406
- saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
660
+ safetensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
407
661
  end_time = time.time()
408
662
  cost_time = end_time - start_time
409
663
  io_cost_time += cost_time
410
664
 
411
- for param_name, param in saftensor_dict.items():
665
+ for param_name, param in safetensor_dict.items():
412
666
  src_rank = int(needed_rank) % src_stage_device_num
413
667
  param_total_dict[param_name][src_rank] = param
414
668
  param_attr_dict[param_name][src_rank] = (True, False)
@@ -442,11 +696,11 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
442
696
  else:
443
697
  if transform_param_dict:
444
698
  if output_format == "safetensors":
445
- save_file(transform_param_dict, save_file_name)
699
+ _save_file_atomically(transform_param_dict, save_file_name, metadata=meta_data)
446
700
  else:
447
- transform_param_dict = _load_and_transform(transform_param_dict,
448
- None, None, transform_func=
449
- lambda v, name: ms.Parameter(v, name=name))
701
+ transform_param_dict = _load_and_transform(transform_param_dict, None, None,
702
+ transform_func=lambda v, name: Parameter(v,
703
+ name=name))
450
704
  ms.save_checkpoint(transform_param_dict, save_file_name)
451
705
  del param_total_dict_keys
452
706
  del param_total_dict
@@ -464,10 +718,10 @@ def _save_final_safetensors(_transform_param_list, output_format):
464
718
  new_transform_dict[save_file_name].update(transform_param_dict)
465
719
  for save_file_name, transform_param_dict in new_transform_dict.items():
466
720
  if output_format == "safetensors":
467
- save_file(transform_param_dict, save_file_name)
721
+ _save_file_atomically(transform_param_dict, save_file_name, metadata={"format": "ms"})
468
722
  else:
469
723
  transform_param_dict = _load_and_transform(transform_param_dict, None, None,
470
- transform_func=lambda v, name: ms.Parameter(v, name=name))
724
+ transform_func=lambda v, name: Parameter(v, name=name))
471
725
  ms.save_checkpoint(transform_param_dict, save_file_name)
472
726
 
473
727
 
@@ -501,8 +755,8 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
501
755
  if not os.path.exists(local_file):
502
756
  raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
503
757
  for rank, file_name in safetensor_files_map.items():
504
- saftensor_dict = load_file(file_name)
505
- for param_name, param in saftensor_dict.items():
758
+ safetensor_dict = load_file(file_name)
759
+ for param_name, param in safetensor_dict.items():
506
760
  # cut the parameter not in the pipeline stage.
507
761
  if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
508
762
  and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
@@ -520,7 +774,7 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
520
774
  if not os.path.exists(save_safetensor_file_dir):
521
775
  _make_dir(save_safetensor_file_dir, "path")
522
776
  save_safetensor_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
523
- save_file(transform_param_dict, save_safetensor_file_name)
777
+ _save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
524
778
 
525
779
 
526
780
  def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor_file_name,
@@ -556,8 +810,8 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
556
810
  origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
557
811
  origin_src_strategy_list = _extract_layout_map(src_strategy_file)
558
812
  for rank, file_name in safetensor_files_map.items():
559
- saftensor_dict = load_file(file_name)
560
- for param_name, param in saftensor_dict.items():
813
+ safetensor_dict = load_file(file_name)
814
+ for param_name, param in safetensor_dict.items():
561
815
  # cut the parameter not in the pipeline stage.
562
816
  if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
563
817
  and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
@@ -572,7 +826,7 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
572
826
  transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
573
827
  param_attr_dict, src_strategy_list, dst_strategy_list,
574
828
  param_type_dict)
575
- save_file(transform_param_dict, save_safetensor_file_name)
829
+ _save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
576
830
 
577
831
 
578
832
  def _extrace_number(file_name):
@@ -628,7 +882,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
628
882
 
629
883
  def load_file_by_param_name(filename, parme_name_list):
630
884
  result = {}
631
- with safe_open(filename, framework="np") as f:
885
+ with _fast_safe_open(filename, framework="np") as f:
632
886
  for k in parme_name_list:
633
887
  result[k] = f.get_tensor(k)
634
888
  return result
@@ -644,10 +898,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
644
898
  device_num = -1
645
899
  param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
646
900
  for param_name in param_total_dict_keys:
647
- if str(type(list(param_total_dict[param_name].values())[0])) == "<class 'builtins.PySafeSlice'>":
648
- tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
649
- else:
650
- tensor_shape = list(param_total_dict[param_name].values())[0].shape
901
+ tensor_shape = list(param_total_dict[param_name].values())[0].shape
651
902
  from_dev_matrix = [1]
652
903
  from_tensor_map = [-1] * len(tensor_shape)
653
904
  from_opt_shard_step = 0
@@ -695,7 +946,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
695
946
  # when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
696
947
  device_list = list(range(0, np.prod(from_tensor_layout[0])))
697
948
  if rank_id % device_num not in param_attr_dict[param_name] and src_strategy_file is None:
698
- raise ValueError("The safetensor of rank {} is missing.".format(rank_id % device_num))
949
+ raise ValueError("The param: {} in rank {} is missing.".format(param_name, rank_id % device_num))
699
950
  param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
700
951
  device_list, rank_id)
701
952
 
@@ -711,8 +962,6 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
711
962
  if isinstance(choice_out, str):
712
963
  param_name = choice_out
713
964
  transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
714
- if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
715
- transform_param_dict[param_name] = transform_param_dict[param_name][:]
716
965
 
717
966
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
718
967
  for param_name in param_total_dict_keys:
@@ -722,33 +971,14 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
722
971
  continue
723
972
  if param_name not in transform_param_dict:
724
973
  transform_para = param_total_dict[param_name][rank_id % device_num]
725
- if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
726
- transform_para = transform_para[:]
727
974
  transform_param_dict[param_name] = transform_para
728
975
  return transform_param_dict
729
976
 
730
977
 
731
978
  def _cal_param_size(shape, dtype):
732
979
  """cal param size by dtype and shape"""
733
- dtype_size = {
734
- "BOOL": 1,
735
- "U8": 1,
736
- "I8": 1,
737
- "F8_E5M2": 1,
738
- "F8_E4M3": 1,
739
- "I16": 2,
740
- "U16": 2,
741
- "I32": 4,
742
- "U32": 4,
743
- "I64": 8,
744
- "U64": 8,
745
- "F16": 2,
746
- "BF16": 2,
747
- "F32": 4,
748
- "F64": 8,
749
- }
750
980
  num_elements = math.prod(shape)
751
- element_size = dtype_size.get(dtype, 4)
981
+ element_size = np_dtype_size.get(dtype, 4)
752
982
  total_bytes = num_elements * element_size
753
983
  return total_bytes
754
984
 
@@ -769,14 +999,15 @@ def _split_weight_dict(weights, num_groups):
769
999
  def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
770
1000
  """save hyper param"""
771
1001
  if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
772
- with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
1002
+ with _fast_safe_open(all_safetensor_files_map.get(0), framework="np") as f:
773
1003
  all_key = f.keys()
774
1004
  hyper_parameter = set(all_key) - set(name_list)
775
1005
  if hyper_parameter:
776
1006
  hyper_dict = {}
777
1007
  for key in hyper_parameter:
778
1008
  hyper_dict[key] = f.get_tensor(key)
779
- save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
1009
+ _save_file_atomically(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"),
1010
+ metadata={"format": "ms"})
780
1011
 
781
1012
 
782
1013
  def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
@@ -826,14 +1057,57 @@ def _get_dst_shape(param_name, param_shape, src_strategy_list):
826
1057
  return to_full_tensor_shape
827
1058
 
828
1059
 
1060
+ def _check_remove_redundancy(merge_with_redundancy, f):
1061
+ """Check whether remove_redundancy is consistent with the safetensors file."""
1062
+ if f.metadata() is not None and "remove_redundancy" in f.metadata().keys():
1063
+ if f.metadata()["remove_redundancy"] == "True" and merge_with_redundancy:
1064
+ logger.warning("For 'unified_safetensors', the safetensors file is deduplicated, "
1065
+ "but merge_with_redundancy is set to True.")
1066
+ return False
1067
+ if f.metadata()["remove_redundancy"] == "False" and not merge_with_redundancy:
1068
+ logger.warning("For 'unified_safetensors', the safetensors file is non-deduplicated, "
1069
+ "but merge_with_redundancy is set to False.")
1070
+ return True
1071
+ return merge_with_redundancy
1072
+
1073
+
1074
+ def set_affinity_pid():
1075
+ """Set CPU affinity pid"""
1076
+ pid = os.getpid()
1077
+ total_cores = os.cpu_count()
1078
+ all_cores = set(range(total_cores))
1079
+ os.sched_setaffinity(pid, all_cores)
1080
+
1081
+
1082
+ def _validate_safetensors_files(target_directory, expected_file_ids):
1083
+ """Validate whether safetensors files are completely generated in the target directory."""
1084
+ missing_file_ids = []
1085
+ for file_id in expected_file_ids:
1086
+ safetensors_file = os.path.join(target_directory, f"part{file_id}.safetensors")
1087
+ if os.path.exists(safetensors_file):
1088
+ continue
1089
+ missing_file_ids.append(file_id)
1090
+
1091
+ if missing_file_ids:
1092
+ logger.warning(
1093
+ f"For unified_safetensors, target file part {missing_file_ids} does not exist. "
1094
+ f"Possible causes: file rename failed, insufficient permissions, or disk space shortage."
1095
+ )
1096
+
1097
+
829
1098
  def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
830
1099
  max_process_num=64, choice_func=None, split_dst_file=()):
831
1100
  """
832
1101
  Merge multiple safetensor files into a unified safetensor file.
833
1102
 
1103
+ Note:
1104
+ When merging weights, it will verify whether the `merge_with_redundancy` parameter differs from
1105
+ the deduplication flag in the merged safetensors files. If they are the same, the merging will be performed
1106
+ according to the deduplication flag in the files.
1107
+
834
1108
  Args:
835
1109
  src_dir (str): Source weight saving directory.
836
- src_strategy_file (str): Source weight segmentation strategy file.
1110
+ src_strategy_file (str): Source weight segmentation strategy file with the file extension `.ckpt` .
837
1111
  dst_dir (str): Target save directory.
838
1112
  merge_with_redundancy (bool, optional): Whether the merged source weight files are de-duplicated and
839
1113
  saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
@@ -861,10 +1135,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
861
1135
  >>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
862
1136
  >>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
863
1137
  """
864
- pid = os.getpid()
865
- total_cores = os.cpu_count()
866
- all_cores = set(range(total_cores))
867
- os.sched_setaffinity(pid, all_cores)
1138
+ set_affinity_pid()
868
1139
  _check_transform_safetensors(src_dir, "", src_strategy_file, None)
869
1140
  _make_dir(dst_dir, "path")
870
1141
  if os.path.isfile(src_dir):
@@ -890,8 +1161,9 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
890
1161
 
891
1162
  actual_params = set()
892
1163
  for _, file_name in all_safetensor_files_map.items():
893
- with safe_open(file_name, framework="np") as f:
1164
+ with _fast_safe_open(file_name, framework="np") as f:
894
1165
  actual_params.update(f.keys())
1166
+ merge_with_redundancy = _check_remove_redundancy(merge_with_redundancy, f)
895
1167
 
896
1168
  params_to_store = actual_params & set(layout_map.keys())
897
1169
 
@@ -904,21 +1176,22 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
904
1176
  param_size_dict = {}
905
1177
  param_total_size = 0
906
1178
  for _, file_name in all_safetensor_files_map.items():
907
- with safe_open(file_name, framework="np") as f:
1179
+ with _fast_safe_open(file_name, framework="np") as f:
908
1180
  for k in f.keys():
909
1181
  if k in name_list:
910
- py_slice = f.get_slice(k)
911
- param_total_size += _cal_param_size(py_slice.get_shape(), py_slice.get_dtype())
912
- param_dst_shape = _get_dst_shape(k, py_slice.get_shape(), origin_src_strategy_list)
1182
+ py_slice = f.get_tensor(k)
1183
+ param_total_size += _cal_param_size(py_slice.shape, py_slice.dtype)
1184
+ param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
913
1185
  # Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
914
1186
  param_dst_shape = [int(item) for item in param_dst_shape]
915
1187
  if choice_func is not None:
916
1188
  choice_out = choice_func(k)
917
1189
  if isinstance(choice_out, bool):
918
1190
  if not choice_out:
1191
+ name_list.remove(k)
919
1192
  continue
920
1193
  if k not in param_size_dict:
921
- param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.get_dtype())
1194
+ param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.dtype)
922
1195
  split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
923
1196
  split_num = min(split_num, len(name_list))
924
1197
  split_list = _split_weight_dict(param_size_dict, split_num)
@@ -932,37 +1205,44 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
932
1205
  start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
933
1206
  end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
934
1207
  sub_list = []
935
- for i in range(len(split_list)):
1208
+ for i, item in enumerate(split_list):
936
1209
  if start_index <= i < end_index:
937
- sub_list.append(split_list[i])
1210
+ sub_list.append(item)
938
1211
  else:
939
1212
  sub_list.append([-1])
1213
+ split_num = end_index - start_index
1214
+ res = list(range(start_index, end_index))
940
1215
  else:
941
1216
  sub_list = split_list
1217
+ res = [i for i in range(split_num)]
942
1218
 
943
1219
  _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
944
1220
  _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
945
1221
 
946
- if split_dst_file:
947
- split_num = end_index - start_index
948
- res = list(range(start_index, end_index))
949
- else:
950
- res = [i for i in range(split_num)]
951
1222
  max_process = min(split_num, max_process_num)
1223
+ file_ids = res[:]
952
1224
  res = _split_list(res, max_process)
953
1225
  processes = []
954
1226
  src_strategy_name = None
955
1227
  if not merge_with_redundancy:
956
1228
  src_strategy_name = src_strategy_file
957
- for i in range(max_process):
958
- p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
959
- needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
960
- src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
961
- "", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
962
- p.start()
963
- processes.append(p)
964
- for p in processes:
965
- p.join()
1229
+ if max_process > 1:
1230
+ for i in range(max_process):
1231
+ p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
1232
+ needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
1233
+ src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
1234
+ "", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
1235
+ p.start()
1236
+ processes.append(p)
1237
+ for p in processes:
1238
+ p.join()
1239
+ else:
1240
+ _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
1241
+ dst_stage_device_num, src_strategy_dict, None,
1242
+ origin_src_strategy_list, origin_dst_strategy_list, "",
1243
+ dst_dir, "safetensors", None, sub_list,
1244
+ res[0], True, src_strategy_name, choice_func)
1245
+ _validate_safetensors_files(dst_dir, file_ids)
966
1246
 
967
1247
 
968
1248
  def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map,
@@ -997,7 +1277,7 @@ def _split_list(split_list, split_num):
997
1277
  def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num):
998
1278
  """apply safetensors object operators"""
999
1279
  if not transform_operator_stack:
1000
- return sf_obj[:]
1280
+ return sf_obj
1001
1281
  level = transform_operator_stack[-1][1]
1002
1282
  level_operators = []
1003
1283
  while True:
@@ -1022,7 +1302,7 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
1022
1302
  allgather_list = [sf_obj for _ in operator[1][:-1]]
1023
1303
  tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
1024
1304
  if op_name == "AllConcat":
1025
- for rank, value in tmp_tensor_dict.items():
1305
+ for _, value in tmp_tensor_dict.items():
1026
1306
  sf_obj = value
1027
1307
  level_operators.clear()
1028
1308
  if not transform_operator_stack:
@@ -1037,13 +1317,26 @@ def _process_hyper_params(file_list, total_safetensors_dir, total_param):
1037
1317
  """process hyper params"""
1038
1318
  if 'hyper_param.safetensors' in file_list:
1039
1319
  hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
1040
- with safe_open(hyper_parameter_file_name, framework="np") as f:
1320
+ with _fast_safe_open(hyper_parameter_file_name, framework="np") as f:
1041
1321
  for key in f.keys():
1042
- total_param[key] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
1322
+ total_param[key] = Parameter(Tensor.from_numpy(f.get_tensor(key)))
1043
1323
  return total_param
1044
1324
 
1045
1325
 
1046
- def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files, dst_strategy_file, rank_id):
1326
+ def _get_param_name_map_by_file(file_name, file_list, name_map):
1327
+ """get param_name_map by file"""
1328
+ with _fast_safe_open(file_name, framework="np") as f:
1329
+ keys = f.keys()
1330
+ values = len(keys) * [file_list[0]]
1331
+ if name_map:
1332
+ flipped_name_map = {value: key for key, value in name_map.items()}
1333
+ keys = [flipped_name_map.get(key, key) for key in keys]
1334
+ param_name_map = dict(zip(keys, values))
1335
+ return param_name_map
1336
+
1337
+
1338
+ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files,
1339
+ dst_strategy_file, rank_id, name_map=None):
1047
1340
  """calculate param_name_map and param_list"""
1048
1341
  if len(file_list) == 1:
1049
1342
  logger.info("There is only one weight file in the directory, which will be automatically mapped.")
@@ -1052,10 +1345,7 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
1052
1345
  if not is_file:
1053
1346
  raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
1054
1347
  f"in the `unified_safetensors_dir`.")
1055
- with safe_open(file_name, framework="np") as f:
1056
- keys = f.keys()
1057
- values = len(keys) * [file_list[0]]
1058
- param_name_map = dict(zip(keys, values))
1348
+ param_name_map = _get_param_name_map_by_file(file_name, file_list, name_map)
1059
1349
  else:
1060
1350
  if not json_files:
1061
1351
  raise ValueError(
@@ -1076,19 +1366,71 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
1076
1366
  return param_name_map, param_list, dst_strategy_list
1077
1367
 
1078
1368
 
1369
+ def _cal_transform_operator_stack_and_device_num(from_dev_matrix, from_tensor_map, from_opt_shard_step,
1370
+ from_opt_shard_size, param_name, dst_strategy_list, tensor_shape,
1371
+ local_rank_id):
1372
+ """cal transform_operator_stack and device_num"""
1373
+ to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
1374
+ dst_strategy_list.get(param_name))
1375
+
1376
+ device_num = np.prod(from_dev_matrix)
1377
+ param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
1378
+ origin_tensor_shape = ()
1379
+ for i, item in enumerate(tensor_shape):
1380
+ if i == 0 and from_opt_shard_size > 0:
1381
+ origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
1382
+ continue
1383
+ origin_tensor_shape += (item * param_strategy[i],)
1384
+
1385
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
1386
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
1387
+
1388
+ from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1389
+ from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
1390
+ to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1391
+ to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
1392
+ # Convert tensor layout to same device num
1393
+ from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
1394
+ from_dev_matrix,
1395
+ from_tensor_map,
1396
+ to_full_tensor_shape,
1397
+ to_dev_matrix, to_tensor_map)
1398
+
1399
+ # when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
1400
+ device_list = list(range(0, np.prod(from_tensor_layout[0])))
1401
+ param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
1402
+ device_list, local_rank_id)
1403
+
1404
+ from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
1405
+ to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
1406
+ _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
1407
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
1408
+ has_layout_from, has_layout_to)
1409
+ transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
1410
+ return transform_operator_stack, device_num
1411
+
1412
+
1413
+ def check_param_dtype(file, param_name):
1414
+ dtype_need_changed = False
1415
+ changed_dtype = None
1416
+ if file.metadata() is not None and param_name in file.metadata().keys():
1417
+ dtype_need_changed = True
1418
+ sf_dtype = file.metadata()[param_name]
1419
+ changed_dtype = safetensors_to_mstype[sf_dtype]
1420
+ return dtype_need_changed, changed_dtype
1421
+
1422
+
1079
1423
  def _load_parallel_checkpoint(file_info):
1080
1424
  """load parallel safetensors by merged file."""
1081
1425
  total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
1082
- rank_id, output_format, name_map, return_param_dict = file_info
1083
- pid = os.getpid()
1084
- total_cores = os.cpu_count()
1085
- all_cores = set(range(total_cores))
1086
- os.sched_setaffinity(pid, all_cores)
1426
+ rank_id, output_format, name_map, return_param_dict = file_info
1427
+ set_affinity_pid()
1087
1428
  file_list = os.listdir(total_safetensors_dir)
1088
1429
  json_files = [file for file in file_list if file == "param_name_map.json"]
1089
- param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
1430
+ sf_files = [file for file in file_list if file.endswith('.safetensors')]
1431
+ param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(sf_files, total_safetensors_dir,
1090
1432
  json_files, dst_strategy_file,
1091
- rank_id)
1433
+ rank_id, name_map)
1092
1434
  total_param = dict()
1093
1435
  dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
1094
1436
  is not None else 1
@@ -1098,13 +1440,14 @@ def _load_parallel_checkpoint(file_info):
1098
1440
  if param_name not in param_name_map:
1099
1441
  continue
1100
1442
  file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
1101
- with safe_open(file_name, framework="np") as f:
1443
+ with _fast_safe_open(file_name, framework="np") as f:
1102
1444
  cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
1103
1445
  if cur_param_name not in f.keys():
1104
1446
  continue
1105
- sf_obj = f.get_slice(cur_param_name)
1447
+ sf_obj = f.get_tensor(cur_param_name)
1448
+ dtype_need_changed, changed_dtype = check_param_dtype(f, param_name)
1106
1449
 
1107
- tensor_shape = sf_obj.get_shape()
1450
+ tensor_shape = sf_obj.shape
1108
1451
  from_dev_matrix = [1]
1109
1452
  from_tensor_map = [-1] * len(tensor_shape)
1110
1453
  from_opt_shard_step = 0
@@ -1112,43 +1455,14 @@ def _load_parallel_checkpoint(file_info):
1112
1455
  if dst_strategy_list is not None:
1113
1456
  if param_name not in dst_strategy_list:
1114
1457
  continue
1115
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
1116
- dst_strategy_list.get(param_name))
1117
-
1118
- device_num = np.prod(from_dev_matrix)
1119
- param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
1120
- origin_tensor_shape = ()
1121
- for i, item in enumerate(tensor_shape):
1122
- if i == 0 and from_opt_shard_size > 0:
1123
- origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
1124
- continue
1125
- origin_tensor_shape += (item * param_strategy[i],)
1126
-
1127
- has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
1128
- has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
1129
-
1130
- from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1131
- from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
1132
- to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1133
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
1134
- # Convert tensor layout to same device num
1135
- from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
1136
- from_dev_matrix,
1137
- from_tensor_map,
1138
- to_full_tensor_shape,
1139
- to_dev_matrix, to_tensor_map)
1140
-
1141
- # when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
1142
- device_list = list(range(0, np.prod(from_tensor_layout[0])))
1143
- param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
1144
- device_list, local_rank_id)
1145
-
1146
- from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
1147
- to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
1148
- _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
1149
- _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
1150
- has_layout_from, has_layout_to)
1151
- transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
1458
+ transform_operator_stack, device_num = _cal_transform_operator_stack_and_device_num(from_dev_matrix,
1459
+ from_tensor_map,
1460
+ from_opt_shard_step,
1461
+ from_opt_shard_size,
1462
+ param_name,
1463
+ dst_strategy_list,
1464
+ tensor_shape,
1465
+ local_rank_id)
1152
1466
  start_time = time.time()
1153
1467
  slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
1154
1468
  end_time = time.time()
@@ -1156,12 +1470,15 @@ def _load_parallel_checkpoint(file_info):
1156
1470
  total_io_cost_time += cost_time
1157
1471
  else:
1158
1472
  start_time = time.time()
1159
- slice_param = sf_obj[:]
1473
+ slice_param = sf_obj
1160
1474
  end_time = time.time()
1161
1475
  cost_time = end_time - start_time
1162
1476
  total_io_cost_time += cost_time
1163
1477
  slice_param_copy = np.copy(slice_param)
1164
- total_param[param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param_copy))
1478
+ if dtype_need_changed:
1479
+ total_param[param_name] = Parameter(Tensor(slice_param_copy, dtype=changed_dtype))
1480
+ else:
1481
+ total_param[param_name] = Parameter(Tensor.from_numpy(slice_param_copy))
1165
1482
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1166
1483
  f"load distributed safetensors io cost time:{total_io_cost_time}.")
1167
1484
  total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
@@ -1178,28 +1495,5 @@ def _load_parallel_checkpoint(file_info):
1178
1495
  return None
1179
1496
 
1180
1497
 
1181
- def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
1182
- """get slice op"""
1183
- tensor_shape = sf_obj.get_shape()
1184
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
1185
- dst_strategy_list.get(param_name))
1186
- # Add optimizer sharding dim for tensor layout
1187
- to_dev_matrix, to_tensor_map, _ = _construct_tensor_layout_for_opt_shard(
1188
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, tensor_shape)
1189
- slice_op = _load_tensor_shape(to_dev_matrix, to_tensor_map, full_shape=tensor_shape, rank_id=rank_id)
1190
- shape = None
1191
- if to_opt_shard_size > 0:
1192
- to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
1193
- to_slice_tensor_shape = ()
1194
- for i, item in enumerate(tensor_shape):
1195
- if i == 0 and to_opt_shard_size > 0:
1196
- to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
1197
- continue
1198
- to_slice_tensor_shape += (item // to_tensor_strategy[i],)
1199
- shape = list(to_slice_tensor_shape)
1200
-
1201
- return slice_op, shape
1202
-
1203
-
1204
1498
  __all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
1205
1499
  "transform_safetensors_by_rank", "unified_safetensors"]