mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,26 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  import os
19
- import time
19
+ import sys
20
20
  import glob
21
- import re
22
21
  import math
23
22
  import json
23
+ import re
24
24
  from collections import defaultdict
25
25
 
26
+ import time
26
27
  import multiprocessing as mp
28
+ import psutil
27
29
  import numpy as np
30
+ from safetensors.numpy import save_file, load_file
31
+ from safetensors import safe_open
32
+
28
33
  import mindspore as ms
34
+ from mindspore import log as logger
35
+ from mindspore.log import vlog_print
29
36
  from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
30
37
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
31
- _insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
38
+ _insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
32
39
  from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
33
40
  _get_needed_rank_transform_operator_map_by_layouts, \
34
41
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
@@ -36,70 +43,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
36
43
  from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
37
44
  _convert_to_list
38
45
 
39
- from safetensors.numpy import save_file, load_file
40
- from safetensors import safe_open
41
-
42
-
43
- def _load_and_transform(path, name_map, load_func, transform_func):
44
- if load_func is not None:
45
- param_dict = load_func(path)
46
- else:
47
- param_dict = path
48
- transform_dict = {}
49
- for k, v in param_dict.items():
50
- new_name = name_map.get(k, k) if name_map is not None else k
51
- transform_dict[new_name] = transform_func(v, new_name)
52
- return transform_dict
53
-
54
-
55
- def _transform_tensor_to_numpy(path, name_map=None):
56
- return _load_and_transform(path, name_map, ms.load_checkpoint, lambda v, new_name: v.asnumpy())
57
-
58
-
59
- def _transform_numpy_to_tensor(path, name_map=None):
60
- return _load_and_transform(path, name_map, load_file, lambda v, new_name: ms.Parameter(v, name=new_name))
61
-
62
-
63
- def _process_file(file_info):
64
- cur_ckpt_path, name_map, save_path, file = file_info
65
- param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
66
- safetensors_filename = file.replace(".ckpt", ".safetensors")
67
- dst_file = os.path.join(save_path, safetensors_filename)
68
- save_file(param_dict_numpy, dst_file)
69
-
70
-
71
- def _process_file_safetensors(file_info):
72
- cur_safe_path, name_map, save_path, file = file_info
73
- param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
74
- ckpt_filename = file.replace(".safetensors", ".ckpt")
75
- dst_file = os.path.join(save_path, ckpt_filename)
76
- ms.save_checkpoint(param_dict_tensor, dst_file)
77
-
78
-
79
- def _gather_tasks(file_path, save_path, file_name_regex, name_map):
80
- """gather transform rank together"""
81
- tasks = []
82
- for root, dirs, _ in os.walk(file_path):
83
- if root != file_path:
84
- continue
85
-
86
- rank_dirs = [d for d in dirs if d.startswith('rank')]
87
- if not rank_dirs:
88
- raise ValueError(
89
- f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
90
-
91
- for rank_dir in rank_dirs:
92
- rank_dir_path = os.path.join(root, rank_dir)
93
- dst_root = os.path.join(save_path,
94
- os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
95
- os.makedirs(dst_root, exist_ok=True)
96
- tasks.extend(
97
- (os.path.join(rank_dir_path, file), name_map, dst_root, file)
98
- for file in os.listdir(rank_dir_path)
99
- if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
100
- )
101
- return tasks
102
-
103
46
 
104
47
  def _progress_bar(iterable, total=None):
105
48
  """
@@ -125,6 +68,7 @@ def _progress_bar(iterable, total=None):
125
68
  elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
126
69
  remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
127
70
 
71
+ sys.stdout.reconfigure(encoding="utf-8")
128
72
  print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
129
73
  if iteration == total:
130
74
  print()
@@ -134,155 +78,16 @@ def _progress_bar(iterable, total=None):
134
78
  print_progress_bar(i)
135
79
 
136
80
 
137
- def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
138
- """
139
- Converts MindSpore checkpoint files into safetensors format and saves them to `save_path`.
140
- Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
141
- used for securely storing Tensors with fast speed (zero copy).
142
-
143
- Note:
144
- The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
145
- too large, otherwise it may cause freezing.
146
- The safetensors format does not support the enc verification function. If ckpt is enabled to save enc
147
- verification, an error will be generated when performing the conversion.
148
- The safetensors format currently does not support crc verification function. If ckpt contains crc verification
149
- information, the crc verification information will be lost after conversion to safetensors.
150
-
151
- Args:
152
- file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
153
- save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
154
- name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
155
- file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
156
- Defaults: ``None``.
157
- processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
158
- Raises:
159
- ValueError: If the input path is invalid or the save_path is not a directory,
160
- or the file_path does not end with '.ckpt'.
161
-
162
- Supported Platforms:
163
- ``Ascend`` ``GPU`` ``CPU``
164
-
165
- Examples:
166
- >>> import mindspore as ms
167
- >>> ms.ckpt_to_safetensors("./ckpt_save_path")
168
- >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
169
- >>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
170
- >>> namemap = {"lin.weight":"new_name"}
171
- >>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
172
- """
173
- is_dir = os.path.isdir(file_path)
174
- is_file = os.path.isfile(file_path)
175
- if not is_dir and not is_file:
176
- raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
177
- if save_path and os.path.splitext(save_path)[1]:
178
- raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
179
- if name_map is not None and not isinstance(name_map, dict):
180
- raise ValueError(
181
- f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
182
-
183
- if is_dir:
184
- tasks = _gather_tasks(file_path, save_path, file_name_regex, name_map)
185
- with mp.Pool(processes=processes_num) as pool:
186
- list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
187
- elif is_file:
188
- if not file_path.endswith(".ckpt"):
189
- raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
190
- if file_name_regex is not None and not re.findall(file_name_regex, file_path):
191
- raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
192
- if save_path and not os.path.exists(save_path):
193
- os.makedirs(save_path, exist_ok=True)
194
-
195
- param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
196
- safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
197
- dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
198
- save_file(param_dict_numpy, dst_file)
199
-
200
-
201
- def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
202
- """gather transform rank together"""
203
- tasks = []
204
- for root, dirs, _ in os.walk(file_path):
205
- if root != file_path:
206
- continue
207
-
208
- rank_dirs = [d for d in dirs if d.startswith('rank')]
209
- if not rank_dirs:
210
- raise ValueError(
211
- f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
212
-
213
- for rank_dir in rank_dirs:
214
- rank_dir_path = os.path.join(root, rank_dir)
215
- dst_root = os.path.join(save_path,
216
- os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
217
- os.makedirs(dst_root, exist_ok=True)
218
- tasks.extend(
219
- (os.path.join(rank_dir_path, file), name_map, dst_root, file)
220
- for file in os.listdir(rank_dir_path)
221
- if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
222
- )
223
- return tasks
224
-
225
-
226
- def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
227
- """
228
- Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
229
- Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
230
- used for securely storing Tensors with fast speed (zero copy).
231
-
232
- Note:
233
- The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
234
- too large, otherwise it may cause freezing.
235
-
236
- Args:
237
- file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
238
- save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
239
- name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
240
- file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
241
- Defaults: ``None``.
242
- processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
243
-
244
- Raises:
245
- ValueError: If the input path is invalid, the save_path is not a directory,
246
- or the file_path does not end with '.safetensors'.
247
-
248
- Supported Platforms:
249
- ``Ascend`` ``GPU`` ``CPU``
250
-
251
- Examples:
252
- >>> import mindspore as ms
253
- >>> ms.safetensors_to_ckpt("./safetensors_save_path")
254
- >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
255
- >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
256
- >>> namemap = {"lin.weight":"new_name"}
257
- >>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
258
- """
259
- is_dir = os.path.isdir(file_path)
260
- is_file = os.path.isfile(file_path)
261
- if not is_dir and not is_file:
262
- raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
263
- if save_path and os.path.splitext(save_path)[1]:
264
- raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
265
- if name_map is not None and not isinstance(name_map, dict):
266
- raise ValueError(
267
- f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
268
-
269
- if is_dir:
270
- tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
271
- with mp.Pool(processes=processes_num) as pool:
272
- list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
273
- elif is_file:
274
- if not file_path.endswith(".safetensors"):
275
- raise ValueError(
276
- f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
277
- if file_name_regex is not None and not re.findall(file_name_regex, file_path):
278
- raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
279
- if save_path and not os.path.exists(save_path):
280
- os.makedirs(save_path, exist_ok=True)
281
-
282
- param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
283
- ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
284
- dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
285
- ms.save_checkpoint(param_dict_tensor, dst_file)
81
+ def _load_and_transform(path, name_map, load_func, transform_func):
82
+ if load_func is not None:
83
+ param_dict = load_func(path)
84
+ else:
85
+ param_dict = path
86
+ transform_dict = {}
87
+ for k, v in param_dict.items():
88
+ new_name = name_map.get(k, k) if name_map is not None else k
89
+ transform_dict[new_name] = transform_func(v, new_name)
90
+ return transform_dict
286
91
 
287
92
 
288
93
  def _check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file):
@@ -460,7 +265,6 @@ def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_fi
460
265
 
461
266
  for name, layout in layout_map.items():
462
267
  pipe_param_list[layout[6][0]].append(name)
463
-
464
268
  part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
465
269
  processes = []
466
270
  for i in range(process_num):
@@ -485,8 +289,9 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
485
289
 
486
290
 
487
291
  def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
488
- needed_rank, device_num):
292
+ needed_rank, device_num, choice_func):
489
293
  """Find the rank_id under redundant groups."""
294
+ io_time = 0
490
295
  for param_name in pipe_param_list:
491
296
  rank_num = int(needed_rank)
492
297
  redundancy_ranks = _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num)
@@ -499,11 +304,23 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
499
304
  open_file_id = real_rank
500
305
  break
501
306
  if open_file_id is not None:
502
- output = file_dict[open_file_id].get_tensor(param_name)
307
+ start_time = time.time()
308
+ output = file_dict[open_file_id].get_slice(param_name)
309
+ end_time = time.time()
310
+ cost_time = end_time - start_time
311
+ io_time += cost_time
312
+ if choice_func is not None:
313
+ choice_out = choice_func(param_name)
314
+ if isinstance(choice_out, bool) and not choice_out:
315
+ continue
316
+ if not isinstance(choice_out, (bool, str)):
317
+ raise ValueError("For 'unified_safetensors', the return value type of the function "
318
+ f"'choice_func' must be bool or str, but got {type(choice_out)}.")
503
319
  saftensor_dict[param_name] = output
504
320
  else:
505
321
  raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
506
322
  f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
323
+ return io_time
507
324
 
508
325
 
509
326
  def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
@@ -512,13 +329,14 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
512
329
  origin_dst_strategy_list,
513
330
  ckpt_prefix, dst_safetensors_dir, output_format,
514
331
  _transform_param_list, pipe_param_list=None, file_index=None, unified_flag=False,
515
- src_strategy_file=None):
332
+ src_strategy_file=None, choice_func=None):
516
333
  """
517
334
  Transforms safetensors files to a specified format without using parallel processing.
518
335
  """
336
+ io_cost_time = 0
519
337
  if src_strategy_file is not None:
520
338
  from mindspore.train._utils import get_parameter_redundancy
521
- redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file)
339
+ redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
522
340
  redundancy_dict = {}
523
341
  device_num = 0
524
342
  for param_name, redundancy in redundancy_dict_tmp.items():
@@ -552,8 +370,10 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
552
370
  if pipe_param_list:
553
371
  saftensor_dict = dict()
554
372
  if src_strategy_file is not None:
555
- _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict,
556
- redundancy_dict, needed_rank, device_num)
373
+ io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
374
+ saftensor_dict, redundancy_dict, needed_rank,
375
+ device_num, choice_func)
376
+ io_cost_time += io_time
557
377
  else:
558
378
  with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
559
379
  if not unified_flag:
@@ -562,14 +382,32 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
562
382
  dst_param_name_set = set(dst_strategy_list_keys)
563
383
  hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
564
384
  pipe_param_list.extend(list(hyper_param_set))
385
+ io_time = 0
565
386
  for param_name in pipe_param_list:
566
387
  if param_name not in f.keys():
567
388
  # param not in ckpt file, check reason
568
389
  continue
569
- output = f.get_tensor(param_name)
390
+ start_time = time.time()
391
+ output = f.get_slice(param_name)
392
+ end_time = time.time()
393
+ cost_time = end_time - start_time
394
+ io_time += cost_time
395
+ io_cost_time += io_time
396
+ if choice_func is not None:
397
+ choice_out = choice_func(param_name)
398
+ if isinstance(choice_out, bool) and not choice_out:
399
+ continue
400
+ if not isinstance(choice_out, (bool, str)):
401
+ raise ValueError("For 'unified_safetensors', the return value type of the function "
402
+ f"'choice_func' must be bool or str, but got {type(choice_out)}.")
570
403
  saftensor_dict[param_name] = output
571
404
  else:
405
+ start_time = time.time()
572
406
  saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
407
+ end_time = time.time()
408
+ cost_time = end_time - start_time
409
+ io_cost_time += cost_time
410
+
573
411
  for param_name, param in saftensor_dict.items():
574
412
  src_rank = int(needed_rank) % src_stage_device_num
575
413
  param_total_dict[param_name][src_rank] = param
@@ -588,7 +426,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
588
426
  local_rank_id = transform_rank % dst_stage_device_num
589
427
  transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
590
428
  param_attr_dict, src_strategy_list, dst_strategy_list,
591
- param_total_dict_keys, src_strategy_file)
429
+ param_total_dict_keys, src_strategy_file, choice_func)
592
430
  if file_index is not None:
593
431
  save_safetensor_file = f"part{file_index}.{output_format}"
594
432
  save_safetensor_file_dir = dst_safetensors_dir
@@ -602,15 +440,17 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
602
440
  if _transform_param_list is not None:
603
441
  _transform_param_list.append({save_file_name: transform_param_dict})
604
442
  else:
605
- if output_format == "safetensors":
606
- save_file(transform_param_dict, save_file_name)
607
- else:
608
- transform_param_dict = _load_and_transform(transform_param_dict, None, None,
609
- transform_func=lambda v, name: ms.Parameter(v,
610
- name=name))
611
- ms.save_checkpoint(transform_param_dict, save_file_name)
443
+ if transform_param_dict:
444
+ if output_format == "safetensors":
445
+ save_file(transform_param_dict, save_file_name)
446
+ else:
447
+ transform_param_dict = _load_and_transform(transform_param_dict,
448
+ None, None, transform_func=
449
+ lambda v, name: ms.Parameter(v, name=name))
450
+ ms.save_checkpoint(transform_param_dict, save_file_name)
612
451
  del param_total_dict_keys
613
452
  del param_total_dict
453
+ return io_cost_time
614
454
 
615
455
 
616
456
  def _save_final_safetensors(_transform_param_list, output_format):
@@ -735,6 +575,13 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
735
575
  save_file(transform_param_dict, save_safetensor_file_name)
736
576
 
737
577
 
578
+ def _extrace_number(file_name):
579
+ """get file last two number"""
580
+ number_ls = re.findall(r'\d+', file_name)
581
+ number_ls = [int(i) for i in number_ls]
582
+ return number_ls[-2:]
583
+
584
+
738
585
  def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
739
586
  """
740
587
  Collects all safetensors files from the specified directory and its subdirectories.
@@ -758,12 +605,9 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
758
605
  else:
759
606
  safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
760
607
  rank_ckpts = glob.glob(safetensor_file_name)
761
- rank_ckpts.sort()
762
- for safetensor_file in rank_ckpts:
763
- if not os.path.isfile(safetensor_file):
764
- ms.log.warning("{} is not a safetensor file.".format(safetensor_file))
765
- continue
766
- all_safetensor_files_map[rank_id] = safetensor_file
608
+ rank_ckpts.sort(key=_extrace_number)
609
+ if rank_ckpts:
610
+ all_safetensor_files_map[rank_id] = rank_ckpts[-1]
767
611
  return all_safetensor_files_map
768
612
 
769
613
 
@@ -775,7 +619,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
775
619
  dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
776
620
  dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
777
621
  dst_device_num = dst_stage_device_num * dst_stage_num
778
- for rank in _progress_bar(range(dst_device_num)):
622
+ for rank in range(dst_device_num):
779
623
  needed_rank_list = ms.rank_list_for_transform(rank, src_strategy_dict, dst_strategy_dict)
780
624
  needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
781
625
  needed_rank_list_map[needed_rank_list_key].append(rank)
@@ -791,7 +635,8 @@ def load_file_by_param_name(filename, parme_name_list):
791
635
 
792
636
 
793
637
  def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
794
- dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None):
638
+ dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None,
639
+ choice_func=None):
795
640
  """
796
641
  Transform model parallel dimension for distributed safetensor files.
797
642
  """
