mindspore 2.6.0rc1__cp39-cp39-win_amd64.whl → 2.7.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (384) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +40 -9
  7. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  8. mindspore/_extends/optimize/cell_utils.py +96 -0
  9. mindspore/_extends/parse/__init__.py +2 -2
  10. mindspore/_extends/parse/compile_config.py +44 -22
  11. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
  12. mindspore/_extends/parse/parser.py +37 -62
  13. mindspore/_extends/parse/resources.py +39 -0
  14. mindspore/_extends/parse/standard_method.py +43 -13
  15. mindspore/_extends/parse/trope.py +8 -1
  16. mindspore/_extends/pijit/__init__.py +1 -2
  17. mindspore/amp.py +4 -4
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/adasum.py +1 -1
  24. mindspore/boost/boost_cell_wrapper.py +4 -4
  25. mindspore/common/__init__.py +27 -2
  26. mindspore/common/_grad_function.py +2 -1
  27. mindspore/common/_pijit_context.py +28 -7
  28. mindspore/common/_stub_tensor.py +1 -209
  29. mindspore/common/_tensor_cpp_method.py +1 -1
  30. mindspore/common/_tensor_docs.py +77 -16
  31. mindspore/common/api.py +238 -113
  32. mindspore/common/dtype.py +21 -11
  33. mindspore/common/dump.py +10 -15
  34. mindspore/common/generator.py +5 -3
  35. mindspore/common/hook_handle.py +11 -2
  36. mindspore/common/jit_config.py +1 -1
  37. mindspore/common/jit_trace.py +84 -105
  38. mindspore/common/parameter.py +26 -12
  39. mindspore/common/recompute.py +3 -3
  40. mindspore/common/sparse_tensor.py +0 -3
  41. mindspore/common/symbol.py +0 -1
  42. mindspore/common/tensor.py +81 -81
  43. mindspore/communication/_comm_helper.py +46 -4
  44. mindspore/communication/management.py +79 -7
  45. mindspore/context.py +58 -40
  46. mindspore/dataset/core/config.py +3 -3
  47. mindspore/dataset/engine/datasets.py +20 -7
  48. mindspore/dataset/engine/datasets_user_defined.py +33 -3
  49. mindspore/dataset/engine/iterators.py +2 -2
  50. mindspore/dataset/engine/obs/config_loader.py +2 -2
  51. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  52. mindspore/dataset/transforms/py_transforms.py +7 -3
  53. mindspore/dataset/transforms/transforms.py +7 -3
  54. mindspore/dataset/vision/validators.py +1 -0
  55. mindspore/device_context/ascend/device.py +1 -1
  56. mindspore/device_context/gpu/__init__.py +2 -2
  57. mindspore/device_context/gpu/device.py +1 -1
  58. mindspore/device_context/gpu/op_precision.py +4 -2
  59. mindspore/device_context/gpu/op_tuning.py +6 -3
  60. mindspore/device_manager.py +16 -9
  61. mindspore/dnnl.dll +0 -0
  62. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -7
  63. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  64. mindspore/experimental/optim/adadelta.py +13 -20
  65. mindspore/experimental/optim/adagrad.py +15 -22
  66. mindspore/experimental/optim/adam.py +17 -24
  67. mindspore/experimental/optim/adamax.py +14 -22
  68. mindspore/experimental/optim/adamw.py +28 -34
  69. mindspore/experimental/optim/asgd.py +15 -25
  70. mindspore/experimental/optim/lr_scheduler.py +27 -45
  71. mindspore/experimental/optim/nadam.py +14 -24
  72. mindspore/experimental/optim/optimizer.py +13 -23
  73. mindspore/experimental/optim/radam.py +18 -24
  74. mindspore/experimental/optim/rmsprop.py +14 -25
  75. mindspore/experimental/optim/rprop.py +15 -26
  76. mindspore/experimental/optim/sgd.py +9 -19
  77. mindspore/hal/__init__.py +4 -4
  78. mindspore/hal/contiguous_tensors_handle.py +2 -2
  79. mindspore/hal/memory.py +27 -7
  80. mindspore/include/api/cell.h +37 -1
  81. mindspore/include/api/delegate.h +10 -0
  82. mindspore/include/api/model.h +3 -0
  83. mindspore/include/api/types.h +2 -2
  84. mindspore/include/c_api/model_c.h +0 -58
  85. mindspore/include/c_api/tensor_c.h +0 -26
  86. mindspore/include/dataset/vision_ascend.h +1 -1
  87. mindspore/jpeg62.dll +0 -0
  88. mindspore/mindrecord/tools/cifar10.py +60 -11
  89. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  90. mindspore/mindspore_backend_common.dll +0 -0
  91. mindspore/mindspore_backend_manager.dll +0 -0
  92. mindspore/mindspore_common.dll +0 -0
  93. mindspore/mindspore_core.dll +0 -0
  94. mindspore/mindspore_cpu_res_manager.dll +0 -0
  95. mindspore/mindspore_dump.dll +0 -0
  96. mindspore/mindspore_frontend.dll +0 -0
  97. mindspore/mindspore_glog.dll +0 -0
  98. mindspore/mindspore_memory_pool.dll +0 -0
  99. mindspore/mindspore_ms_backend.dll +0 -0
  100. mindspore/mindspore_ops.dll +0 -0
  101. mindspore/mindspore_ops_host.dll +0 -0
  102. mindspore/mindspore_ops_kernel_common.dll +0 -0
  103. mindspore/mindspore_profiler.dll +0 -0
  104. mindspore/mindspore_pyboost.dll +0 -0
  105. mindspore/mindspore_pynative.dll +0 -0
  106. mindspore/mindspore_res_manager.dll +0 -0
  107. mindspore/mindspore_runtime_pipeline.dll +0 -0
  108. mindspore/mint/__init__.py +6 -46
  109. mindspore/mint/distributed/__init__.py +1 -0
  110. mindspore/mint/distributed/distributed.py +212 -9
  111. mindspore/mint/nn/__init__.py +1 -1
  112. mindspore/mint/nn/functional.py +53 -6
  113. mindspore/mint/nn/layer/_functions.py +164 -294
  114. mindspore/mint/nn/layer/activation.py +8 -6
  115. mindspore/mint/nn/layer/conv.py +137 -101
  116. mindspore/mint/nn/layer/normalization.py +8 -22
  117. mindspore/mint/optim/adam.py +19 -18
  118. mindspore/mint/optim/adamw.py +14 -8
  119. mindspore/mint/optim/sgd.py +5 -5
  120. mindspore/nn/cell.py +328 -502
  121. mindspore/nn/grad/cell_grad.py +11 -12
  122. mindspore/nn/layer/activation.py +32 -34
  123. mindspore/nn/layer/basic.py +67 -64
  124. mindspore/nn/layer/channel_shuffle.py +4 -4
  125. mindspore/nn/layer/combined.py +4 -2
  126. mindspore/nn/layer/conv.py +117 -110
  127. mindspore/nn/layer/dense.py +9 -7
  128. mindspore/nn/layer/embedding.py +50 -52
  129. mindspore/nn/layer/image.py +37 -39
  130. mindspore/nn/layer/math.py +111 -112
  131. mindspore/nn/layer/normalization.py +56 -44
  132. mindspore/nn/layer/pooling.py +58 -63
  133. mindspore/nn/layer/rnn_cells.py +33 -33
  134. mindspore/nn/layer/rnns.py +56 -56
  135. mindspore/nn/layer/thor_layer.py +74 -73
  136. mindspore/nn/layer/transformer.py +11 -1
  137. mindspore/nn/learning_rate_schedule.py +20 -20
  138. mindspore/nn/loss/loss.py +79 -81
  139. mindspore/nn/optim/adam.py +3 -3
  140. mindspore/nn/optim/adasum.py +2 -2
  141. mindspore/nn/optim/asgd.py +2 -0
  142. mindspore/nn/optim/optimizer.py +1 -1
  143. mindspore/nn/optim/thor.py +2 -2
  144. mindspore/nn/probability/distribution/exponential.py +2 -1
  145. mindspore/nn/probability/distribution/poisson.py +2 -1
  146. mindspore/nn/sparse/sparse.py +3 -3
  147. mindspore/nn/wrap/cell_wrapper.py +34 -37
  148. mindspore/nn/wrap/grad_reducer.py +37 -37
  149. mindspore/nn/wrap/loss_scale.py +72 -74
  150. mindspore/numpy/array_creations.py +5 -5
  151. mindspore/numpy/fft.py +1 -1
  152. mindspore/numpy/math_ops.py +5 -5
  153. mindspore/opencv_core452.dll +0 -0
  154. mindspore/opencv_imgcodecs452.dll +0 -0
  155. mindspore/opencv_imgproc452.dll +0 -0
  156. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  157. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  158. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  159. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  160. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +42 -11
  161. mindspore/ops/auto_generate/gen_extend_func.py +23 -141
  162. mindspore/ops/auto_generate/gen_ops_def.py +727 -321
  163. mindspore/ops/auto_generate/gen_ops_prim.py +1721 -984
  164. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  165. mindspore/ops/composite/__init__.py +10 -0
  166. mindspore/ops/composite/base.py +8 -4
  167. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  168. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  169. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  170. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  171. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  172. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  173. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  174. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  175. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  176. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  177. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  178. mindspore/ops/function/__init__.py +3 -1
  179. mindspore/ops/function/_add_attr_func.py +11 -6
  180. mindspore/ops/function/array_func.py +9 -96
  181. mindspore/ops/function/debug_func.py +4 -3
  182. mindspore/ops/function/grad/grad_func.py +1 -1
  183. mindspore/ops/function/math_func.py +33 -540
  184. mindspore/ops/function/nn_func.py +28 -74
  185. mindspore/ops/function/other_func.py +4 -1
  186. mindspore/ops/function/random_func.py +44 -5
  187. mindspore/ops/function/vmap_func.py +2 -1
  188. mindspore/ops/functional.py +2 -3
  189. mindspore/ops/functional_overload.py +571 -6
  190. mindspore/ops/op_info_register.py +21 -0
  191. mindspore/ops/operations/__init__.py +16 -11
  192. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  193. mindspore/ops/operations/_inner_ops.py +3 -6
  194. mindspore/ops/operations/_sequence_ops.py +1 -1
  195. mindspore/ops/operations/array_ops.py +2 -2
  196. mindspore/ops/operations/comm_ops.py +185 -26
  197. mindspore/ops/operations/custom_ops.py +294 -174
  198. mindspore/ops/operations/debug_ops.py +59 -4
  199. mindspore/ops/operations/image_ops.py +13 -13
  200. mindspore/ops/operations/manually_defined/ops_def.py +15 -16
  201. mindspore/ops/operations/math_ops.py +3 -4
  202. mindspore/ops/operations/nn_ops.py +7 -39
  203. mindspore/ops/primitive.py +6 -10
  204. mindspore/ops/tensor_method.py +47 -8
  205. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  206. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  207. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  208. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  209. mindspore/ops_generate/common/base_generator.py +14 -0
  210. mindspore/ops_generate/common/gen_constants.py +8 -3
  211. mindspore/ops_generate/common/gen_utils.py +0 -19
  212. mindspore/ops_generate/common/op_proto.py +11 -4
  213. mindspore/ops_generate/common/template.py +88 -11
  214. mindspore/ops_generate/gen_ops.py +1 -1
  215. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  216. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  217. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  218. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  219. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  220. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  221. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  222. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
  223. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  224. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  225. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  226. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  227. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  228. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  229. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  230. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  231. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  232. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  233. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  234. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  235. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  236. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  237. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  238. mindspore/parallel/_auto_parallel_context.py +11 -8
  239. mindspore/parallel/_cell_wrapper.py +113 -45
  240. mindspore/parallel/_parallel_serialization.py +1 -1
  241. mindspore/parallel/_ps_context.py +4 -6
  242. mindspore/parallel/_tensor.py +167 -12
  243. mindspore/parallel/_transformer/moe.py +1 -1
  244. mindspore/parallel/_transformer/transformer.py +13 -8
  245. mindspore/parallel/auto_parallel.py +14 -7
  246. mindspore/parallel/checkpoint_convert.py +3 -3
  247. mindspore/parallel/checkpoint_transform.py +11 -7
  248. mindspore/parallel/cluster/process_entity/_api.py +84 -48
  249. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  250. mindspore/parallel/cluster/run.py +43 -4
  251. mindspore/parallel/function/__init__.py +8 -1
  252. mindspore/parallel/function/reshard_func.py +6 -7
  253. mindspore/parallel/nn/__init__.py +15 -2
  254. mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
  255. mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
  256. mindspore/parallel/shard.py +3 -4
  257. mindspore/parallel/transform_safetensors.py +463 -174
  258. mindspore/profiler/__init__.py +2 -1
  259. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  260. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  261. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  262. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  263. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  264. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  265. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  266. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  267. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  268. mindspore/profiler/analysis/task_manager.py +1 -1
  269. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  270. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
  272. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  273. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  274. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  275. mindspore/profiler/common/constant.py +16 -0
  276. mindspore/profiler/common/profiler_context.py +25 -27
  277. mindspore/profiler/common/profiler_info.py +0 -16
  278. mindspore/profiler/common/profiler_op_analyse.py +235 -0
  279. mindspore/profiler/common/profiler_output_path.py +23 -8
  280. mindspore/profiler/common/profiler_parameters.py +128 -35
  281. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  282. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  283. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  284. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  285. mindspore/profiler/dynamic_profiler.py +305 -314
  286. mindspore/profiler/envprofiler.py +12 -7
  287. mindspore/profiler/experimental_config.py +96 -6
  288. mindspore/profiler/mstx.py +33 -12
  289. mindspore/profiler/platform/__init__.py +2 -3
  290. mindspore/profiler/platform/npu_profiler.py +29 -19
  291. mindspore/profiler/profiler.py +35 -19
  292. mindspore/profiler/profiler_action_controller.py +64 -76
  293. mindspore/profiler/schedule.py +10 -4
  294. mindspore/rewrite/common/config.py +1 -0
  295. mindspore/rewrite/common/namer.py +1 -0
  296. mindspore/rewrite/common/namespace.py +1 -0
  297. mindspore/rewrite/node/node.py +31 -11
  298. mindspore/rewrite/parsers/assign_parser.py +1 -1
  299. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  300. mindspore/run_check/_check_version.py +7 -10
  301. mindspore/runtime/__init__.py +5 -5
  302. mindspore/runtime/event.py +10 -4
  303. mindspore/runtime/executor.py +60 -45
  304. mindspore/runtime/memory.py +30 -32
  305. mindspore/runtime/thread_bind_core.py +298 -164
  306. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  307. mindspore/swresample-4.dll +0 -0
  308. mindspore/swscale-6.dll +0 -0
  309. mindspore/tinyxml2.dll +0 -0
  310. mindspore/train/_utils.py +14 -4
  311. mindspore/train/amp.py +43 -20
  312. mindspore/train/callback/__init__.py +5 -5
  313. mindspore/train/callback/_checkpoint.py +3 -6
  314. mindspore/train/callback/_flops_collector.py +1 -1
  315. mindspore/train/callback/_landscape.py +0 -1
  316. mindspore/train/callback/_train_fault_tolerance.py +97 -16
  317. mindspore/train/data_sink.py +11 -2
  318. mindspore/train/dataset_helper.py +9 -0
  319. mindspore/train/model.py +135 -55
  320. mindspore/train/serialization.py +133 -111
  321. mindspore/train/summary/summary_record.py +13 -2
  322. mindspore/turbojpeg.dll +0 -0
  323. mindspore/utils/__init__.py +3 -2
  324. mindspore/utils/dryrun.py +0 -6
  325. mindspore/utils/runtime_execution_order_check.py +163 -77
  326. mindspore/utils/sdc_detect.py +68 -0
  327. mindspore/utils/utils.py +6 -9
  328. mindspore/version.py +1 -1
  329. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
  330. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +333 -371
  331. mindspore/_deprecated/jit.py +0 -198
  332. mindspore/experimental/es/__init__.py +0 -22
  333. mindspore/experimental/es/embedding_service.py +0 -891
  334. mindspore/experimental/es/embedding_service_layer.py +0 -581
  335. mindspore/profiler/parser/__init__.py +0 -14
  336. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  337. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  338. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  339. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  340. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  341. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  342. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  343. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  344. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  345. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  346. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  347. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  348. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  349. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  350. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  351. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  352. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  353. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  354. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  355. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  356. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  357. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  358. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  359. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  360. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  361. mindspore/profiler/parser/container.py +0 -229
  362. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  363. mindspore/profiler/parser/flops_parser.py +0 -531
  364. mindspore/profiler/parser/framework_enum.py +0 -111
  365. mindspore/profiler/parser/framework_parser.py +0 -464
  366. mindspore/profiler/parser/framework_struct.py +0 -61
  367. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  368. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  369. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  370. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  371. mindspore/profiler/parser/hccl_parser.py +0 -573
  372. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  373. mindspore/profiler/parser/integrator.py +0 -526
  374. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  375. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  376. mindspore/profiler/parser/minddata_parser.py +0 -186
  377. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  378. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  379. mindspore/profiler/parser/optime_parser.py +0 -250
  380. mindspore/profiler/parser/profiler_info.py +0 -213
  381. mindspore/profiler/parser/step_trace_parser.py +0 -666
  382. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
  383. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
  384. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