@@ -799,7 +644,10 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
799
644
  device_num = -1
800
645
  param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
801
646
  for param_name in param_total_dict_keys:
802
- tensor_shape = list(param_total_dict[param_name].values())[0].shape
647
+ if str(type(list(param_total_dict[param_name].values())[0])) == "<class 'builtins.PySafeSlice'>":
648
+ tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
649
+ else:
650
+ tensor_shape = list(param_total_dict[param_name].values())[0].shape
803
651
  from_dev_matrix = [1]
804
652
  from_tensor_map = [-1] * len(tensor_shape)
805
653
  from_opt_shard_step = 0
@@ -832,6 +680,9 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
832
680
  continue
833
681
  origin_tensor_shape += (item * param_strategy[i],)
834
682
 
683
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
684
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
685
+
835
686
  from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
836
687
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
837
688
  to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
@@ -851,21 +702,132 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
851
702
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
852
703
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
853
704
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
705
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
854
706
  transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
855
707
  param_total_dict_copy = param_total_dict[param_name].copy()
856
708
  _apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
857
-
709
+ if choice_func is not None:
710
+ choice_out = choice_func(param_name)
711
+ if isinstance(choice_out, str):
712
+ param_name = choice_out
858
713
  transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
714
+ if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
715
+ transform_param_dict[param_name] = transform_param_dict[param_name][:]
859
716
 
860
717
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
861
718
  for param_name in param_total_dict_keys:
719
+ if choice_func is not None:
720
+ choice_out = choice_func(param_name)
721
+ if isinstance(choice_out, str):
722
+ continue
862
723
  if param_name not in transform_param_dict:
863
724
  transform_para = param_total_dict[param_name][rank_id % device_num]
725
+ if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
726
+ transform_para = transform_para[:]
864
727
  transform_param_dict[param_name] = transform_para
865
728
  return transform_param_dict
866
729
 
867
730
 
868
- def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None):
731
+ def _cal_param_size(shape, dtype):
732
+ """cal param size by dtype and shape"""
733
+ dtype_size = {
734
+ "BOOL": 1,
735
+ "U8": 1,
736
+ "I8": 1,
737
+ "F8_E5M2": 1,
738
+ "F8_E4M3": 1,
739
+ "I16": 2,
740
+ "U16": 2,
741
+ "I32": 4,
742
+ "U32": 4,
743
+ "I64": 8,
744
+ "U64": 8,
745
+ "F16": 2,
746
+ "BF16": 2,
747
+ "F32": 4,
748
+ "F64": 8,
749
+ }
750
+ num_elements = math.prod(shape)
751
+ element_size = dtype_size.get(dtype, 4)
752
+ total_bytes = num_elements * element_size
753
+ return total_bytes
754
+
755
+
756
+ def _split_weight_dict(weights, num_groups):
757
+ """split weights by num"""
758
+ sorted_items = sorted(weights.items(), key=lambda x: -x[1])
759
+ groups = [[] for _ in range(num_groups)]
760
+ total_bytes = [0] * num_groups
761
+ for weight_name, byte_size in sorted_items:
762
+ min_index = total_bytes.index(min(total_bytes))
763
+ groups[min_index].append(weight_name)
764
+ total_bytes[min_index] += byte_size
765
+
766
+ return groups
767
+
768
+
769
+ def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
770
+ """save hyper param"""
771
+ if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
772
+ with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
773
+ all_key = f.keys()
774
+ hyper_parameter = set(all_key) - set(name_list)
775
+ if hyper_parameter:
776
+ hyper_dict = {}
777
+ for key in hyper_parameter:
778
+ hyper_dict[key] = f.get_tensor(key)
779
+ save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
780
+
781
+
782
+ def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
783
+ """save parameter map json file"""
784
+ param_name_dict = dict()
785
+ for index, part_list in enumerate(split_list):
786
+ for name in part_list:
787
+ save_param_name = name
788
+ if choice_func is not None:
789
+ choice_out = choice_func(name)
790
+ if isinstance(choice_out, str):
791
+ save_param_name = choice_out
792
+ if save_param_name == -1:
793
+ break
794
+ param_name_dict[save_param_name] = f"part{index}.safetensors"
795
+ output_dict = {"metadata": {"total_size": param_total_size}, "weight_map": param_name_dict}
796
+ if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
797
+ json_str = json.dumps(output_dict, indent=4)
798
+ map_file = os.path.join(dst_dir, "param_name_map.json")
799
+ with open(map_file, 'w') as f:
800
+ f.write(json_str)
801
+
802
+
803
+ def _get_dst_shape(param_name, param_shape, src_strategy_list):
804
+ """get dst shape by strategy"""
805
+ from_dev_matrix = [1]
806
+ from_tensor_map = [-1] * len(param_shape)
807
+ from_opt_shard_size = 0
808
+ if src_strategy_list is not None:
809
+ from_dev_matrix, from_tensor_map, _, from_opt_shard_size = _extract_layout_item(
810
+ src_strategy_list.get(param_name))
811
+ to_dev_matrix_origin = [1]
812
+ to_tensor_map_origin = [-1] * len(param_shape)
813
+ to_opt_shard_step = 0
814
+ to_opt_shard_size = 0
815
+
816
+ param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
817
+ origin_tensor_shape = ()
818
+ for i, item in enumerate(param_shape):
819
+ if i == 0 and from_opt_shard_size > 0:
820
+ origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
821
+ continue
822
+ origin_tensor_shape += (item * param_strategy[i],)
823
+
824
+ _, _, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
825
+ to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
826
+ return to_full_tensor_shape
827
+
828
+
829
+ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
830
+ max_process_num=64, choice_func=None, split_dst_file=()):
869
831
  """
870
832
  Merge multiple safetensor files into a unified safetensor file.
871
833
 
@@ -877,6 +839,14 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
877
839
  saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
878
840
  file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
879
841
  meaning all safetensors files in the source weight directory will be merged.
842
+ max_process_num (int, optional): Maximum number of processes. Default: ``64``.
843
+ choice_func (callable, optional): A callable function used to filter parameters or modify parameter names.
844
+ The return value of the function must be of type str (string) or bool (boolean). Default: ``None``.
845
+ split_dst_file (tuple, optional) - A parameter used to manually split a task into multiple subtasks for
846
+ execution, represented as a tuple containing two elements. The first element indicates the number of
847
+ the current subtask, and the second element indicates the total number of tasks. This parameter supports
848
+ splitting and executing tasks multiple times on a single machine, and also supports executing different
849
+ subtasks on multiple machines respectively. Default: ``()``.
880
850
 
881
851
  Raises:
882
852
  ValueError: If the safetensors file of rank is missing.
@@ -889,8 +859,12 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
889
859
  >>> src_dir = "/usr/safetensors/llama31B/4p_safetensors/"
890
860
  >>> src_strategy_file = "/usr/safetensors/llama31B/strategy_4p.ckpt"
891
861
  >>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
892
- >>> ms.unified_safetensors(src_dir, src_strategy_file, dst_dir)
862
+ >>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
893
863
  """