@@ -21,27 +21,263 @@ import glob
21
21
  import math
22
22
  import json
23
23
  import re
24
- from collections import defaultdict
24
+ import mmap
25
+ import stat
26
+ from collections import defaultdict, OrderedDict
25
27
 
26
28
  import time
27
29
  import multiprocessing as mp
30
+
31
+ from safetensors.numpy import save_file, load_file
28
32
  import psutil
29
33
  import numpy as np
30
- from safetensors.numpy import save_file, load_file
31
- from safetensors import safe_open
32
34
 
33
35
  import mindspore as ms
34
36
  from mindspore import log as logger
35
37
  from mindspore.log import vlog_print
38
+ from mindspore.common.parameter import Parameter
39
+ from mindspore.common.tensor import Tensor
40
+ from mindspore.common import np_dtype
36
41
  from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
37
42
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
38
43
  _insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
39
44
  from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
40
45
  _get_needed_rank_transform_operator_map_by_layouts, \
41
46
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
42
- _extract_layout_item, _load_tensor_shape, _apply_operator
47
+ _extract_layout_item, _apply_operator
43
48
  from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
44
49
  _convert_to_list
50
+ from mindspore.common import dtype as mstype
51
+
52
+ safetensors_to_mstype = {'Int4': mstype.qint4x2}
53
+
54
+ np.bfloat16 = np_dtype.bfloat16
55
+
56
+ MAX_HEADER_SIZE = 100 * 1000 * 1000
57
+
58
+ dtype_size = {
59
+ "BOOL": 1,
60
+ "U8": 1,
61
+ "I8": 1,
62
+ "I16": 2,
63
+ "U16": 2,
64
+ "I32": 4,
65
+ "U32": 4,
66
+ "I64": 8,
67
+ "U64": 8,
68
+ "F16": 2,
69
+ "BF16": 2,
70
+ "F32": 4,
71
+ "F64": 8,
72
+ }
73
+ np_dtype_size = {
74
+ "bool_": 1,
75
+ "uint8": 1,
76
+ "int8": 1,
77
+ "int16": 2,
78
+ "uint16": 2,
79
+ "int32": 4,
80
+ "uint32": 4,
81
+ "int64": 8,
82
+ "uint64": 8,
83
+ "float16": 2,
84
+ "bfloat16": 2,
85
+ "float32": 4,
86
+ "float64": 8,
87
+ }
88
+ numpy_dtype = {
89
+ "BOOL": np.bool_,
90
+ "U8": np.uint8,
91
+ "I8": np.int8,
92
+ "I16": np.int16,
93
+ "U16": np.uint16,
94
+ "I32": np.int32,
95
+ "U32": np.uint32,
96
+ "I64": np.int64,
97
+ "U64": np.uint64,
98
+ "F16": np.float16,
99
+ "BF16": np.bfloat16, # no bf16
100
+ "F32": np.float32,
101
+ "F64": np.float64,
102
+ }
103
+
104
+
105
+ def getSize(fileobject):
106
+ fileobject.seek(0, 2) # move the cursor to the end of the file
107
+ size = fileobject.tell()
108
+ fileobject.seek(0) # move the cursor to the start of the file
109
+ return size
110
+
111
+
112
+ def _save_file_atomically(transform_param_dict, save_file_name, metadata=None):
113
+ """Atomically save file using temporary name and rename."""
114
+ if metadata is None:
115
+ metadata = {"format": "ms"}
116
+ file_name_list = list(os.path.splitext(save_file_name))
117
+ file_name_list[1] = file_name_list[1].replace('.safetensors', '.tmp')
118
+ tmp_name = ''.join(file_name_list)
119
+ try:
120
+ if os.path.exists(save_file_name):
121
+ os.chmod(save_file_name, stat.S_IWUSR)
122
+ os.remove(save_file_name)
123
+ if os.path.exists(tmp_name):
124
+ os.chmod(tmp_name, stat.S_IWUSR)
125
+ os.remove(tmp_name)
126
+ save_file(transform_param_dict, tmp_name, metadata=metadata)
127
+ os.rename(tmp_name, save_file_name)
128
+ os.chmod(save_file_name, stat.S_IRUSR)
129
+ except Exception as e:
130
+ if not os.path.exists(save_file_name):
131
+ logger.warning(f"Save failed, {save_file_name} not found. "
132
+ f"This may indicate multiple processes modifying the same file "
133
+ f"or insufficient disk space.")
134
+ raise e
135
+
136
+
137
+ def metadata_validate(metadata):
138
+ """validation metadata"""
139
+ start = 0
140
+ for key, info in metadata.items():
141
+ s, e = info["data_offsets"]
142
+ if s != start or e < s:
143
+ raise ValueError(f"SafeTensorError::InvalidOffset({key})")
144
+ start = e
145
+ nelements = np.prod(info["shape"])
146
+ nbytes = nelements * dtype_size[info["dtype"]]
147
+ if (e - s) != nbytes:
148
+ raise ValueError("SafeTensorError::TensorInvalidInfo")
149
+ return start
150
+
151
+
152
+ def read_metadata(buffer):
153
+ """read metadata by buffer"""
154
+ buffer_len = getSize(buffer)
155
+ if buffer_len < 8:
156
+ raise ValueError("SafeTensorError::HeaderTooSmall")
157
+
158
+ n = np.frombuffer(buffer.read(8), dtype=np.uint64).item()
159
+ if n > MAX_HEADER_SIZE:
160
+ raise ValueError("SafeTensorError::HeaderTooLarge")
161
+
162
+ stop = n + 8
163
+ if stop > buffer_len:
164
+ raise ValueError("SafeTensorError::InvalidHeaderLength")
165
+
166
+ tensors = json.loads(buffer.read(n), object_pairs_hook=OrderedDict)
167
+ metadata = tensors.pop("__metadata__", None)
168
+ buffer_end = metadata_validate(tensors)
169
+
170
+ if buffer_end + 8 + n != buffer_len:
171
+ raise ValueError("SafeTensorError::MetadataIncompleteBuffer")
172
+
173
+ return stop, tensors, metadata
174
+
175
+
176
+ class PySafeSlice:
177
+ """Create PySafeSlice by file"""
178
+
179
+ def __init__(self, info, bufferfile, base_ptr, buffermmap):
180
+ self.info = info
181
+ self.bufferfile = bufferfile
182
+ self.buffermmap = buffermmap
183
+ self.base_ptr = base_ptr
184
+
185
+ self.start = [0 for dim in self.shape]
186
+ self.stop = [dim for dim in self.shape]
187
+ self.step = [1 for dim in self.shape]
188
+
189
+ @property
190
+ def ndim(self):
191
+ return len(self.shape)
192
+
193
+ def get(self, *args, **kwargs):
194
+ """Get tensor from buffer by data_offset"""
195
+ nbytes = int(np.prod(self.shape)) * np.dtype(self.dtype).itemsize
196
+ offset = self.start_offset
197
+ tensor = np.frombuffer(self.buffermmap, dtype=self.dtype, offset=offset,
198
+ count=nbytes // np.dtype(self.dtype).itemsize)
199
+ tensor = tensor.reshape(self.shape)
200
+ if not tensor.flags["ALIGNED"]:
201
+ logger.info("This safetensors file is not aligned.")
202
+ tensor = tensor.copy()
203
+ return tensor
204
+
205
+ @property
206
+ def start_offset(self):
207
+ return self.base_ptr + self.info["data_offsets"][0]
208
+
209
+ def get_shape(self):
210
+ return self.shape
211
+
212
+ @property
213
+ def shape(self):
214
+ return self.info["shape"]
215
+
216
+ @property
217
+ def dtype(self):
218
+ return numpy_dtype[self.info["dtype"]]
219
+
220
+ @property
221
+ def nelements(self):
222
+ return np.prod(self.info["shape"])
223
+
224
+ @property
225
+ def bits(self):
226
+ return dtype_size[self.info["dtype"]]
227
+
228
+ @property
229
+ def nbytes(self):
230
+ return self.nelements * dtype_size[self.info["dtype"]]
231
+
232
+
233
+ class _fast_safe_open:
234
+ """
235
+ Open a safetensors file and access its metadata and tensors efficiently.
236
+
237
+ This function is designed to work similarly to `safetensors.safe_open`,
238
+ providing a fast way to open and interact with safetensors files.
239
+ """
240
+
241
+ def __init__(self, filename, framework=None, device="cpu"):
242
+ self.filename = filename
243
+ self.framework = framework
244
+ self.file = open(self.filename, "rb")
245
+ self.file_mmap = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_COPY)
246
+ try:
247
+ self.base, self.tensors_decs, self.__metadata__ = read_metadata(self.file)
248
+ except ValueError:
249
+ raise ValueError(f"Fail to parse the input safetensors file: '{self.filename}'. "
250
+ f"Please check the correctness of the file.")
251
+ self.tensors = OrderedDict()
252
+ for key, info in self.tensors_decs.items():
253
+ self.tensors[key] = PySafeSlice(info, self.file, self.base, self.file_mmap)
254
+ self.tensors[key].key = key
255
+
256
+ def __enter__(self):
257
+ return self
258
+
259
+ def __exit__(self, *args):
260
+ self.file.close()
261
+
262
+ def metadata(self):
263
+ return self.__metadata__
264
+
265
+ def keys(self):
266
+ return list(self.tensors.keys())
267
+
268
+ def get_tensor(self, name):
269
+ return self.tensors[name].get()
270
+
271
+
272
+ def _fast_load_file(filename):
273
+ """
274
+ Load safetensors info from a specified file.
275
+ """
276
+ result = {}
277
+ with _fast_safe_open(filename, framework="np") as f:
278
+ for k in f.keys():
279
+ result[k] = f.get_tensor(k)
280
+ return result
45
281
 