864
+ pid = os.getpid()
865
+ total_cores = os.cpu_count()
866
+ all_cores = set(range(total_cores))
867
+ os.sched_setaffinity(pid, all_cores)
894
868
  _check_transform_safetensors(src_dir, "", src_strategy_file, None)
895
869
  _make_dir(dst_dir, "path")
896
870
  if os.path.isfile(src_dir):
@@ -914,13 +888,11 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
914
888
  "but it is missing.".format(needed_rank, rank))
915
889
  layout_map = _convert_to_list(src_strategy_dict)
916
890
 
917
- total_size = 0
918
891
  actual_params = set()
919
892
  for _, file_name in all_safetensor_files_map.items():
920
- total_size += os.path.getsize(file_name) / 1024 / 1024 / 1024
921
893
  with safe_open(file_name, framework="np") as f:
922
894
  actual_params.update(f.keys())
923
- split_num = math.ceil(total_size / 3)
895
+
924
896
  params_to_store = actual_params & set(layout_map.keys())
925
897
 
926
898
  name_list = []
@@ -928,29 +900,55 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
928
900
  if name.startswith("accu_grads"):
929
901
  continue
930
902
  name_list.append(name)
931
- split_list = _split_list(name_list, split_num)
932
-
933
- with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
934
- all_key = f.keys()
935
- hyper_parameter = set(all_key) - set(name_list)
936
- if hyper_parameter:
937
- hyper_dict = {}
938
- for key in hyper_parameter:
939
- hyper_dict[key] = f.get_tensor(key)
940
- save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
941
-
942
- # save parameter map json
943
- param_name_dict = dict()
944
- for index, part_list in enumerate(split_list):
945
- for name in part_list:
946
- param_name_dict[name] = f"part{index}.safetensors"
947
- json_str = json.dumps(param_name_dict, indent=4)
948
- map_file = os.path.join(dst_dir, "param_name_map.json")
949
- with open(map_file, 'w') as f:
950
- f.write(json_str)
951
-
952
- max_process = min(split_num, 100)
953
- res = [i for i in range(split_num)]
903
+
904
+ param_size_dict = {}
905
+ param_total_size = 0
906
+ for _, file_name in all_safetensor_files_map.items():
907
+ with safe_open(file_name, framework="np") as f:
908
+ for k in f.keys():
909
+ if k in name_list:
910
+ py_slice = f.get_slice(k)
911
+ param_total_size += _cal_param_size(py_slice.get_shape(), py_slice.get_dtype())
912
+ param_dst_shape = _get_dst_shape(k, py_slice.get_shape(), origin_src_strategy_list)
913
+ # Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
914
+ param_dst_shape = [int(item) for item in param_dst_shape]
915
+ if choice_func is not None:
916
+ choice_out = choice_func(k)
917
+ if isinstance(choice_out, bool):
918
+ if not choice_out:
919
+ continue
920
+ if k not in param_size_dict:
921
+ param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.get_dtype())
922
+ split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
923
+ split_num = min(split_num, len(name_list))
924
+ split_list = _split_weight_dict(param_size_dict, split_num)
925
+
926
+ if split_dst_file:
927
+ current_machine_num = split_dst_file[0]
928
+ total_machine_num = split_dst_file[1]
929
+ n = len(split_list)
930
+ avg_length = n // total_machine_num
931
+ remainder = n % total_machine_num
932
+ start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
933
+ end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
934
+ sub_list = []
935
+ for i in range(len(split_list)):
936
+ if start_index <= i < end_index:
937
+ sub_list.append(split_list[i])
938
+ else:
939
+ sub_list.append([-1])
940
+ else:
941
+ sub_list = split_list
942
+
943
+ _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
944
+ _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
945
+
946
+ if split_dst_file:
947
+ split_num = end_index - start_index
948
+ res = list(range(start_index, end_index))
949
+ else:
950
+ res = [i for i in range(split_num)]
951
+ max_process = min(split_num, max_process_num)
954
952
  res = _split_list(res, max_process)
955
953
  processes = []
956
954
  src_strategy_name = None
@@ -960,7 +958,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
960
958
  p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
961
959
  needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
962
960
  src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
963
- "", dst_dir, "safetensors", None, split_list, res[i], True, src_strategy_name))
961
+ "", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
964
962
  p.start()
965
963
  processes.append(p)
966
964
  for p in processes:
@@ -974,13 +972,21 @@ def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor
974
972
  origin_dst_strategy_list,
975
973
  ckpt_prefix, dst_safetensors_dir, output_format,
976
974
  _transform_param_list, pipe_param_list=None, file_index=None,
977
- unified_flag=False, src_strategy_file=None):
975
+ unified_flag=False, src_strategy_file=None, choice_func=None):
976
+ """transform safetensors single semaphore"""
977
+ total_io_cost_time = 0
978
978
  for i in file_index:
979
- _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
980
- dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
981
- origin_src_strategy_list,
982
- origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir, output_format,
983
- _transform_param_list, pipe_param_list[i], i, unified_flag, src_strategy_file)
979
+ io_cost_time = _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map,
980
+ src_stage_device_num, dst_stage_device_num, src_strategy_dict,
981
+ dst_strategy_dict, origin_src_strategy_list,
982
+ origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir,
983
+ output_format, _transform_param_list, pipe_param_list[i], i,
984
+ unified_flag, src_strategy_file, choice_func)
985
+ while psutil.virtual_memory().percent > 50:
986
+ time.sleep(1)
987
+ total_io_cost_time += io_cost_time
988
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
989
+ f"Unified safetensors io cost time:{total_io_cost_time}.")
984
990
 