46
282
 
47
283
  def _progress_bar(iterable, total=None):
@@ -267,15 +503,22 @@ def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_fi
267
503
  pipe_param_list[layout[6][0]].append(name)
268
504
  part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
269
505
  processes = []
270
- for i in range(process_num):
271
- p = mp.Process(target=_transform_safetensors_single, args=(
272
- part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
273
- src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
274
- ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
275
- p.start()
276
- processes.append(p)
277
- for p in processes:
278
- p.join()
506
+ if process_num > 1:
507
+ for i in range(process_num):
508
+ p = mp.Process(target=_transform_safetensors_single, args=(
509
+ part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
510
+ src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
511
+ ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
512
+ p.start()
513
+ processes.append(p)
514
+ for p in processes:
515
+ p.join()
516
+ else:
517
+ _transform_safetensors_single(part_list_dict[0], all_safetensor_files_map, src_stage_device_num,
518
+ dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
519
+ origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
520
+ dst_safetensors_dir, output_format, _transform_param_list,
521
+ pipe_param_list[0])
279
522
 
280
523
 
281
524
  def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
@@ -288,7 +531,7 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
288
531
  return set()
289
532
 
290
533
 
291
- def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
534
+ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, safetensor_dict, redundancy_dict,
292
535
  needed_rank, device_num, choice_func):
293
536
  """Find the rank_id under redundant groups."""
294
537
  io_time = 0
@@ -305,7 +548,7 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
305
548
  break
306
549
  if open_file_id is not None:
307
550
  start_time = time.time()
308
- output = file_dict[open_file_id].get_slice(param_name)
551
+ output = file_dict[open_file_id].get_tensor(param_name)
309
552
  end_time = time.time()
310
553
  cost_time = end_time - start_time
311
554
  io_time += cost_time
@@ -316,7 +559,7 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
316
559
  if not isinstance(choice_out, (bool, str)):
317
560
  raise ValueError("For 'unified_safetensors', the return value type of the function "
318
561
  f"'choice_func' must be bool or str, but got {type(choice_out)}.")
319
- saftensor_dict[param_name] = output
562
+ safetensor_dict[param_name] = output
320
563
  else:
321
564
  raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
322
565
  f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
@@ -334,6 +577,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
334
577
  Transforms safetensors files to a specified format without using parallel processing.
335
578
  """
336
579
  io_cost_time = 0
580
+ meta_data = {"format": "ms"}
337
581
  if src_strategy_file is not None:
338
582
  from mindspore.train._utils import get_parameter_redundancy
339
583
  redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
@@ -353,13 +597,15 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
353
597
  file_dict = {}
354
598
  single_param_dict = {}
355
599
  for file_id, _ in all_safetensor_files_map.items():
356
- f = safe_open(all_safetensor_files_map.get(file_id), framework="np")
600
+ f = _fast_safe_open(all_safetensor_files_map.get(file_id), framework="np")
357
601
  file_dict[file_id] = f
358
602
  for param_name in f.keys():
359
603
  if param_name not in single_param_dict.keys():
360
604
  single_param_dict[param_name] = {file_id}
361
605
  else:
362
606
  single_param_dict[param_name].add(file_id)
607
+ if f.metadata() is not None:
608
+ meta_data.update(f.metadata())
363
609
  src_strategy_list_keys = _convert_to_list(src_strategy_dict).keys() if src_strategy_dict else []
364
610
  dst_strategy_list_keys = _convert_to_list(dst_strategy_dict).keys() if dst_strategy_dict else []
365
611
  for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
@@ -368,27 +614,29 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
368
614
  needed_rank_list = needed_rank_list_key.split("-")
369
615
  for needed_rank in needed_rank_list:
370
616
  if pipe_param_list:
371
- saftensor_dict = dict()
617
+ safetensor_dict = dict()
372
618
  if src_strategy_file is not None:
373
619
  io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
374
- saftensor_dict, redundancy_dict, needed_rank,
620
+ safetensor_dict, redundancy_dict, needed_rank,
375
621
  device_num, choice_func)
376
622
  io_cost_time += io_time
377
623
  else:
378
- with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
624
+ with _fast_safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
379
625
  if not unified_flag:
380
626
  all_param_name_set = set(f.keys())
381
627
  src_param_name_set = set(src_strategy_list_keys)
382
628
  dst_param_name_set = set(dst_strategy_list_keys)
383
629
  hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
384
630
  pipe_param_list.extend(list(hyper_param_set))
631
+ if f.metadata() is not None:
632
+ meta_data.update(f.metadata())
385
633
  io_time = 0
386
634
  for param_name in pipe_param_list:
387
635
  if param_name not in f.keys():
388
636
  # param not in ckpt file, check reason
389
637
  continue
390
638
  start_time = time.time()
391
- output = f.get_slice(param_name)
639
+ output = f.get_tensor(param_name)
392
640
  end_time = time.time()
393
641
  cost_time = end_time - start_time
394
642
  io_time += cost_time
@@ -400,15 +648,15 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
400
648
  if not isinstance(choice_out, (bool, str)):
401
649
  raise ValueError("For 'unified_safetensors', the return value type of the function "
402
650
  f"'choice_func' must be bool or str, but got {type(choice_out)}.")
403
- saftensor_dict[param_name] = output
651
+ safetensor_dict[param_name] = output
404
652
  else:
405
653
  start_time = time.time()
406
- saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
654
+ safetensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
407
655
  end_time = time.time()
408
656
  cost_time = end_time - start_time
409
657
  io_cost_time += cost_time
410
658
 
411
- for param_name, param in saftensor_dict.items():
659
+ for param_name, param in safetensor_dict.items():
412
660
  src_rank = int(needed_rank) % src_stage_device_num
413
661
  param_total_dict[param_name][src_rank] = param
414
662
  param_attr_dict[param_name][src_rank] = (True, False)
@@ -442,11 +690,11 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
442
690
  else:
443
691
  if transform_param_dict:
444
692
  if output_format == "safetensors":
445
- save_file(transform_param_dict, save_file_name)
693
+ _save_file_atomically(transform_param_dict, save_file_name, metadata=meta_data)
446
694
  else:
447
- transform_param_dict = _load_and_transform(transform_param_dict,
448
- None, None, transform_func=
449
- lambda v, name: ms.Parameter(v, name=name))
695
+ transform_param_dict = _load_and_transform(transform_param_dict, None, None,
696
+ transform_func=lambda v, name: Parameter(v,
697
+ name=name))
450
698
  ms.save_checkpoint(transform_param_dict, save_file_name)
451
699
  del param_total_dict_keys
452
700
  del param_total_dict
@@ -464,10 +712,10 @@ def _save_final_safetensors(_transform_param_list, output_format):
464
712
  new_transform_dict[save_file_name].update(transform_param_dict)
465
713
  for save_file_name, transform_param_dict in new_transform_dict.items():
466
714
  if output_format == "safetensors":
467
- save_file(transform_param_dict, save_file_name)
715
+ _save_file_atomically(transform_param_dict, save_file_name, metadata={"format": "ms"})
468
716
  else:
469
717
  transform_param_dict = _load_and_transform(transform_param_dict, None, None,
470
- transform_func=lambda v, name: ms.Parameter(v, name=name))
718
+ transform_func=lambda v, name: Parameter(v, name=name))
471
719
  ms.save_checkpoint(transform_param_dict, save_file_name)
472
720
 
473
721
 
@@ -501,8 +749,8 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
501
749
  if not os.path.exists(local_file):
502
750
  raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
503
751
  for rank, file_name in safetensor_files_map.items():
504
- saftensor_dict = load_file(file_name)
505
- for param_name, param in saftensor_dict.items():
752
+ safetensor_dict = load_file(file_name)
753
+ for param_name, param in safetensor_dict.items():
506
754
  # cut the parameter not in the pipeline stage.
507
755
  if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
508
756
  and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
@@ -520,7 +768,7 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
520
768
  if not os.path.exists(save_safetensor_file_dir):
521
769
  _make_dir(save_safetensor_file_dir, "path")
522
770
  save_safetensor_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
523
- save_file(transform_param_dict, save_safetensor_file_name)
771
+ _save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
524
772
 
525
773
 
526
774
  def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor_file_name,
@@ -556,8 +804,8 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
556
804
  origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
557
805
  origin_src_strategy_list = _extract_layout_map(src_strategy_file)
558
806
  for rank, file_name in safetensor_files_map.items():
559
- saftensor_dict = load_file(file_name)
560
- for param_name, param in saftensor_dict.items():
807
+ safetensor_dict = load_file(file_name)
808
+ for param_name, param in safetensor_dict.items():
561
809
  # cut the parameter not in the pipeline stage.
562
810
  if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
563
811
  and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
@@ -572,7 +820,7 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
572
820
  transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
573
821
  param_attr_dict, src_strategy_list, dst_strategy_list,
574
822
  param_type_dict)
575
- save_file(transform_param_dict, save_safetensor_file_name)
823
+ _save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
576
824
 
577
825
 
578
826
  def _extrace_number(file_name):
@@ -628,7 +876,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
628
876
 
629
877
  def load_file_by_param_name(filename, parme_name_list):
630
878
  result = {}
631
- with safe_open(filename, framework="np") as f:
879
+ with _fast_safe_open(filename, framework="np") as f:
632
880
  for k in parme_name_list:
633
881
  result[k] = f.get_tensor(k)
634
882
  return result
@@ -644,10 +892,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
644
892
  device_num = -1
645
893
  param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
646
894
  for param_name in param_total_dict_keys:
647
- if str(type(list(param_total_dict[param_name].values())[0])) == "<class 'builtins.PySafeSlice'>":
648
- tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
649
- else:
650
- tensor_shape = list(param_total_dict[param_name].values())[0].shape
895
+ tensor_shape = list(param_total_dict[param_name].values())[0].shape
651
896
  from_dev_matrix = [1]
652
897
  from_tensor_map = [-1] * len(tensor_shape)
653
898
  from_opt_shard_step = 0
@@ -695,7 +940,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
695
940
  # when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
696
941
  device_list = list(range(0, np.prod(from_tensor_layout[0])))
697
942
  if rank_id % device_num not in param_attr_dict[param_name] and src_strategy_file is None:
698
- raise ValueError("The safetensor of rank {} is missing.".format(rank_id % device_num))
943
+ raise ValueError("The param: {} in rank {} is missing.".format(param_name, rank_id % device_num))
699
944
  param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
700
945
  device_list, rank_id)
701
946
 
@@ -711,8 +956,6 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
711
956
  if isinstance(choice_out, str):
712
957
  param_name = choice_out
713
958
  transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
714
- if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
715
- transform_param_dict[param_name] = transform_param_dict[param_name][:]
716
959
 
717
960
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
718
961
  for param_name in param_total_dict_keys:
@@ -722,33 +965,14 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
722
965
  continue
723
966
  if param_name not in transform_param_dict:
724
967
  transform_para = param_total_dict[param_name][rank_id % device_num]
725
- if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
726
- transform_para = transform_para[:]
727
968
  transform_param_dict[param_name] = transform_para
728
969
  return transform_param_dict
729
970
 
730
971
 
731
972
  def _cal_param_size(shape, dtype):
732
973
  """cal param size by dtype and shape"""
733
- dtype_size = {
734
- "BOOL": 1,
735
- "U8": 1,
736
- "I8": 1,
737
- "F8_E5M2": 1,
738
- "F8_E4M3": 1,
739
- "I16": 2,
740
- "U16": 2,
741
- "I32": 4,
742
- "U32": 4,
743
- "I64": 8,
744
- "U64": 8,
745
- "F16": 2,
746
- "BF16": 2,
747
- "F32": 4,
748
- "F64": 8,
749
- }
750
974
  num_elements = math.prod(shape)
751
- element_size = dtype_size.get(dtype, 4)
975
+ element_size = np_dtype_size.get(dtype, 4)
752
976
  total_bytes = num_elements * element_size
753
977
  return total_bytes
754
978
 
@@ -769,14 +993,15 @@ def _split_weight_dict(weights, num_groups):
769
993
  def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
770
994
  """save hyper param"""
771
995
  if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
772
- with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
996
+ with _fast_safe_open(all_safetensor_files_map.get(0), framework="np") as f:
773
997
  all_key = f.keys()
774
998
  hyper_parameter = set(all_key) - set(name_list)
775
999
  if hyper_parameter:
776
1000
  hyper_dict = {}
777
1001
  for key in hyper_parameter:
778
1002
  hyper_dict[key] = f.get_tensor(key)
779
- save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
1003
+ _save_file_atomically(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"),
1004
+ metadata={"format": "ms"})
780
1005
 
781
1006
 
782
1007
  def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
@@ -826,14 +1051,57 @@ def _get_dst_shape(param_name, param_shape, src_strategy_list):
826
1051
  return to_full_tensor_shape
827
1052
 
828
1053
 
1054
+ def _check_remove_redundancy(merge_with_redundancy, f):
1055
+ """Check whether remove_redundancy is consistent with the safetensors file."""
1056
+ if f.metadata() is not None and "remove_redundancy" in f.metadata().keys():
1057
+ if f.metadata()["remove_redundancy"] == "True" and merge_with_redundancy:
1058
+ logger.warning("For 'unified_safetensors', the safetensors file is deduplicated, "
1059
+ "but merge_with_redundancy is set to True.")
1060
+ return False
1061
+ if f.metadata()["remove_redundancy"] == "False" and not merge_with_redundancy:
1062
+ logger.warning("For 'unified_safetensors', the safetensors file is non-deduplicated, "
1063
+ "but merge_with_redundancy is set to False.")
1064
+ return True
1065
+ return merge_with_redundancy
1066
+
1067
+
1068
+ def set_affinity_pid():
1069
+ """Set CPU affinity pid"""
1070
+ pid = os.getpid()
1071
+ total_cores = os.cpu_count()
1072
+ all_cores = set(range(total_cores))
1073
+ os.sched_setaffinity(pid, all_cores)
1074
+
1075
+
1076
+ def _validate_safetensors_files(target_directory, expected_file_ids):
1077
+ """Validate whether safetensors files are completely generated in the target directory."""
1078
+ missing_file_ids = []
1079
+ for file_id in expected_file_ids:
1080
+ safetensors_file = os.path.join(target_directory, f"part{file_id}.safetensors")
1081
+ if os.path.exists(safetensors_file):
1082
+ continue
1083
+ missing_file_ids.append(file_id)
1084
+
1085
+ if missing_file_ids:
1086
+ logger.warning(
1087
+ f"For unified_safetensors, target file part {missing_file_ids} does not exist. "
1088
+ f"Possible causes: file rename failed, insufficient permissions, or disk space shortage."
1089
+ )
1090
+
1091
+
829
1092
  def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
830
1093
  max_process_num=64, choice_func=None, split_dst_file=()):
831
1094
  """
832
1095
  Merge multiple safetensor files into a unified safetensor file.
833
1096
 
1097
+ Note:
1098
+ When merging weights, it will verify whether the `merge_with_redundancy` parameter differs from
1099
+ the deduplication flag in the merged safetensors files. If they are the same, the merging will be performed
1100
+ according to the deduplication flag in the files.
1101
+
834
1102
  Args:
835
1103
  src_dir (str): Source weight saving directory.
836
- src_strategy_file (str): Source weight segmentation strategy file.
1104
+ src_strategy_file (str): Source weight segmentation strategy file with the file extension `.ckpt` .
837
1105
  dst_dir (str): Target save directory.
838
1106
  merge_with_redundancy (bool, optional): Whether the merged source weight files are de-duplicated and
839
1107
  saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
@@ -861,10 +1129,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
861
1129
  >>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
862
1130
  >>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
863
1131
  """
864
- pid = os.getpid()
865
- total_cores = os.cpu_count()
866
- all_cores = set(range(total_cores))
867
- os.sched_setaffinity(pid, all_cores)
1132
+ set_affinity_pid()
868
1133
  _check_transform_safetensors(src_dir, "", src_strategy_file, None)
869
1134
  _make_dir(dst_dir, "path")
870
1135
  if os.path.isfile(src_dir):
@@ -890,8 +1155,9 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
890
1155
 
891
1156
  actual_params = set()
892
1157
  for _, file_name in all_safetensor_files_map.items():
893
- with safe_open(file_name, framework="np") as f:
1158
+ with _fast_safe_open(file_name, framework="np") as f:
894
1159
  actual_params.update(f.keys())
1160
+ merge_with_redundancy = _check_remove_redundancy(merge_with_redundancy, f)
895
1161
 
896
1162
  params_to_store = actual_params & set(layout_map.keys())
897
1163
 
@@ -904,21 +1170,22 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
904
1170
  param_size_dict = {}
905
1171
  param_total_size = 0
906
1172
  for _, file_name in all_safetensor_files_map.items():
907
- with safe_open(file_name, framework="np") as f:
1173
+ with _fast_safe_open(file_name, framework="np") as f:
908
1174
  for k in f.keys():
909
1175
  if k in name_list:
910
- py_slice = f.get_slice(k)
911
- param_total_size += _cal_param_size(py_slice.get_shape(), py_slice.get_dtype())
912
- param_dst_shape = _get_dst_shape(k, py_slice.get_shape(), origin_src_strategy_list)
1176
+ py_slice = f.get_tensor(k)
1177
+ param_total_size += _cal_param_size(py_slice.shape, py_slice.dtype)
1178
+ param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
913
1179
  # Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
914
1180
  param_dst_shape = [int(item) for item in param_dst_shape]
915
1181
  if choice_func is not None:
916
1182
  choice_out = choice_func(k)
917
1183
  if isinstance(choice_out, bool):
918
1184
  if not choice_out:
1185
+ name_list.remove(k)
919
1186
  continue
920
1187
  if k not in param_size_dict:
921
- param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.get_dtype())
1188
+ param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.dtype)
922
1189
  split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
923
1190
  split_num = min(split_num, len(name_list))
924
1191
  split_list = _split_weight_dict(param_size_dict, split_num)
@@ -932,37 +1199,44 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
932
1199
  start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
933
1200
  end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
934
1201
  sub_list = []
935
- for i in range(len(split_list)):
1202
+ for i, item in enumerate(split_list):
936
1203
  if start_index <= i < end_index:
937
- sub_list.append(split_list[i])
1204
+ sub_list.append(item)
938
1205
  else:
939
1206
  sub_list.append([-1])
1207
+ split_num = end_index - start_index
1208
+ res = list(range(start_index, end_index))
940
1209
  else:
941
1210
  sub_list = split_list
1211
+ res = [i for i in range(split_num)]
942
1212
 
943
1213
  _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
944
1214
  _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
945
1215
 
946
- if split_dst_file:
947
- split_num = end_index - start_index
948
- res = list(range(start_index, end_index))
949
- else:
950
- res = [i for i in range(split_num)]
951
1216
  max_process = min(split_num, max_process_num)
1217
+ file_ids = res[:]
952
1218
  res = _split_list(res, max_process)
953
1219
  processes = []
954
1220
  src_strategy_name = None
955
1221
  if not merge_with_redundancy:
956
1222
  src_strategy_name = src_strategy_file
957
- for i in range(max_process):
958
- p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
959
- needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
960
- src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
961
- "", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
962
- p.start()
963
- processes.append(p)
964
- for p in processes:
965
- p.join()
1223
+ if max_process > 1:
1224
+ for i in range(max_process):
1225
+ p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
1226
+ needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
1227
+ src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
1228
+ "", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
1229
+ p.start()
1230
+ processes.append(p)
1231
+ for p in processes:
1232
+ p.join()
1233
+ else:
1234
+ _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
1235
+ dst_stage_device_num, src_strategy_dict, None,
1236
+ origin_src_strategy_list, origin_dst_strategy_list, "",
1237
+ dst_dir, "safetensors", None, sub_list,
1238
+ res[0], True, src_strategy_name, choice_func)
1239
+ _validate_safetensors_files(dst_dir, file_ids)
966
1240
 
967
1241
 
968
1242
  def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map,
@@ -997,7 +1271,7 @@ def _split_list(split_list, split_num):
997
1271
  def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num):
998
1272
  """apply safetensors object operators"""
999
1273
  if not transform_operator_stack:
1000
- return sf_obj[:]
1274
+ return sf_obj
1001
1275
  level = transform_operator_stack[-1][1]
1002
1276
  level_operators = []
1003
1277
  while True:
@@ -1022,7 +1296,7 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
1022
1296
  allgather_list = [sf_obj for _ in operator[1][:-1]]
1023
1297
  tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
1024
1298
  if op_name == "AllConcat":
1025
- for rank, value in tmp_tensor_dict.items():
1299
+ for _, value in tmp_tensor_dict.items():
1026
1300
  sf_obj = value
1027
1301
  level_operators.clear()
1028
1302
  if not transform_operator_stack:
@@ -1037,13 +1311,26 @@ def _process_hyper_params(file_list, total_safetensors_dir, total_param):
1037
1311
  """process hyper params"""
1038
1312
  if 'hyper_param.safetensors' in file_list:
1039
1313
  hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
1040
- with safe_open(hyper_parameter_file_name, framework="np") as f:
1314
+ with _fast_safe_open(hyper_parameter_file_name, framework="np") as f:
1041
1315
  for key in f.keys():
1042
- total_param[key] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
1316
+ total_param[key] = Parameter(Tensor.from_numpy(f.get_tensor(key)))
1043
1317
  return total_param
1044
1318
 
1045
1319
 
1046
- def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files, dst_strategy_file, rank_id):
1320
+ def _get_param_name_map_by_file(file_name, file_list, name_map):
1321
+ """get param_name_map by file"""
1322
+ with _fast_safe_open(file_name, framework="np") as f:
1323
+ keys = f.keys()
1324
+ values = len(keys) * [file_list[0]]
1325
+ if name_map:
1326
+ flipped_name_map = {value: key for key, value in name_map.items()}
1327
+ keys = [flipped_name_map.get(key, key) for key in keys]
1328
+ param_name_map = dict(zip(keys, values))
1329
+ return param_name_map
1330
+
1331
+
1332
+ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files,
1333
+ dst_strategy_file, rank_id, name_map=None):
1047
1334
  """calculate param_name_map and param_list"""
1048
1335
  if len(file_list) == 1:
1049
1336
  logger.info("There is only one weight file in the directory, which will be automatically mapped.")
@@ -1052,10 +1339,7 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
1052
1339
  if not is_file:
1053
1340
  raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
1054
1341
  f"in the `unified_safetensors_dir`.")
1055
- with safe_open(file_name, framework="np") as f:
1056
- keys = f.keys()
1057
- values = len(keys) * [file_list[0]]
1058
- param_name_map = dict(zip(keys, values))
1342
+ param_name_map = _get_param_name_map_by_file(file_name, file_list, name_map)
1059
1343
  else:
1060
1344
  if not json_files:
1061
1345
  raise ValueError(
@@ -1076,19 +1360,71 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
1076
1360
  return param_name_map, param_list, dst_strategy_list
1077
1361
 
1078
1362
 
1363
+ def _cal_transform_operator_stack_and_device_num(from_dev_matrix, from_tensor_map, from_opt_shard_step,
1364
+ from_opt_shard_size, param_name, dst_strategy_list, tensor_shape,
1365
+ local_rank_id):
1366
+ """cal transform_operator_stack and device_num"""
1367
+ to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
1368
+ dst_strategy_list.get(param_name))
1369
+
1370
+ device_num = np.prod(from_dev_matrix)
1371
+ param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
1372
+ origin_tensor_shape = ()
1373
+ for i, item in enumerate(tensor_shape):
1374
+ if i == 0 and from_opt_shard_size > 0:
1375
+ origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
1376
+ continue
1377
+ origin_tensor_shape += (item * param_strategy[i],)
1378
+
1379
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
1380
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
1381
+
1382
+ from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1383
+ from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
1384
+ to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1385
+ to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
1386
+ # Convert tensor layout to same device num
1387
+ from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
1388
+ from_dev_matrix,
1389
+ from_tensor_map,
1390
+ to_full_tensor_shape,
1391
+ to_dev_matrix, to_tensor_map)
1392
+
1393
+ # when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
1394
+ device_list = list(range(0, np.prod(from_tensor_layout[0])))
1395
+ param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
1396
+ device_list, local_rank_id)
1397
+
1398
+ from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
1399
+ to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
1400
+ _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
1401
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
1402
+ has_layout_from, has_layout_to)
1403
+ transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
1404
+ return transform_operator_stack, device_num
1405
+
1406
+
1407
+ def check_param_dtype(file, param_name):
1408
+ dtype_need_changed = False
1409
+ changed_dtype = None
1410
+ if file.metadata() is not None and param_name in file.metadata().keys():
1411
+ dtype_need_changed = True
1412
+ sf_dtype = file.metadata()[param_name]
1413
+ changed_dtype = safetensors_to_mstype[sf_dtype]
1414
+ return dtype_need_changed, changed_dtype
1415
+
1416
+
1079
1417
  def _load_parallel_checkpoint(file_info):
1080
1418
  """load parallel safetensors by merged file."""
1081
1419
  total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
1082
- rank_id, output_format, name_map, return_param_dict = file_info
1083
- pid = os.getpid()
1084
- total_cores = os.cpu_count()
1085
- all_cores = set(range(total_cores))
1086
- os.sched_setaffinity(pid, all_cores)
1420
+ rank_id, output_format, name_map, return_param_dict = file_info
1421
+ set_affinity_pid()
1087
1422
  file_list = os.listdir(total_safetensors_dir)
1088
1423
  json_files = [file for file in file_list if file == "param_name_map.json"]
1089
- param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
1424
+ sf_files = [file for file in file_list if file.endswith('.safetensors')]
1425
+ param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(sf_files, total_safetensors_dir,
1090
1426
  json_files, dst_strategy_file,
1091
- rank_id)
1427
+ rank_id, name_map)
1092
1428
  total_param = dict()
1093
1429
  dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
1094
1430
  is not None else 1
@@ -1098,13 +1434,14 @@ def _load_parallel_checkpoint(file_info):
1098
1434
  if param_name not in param_name_map:
1099
1435
  continue
1100
1436
  file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
1101
- with safe_open(file_name, framework="np") as f:
1437
+ with _fast_safe_open(file_name, framework="np") as f:
1102
1438
  cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
1103
1439
  if cur_param_name not in f.keys():
1104
1440
  continue
1105
- sf_obj = f.get_slice(cur_param_name)
1441
+ sf_obj = f.get_tensor(cur_param_name)
1442
+ dtype_need_changed, changed_dtype = check_param_dtype(f, param_name)
1106
1443
 
1107
- tensor_shape = sf_obj.get_shape()
1444
+ tensor_shape = sf_obj.shape
1108
1445
  from_dev_matrix = [1]
1109
1446
  from_tensor_map = [-1] * len(tensor_shape)
1110
1447
  from_opt_shard_step = 0
@@ -1112,43 +1449,14 @@ def _load_parallel_checkpoint(file_info):
1112
1449
  if dst_strategy_list is not None:
1113
1450
  if param_name not in dst_strategy_list:
1114
1451
  continue
1115
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
1116
- dst_strategy_list.get(param_name))
1117
-
1118
- device_num = np.prod(from_dev_matrix)
1119
- param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
1120
- origin_tensor_shape = ()
1121
- for i, item in enumerate(tensor_shape):
1122
- if i == 0 and from_opt_shard_size > 0:
1123
- origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
1124
- continue
1125
- origin_tensor_shape += (item * param_strategy[i],)
1126
-
1127
- has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
1128
- has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
1129
-
1130
- from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1131
- from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
1132
- to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1133
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
1134
- # Convert tensor layout to same device num
1135
- from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
1136
- from_dev_matrix,
1137
- from_tensor_map,
1138
- to_full_tensor_shape,
1139
- to_dev_matrix, to_tensor_map)
1140
-
1141
- # when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
1142
- device_list = list(range(0, np.prod(from_tensor_layout[0])))
1143
- param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
1144
- device_list, local_rank_id)
1145
-
1146
- from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
1147
- to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
1148
- _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
1149
- _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
1150
- has_layout_from, has_layout_to)
1151
- transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
1452
+ transform_operator_stack, device_num = _cal_transform_operator_stack_and_device_num(from_dev_matrix,
1453
+ from_tensor_map,
1454
+ from_opt_shard_step,
1455
+ from_opt_shard_size,
1456
+ param_name,
1457
+ dst_strategy_list,
1458
+ tensor_shape,
1459
+ local_rank_id)
1152
1460
  start_time = time.time()
1153
1461
  slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
1154
1462
  end_time = time.time()
@@ -1156,11 +1464,15 @@ def _load_parallel_checkpoint(file_info):
1156
1464
  total_io_cost_time += cost_time
1157
1465
  else:
1158
1466
  start_time = time.time()
1159
- slice_param = sf_obj[:]
1467
+ slice_param = sf_obj
1160
1468
  end_time = time.time()
1161
1469
  cost_time = end_time - start_time
1162
1470
  total_io_cost_time += cost_time
1163
- total_param[param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
1471
+ slice_param_copy = np.copy(slice_param)
1472
+ if dtype_need_changed:
1473
+ total_param[param_name] = Parameter(Tensor(slice_param_copy, dtype=changed_dtype))
1474
+ else:
1475
+ total_param[param_name] = Parameter(Tensor.from_numpy(slice_param_copy))
1164
1476
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1165
1477
  f"load distributed safetensors io cost time:{total_io_cost_time}.")
1166
1478
  total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
@@ -1177,28 +1489,5 @@ def _load_parallel_checkpoint(file_info):
1177
1489
  return None
1178
1490
 
1179
1491
 
1180
- def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
1181
- """get slice op"""
1182
- tensor_shape = sf_obj.get_shape()
1183
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
1184
- dst_strategy_list.get(param_name))
1185
- # Add optimizer sharding dim for tensor layout
1186
- to_dev_matrix, to_tensor_map, _ = _construct_tensor_layout_for_opt_shard(
1187
- to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, tensor_shape)
1188
- slice_op = _load_tensor_shape(to_dev_matrix, to_tensor_map, full_shape=tensor_shape, rank_id=rank_id)
1189
- shape = None
1190
- if to_opt_shard_size > 0:
1191
- to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
1192
- to_slice_tensor_shape = ()
1193
- for i, item in enumerate(tensor_shape):
1194
- if i == 0 and to_opt_shard_size > 0:
1195
- to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
1196
- continue
1197
- to_slice_tensor_shape += (item // to_tensor_strategy[i],)
1198
- shape = list(to_slice_tensor_shape)
1199
-
1200
- return slice_op, shape
1201
-
1202
-
1203
1492
  __all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
1204
1493
  "transform_safetensors_by_rank", "unified_safetensors"]