985
991
 
986
992
  def _split_list(split_list, split_num):
@@ -1027,36 +1033,76 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
1027
1033
  return sf_obj
1028
1034
 
1029
1035
 
1030
- def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None, dst_safetensors_dir=None,
1031
- rank_id=None):
1032
- """load parallel safetensors by merged file."""
1033
- file_list = os.listdir(total_safetensors_dir)
1034
- json_files = [file for file in file_list if file.endswith('.json')]
1035
- if len(json_files) != 1:
1036
- raise ValueError(f"For 'load_parallel_checkpoint', the number of json files in 'total_safetensors_dir' "
1037
- f"must be 1, but got {len(json_files)}.")
1038
- param_name_json = os.path.join(total_safetensors_dir, json_files[0])
1039
- with open(param_name_json, 'r') as f:
1040
- param_name_map = json.load(f)
1036
+ def _process_hyper_params(file_list, total_safetensors_dir, total_param):
1037
+ """process hyper params"""
1038
+ if 'hyper_param.safetensors' in file_list:
1039
+ hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
1040
+ with safe_open(hyper_parameter_file_name, framework="np") as f:
1041
+ for key in f.keys():
1042
+ total_param[key] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
1043
+ return total_param
1044
+
1045
+
1046
+ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files, dst_strategy_file, rank_id):
1047
+ """calculate param_name_map and param_list"""
1048
+ if len(file_list) == 1:
1049
+ logger.info("There is only one weight file in the directory, which will be automatically mapped.")
1050
+ file_name = os.path.join(total_safetensors_dir, file_list[0])
1051
+ is_file = os.path.isfile(file_name)
1052
+ if not is_file:
1053
+ raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
1054
+ f"in the `unified_safetensors_dir`.")
1055
+ with safe_open(file_name, framework="np") as f:
1056
+ keys = f.keys()
1057
+ values = len(keys) * [file_list[0]]
1058
+ param_name_map = dict(zip(keys, values))
1059
+ else:
1060
+ if not json_files:
1061
+ raise ValueError(
1062
+ f"For 'load_parallel_checkpoint', there must be a JSON file named 'param_name_map.json' in "
1063
+ f"the 'total_safetensors_dir'.")
1064
+ param_name_json = os.path.join(total_safetensors_dir, json_files[0])
1065
+ with open(param_name_json, 'r') as f:
1066
+ param_name_map = json.load(f)
1067
+ if "weight_map" in param_name_map:
1068
+ param_name_map = param_name_map["weight_map"]
1069
+
1041
1070
  if dst_strategy_file is not None:
1042
1071
  _, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
1043
1072
  param_list = dst_strategy_list.keys()
1044
1073
  else:
1045
1074
  dst_strategy_list = None
1046
1075
  param_list = param_name_map.keys()
1076
+ return param_name_map, param_list, dst_strategy_list
1077
+
1047
1078
 
1079
+ def _load_parallel_checkpoint(file_info):
1080
+ """load parallel safetensors by merged file."""
1081
+ total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
1082
+ rank_id, output_format, name_map, return_param_dict = file_info
1083
+ pid = os.getpid()
1084
+ total_cores = os.cpu_count()
1085
+ all_cores = set(range(total_cores))
1086
+ os.sched_setaffinity(pid, all_cores)
1087
+ file_list = os.listdir(total_safetensors_dir)
1088
+ json_files = [file for file in file_list if file == "param_name_map.json"]
1089
+ param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
1090
+ json_files, dst_strategy_file,
1091
+ rank_id)
1048
1092
  total_param = dict()
1049
1093
  dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
1050
1094
  is not None else 1
1051
1095
  local_rank_id = rank_id % dst_stage_device_num
1052
- for param_name in param_list:
1096
+ total_io_cost_time = 0
1097
+ for param_name in _progress_bar(param_list):
1053
1098
  if param_name not in param_name_map:
1054
1099
  continue
1055
1100
  file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
1056
1101
  with safe_open(file_name, framework="np") as f:
1057
- if param_name not in f.keys():
1102
+ cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
1103
+ if cur_param_name not in f.keys():
1058
1104
  continue
1059
- sf_obj = f.get_slice(param_name)
1105
+ sf_obj = f.get_slice(cur_param_name)
1060
1106
 
1061
1107
  tensor_shape = sf_obj.get_shape()
1062
1108
  from_dev_matrix = [1]
@@ -1078,6 +1124,9 @@ def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None
1078
1124
  continue
1079
1125
  origin_tensor_shape += (item * param_strategy[i],)
1080
1126
 
1127
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
1128
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
1129
+
1081
1130
  from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
1082
1131
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
1083
1132
  to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
@@ -1097,25 +1146,34 @@ def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None
1097
1146
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
1098
1147
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
1099
1148
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
1149
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
1150
+ has_layout_from, has_layout_to)
1100
1151
  transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
1101
-
1152
+ start_time = time.time()
1102
1153
  slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
1154
+ end_time = time.time()
1155
+ cost_time = end_time - start_time
1156
+ total_io_cost_time += cost_time
1103
1157
  else:
1158
+ start_time = time.time()
1104
1159
  slice_param = sf_obj[:]
1105
-
1106
- total_param[param_name] = ms.Parameter(slice_param)
1107
-
1108
- if 'hyper_param.safetensors' in file_list:
1109
- hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
1110
- with safe_open(hyper_parameter_file_name, framework="np") as f:
1111
- for key in f.keys():
1112
- total_param[key] = ms.Parameter(f.get_tensor(key))
1160
+ end_time = time.time()
1161
+ cost_time = end_time - start_time
1162
+ total_io_cost_time += cost_time
1163
+ total_param[param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
1164
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1165
+ f"load distributed safetensors io cost time:{total_io_cost_time}.")
1166
+ total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
1113
1167
  if net is not None:
1114
- param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
1115
- return param_not_load, ckpt_not_load
1168
+ if not return_param_dict:
1169
+ logger.info("start load param into net...")
1170
+ param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
1171
+ logger.info("load param into net is end...")
1172
+ return param_not_load, ckpt_not_load
1173
+ return total_param
1116
1174
  _make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
1117
- ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.safetensors"),
1118
- format='safetensors')
1175
+ ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.{output_format}"),
1176
+ format=output_format)
1119
1177
  return None
1120
1178
 
1121
1179
 
@@ -1143,4 +1201,4 @@ def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
1143
1201
 
1144
1202
 
1145
1203
  __all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
1146
- "transform_safetensors_by_rank", "ckpt_to_safetensors", "safetensors_to_ckpt", "unified_safetensors"]
1204
+ "transform_safetensors_by_rank", "unified_safetensors"]