mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from __future__ import absolute_import
17
17
 
18
18
  import os
19
+ import sys
19
20
  import glob
20
21
  import math
21
22
  import json
@@ -24,15 +25,17 @@ from collections import defaultdict
24
25
 
25
26
  import time
26
27
  import multiprocessing as mp
28
+ import psutil
27
29
  import numpy as np
28
30
  from safetensors.numpy import save_file, load_file
29
31
  from safetensors import safe_open
30
32
 
31
33
  import mindspore as ms
32
34
  from mindspore import log as logger
35
+ from mindspore.log import vlog_print
33
36
  from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
34
37
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
35
- _insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
38
+ _insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
36
39
  from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
37
40
  _get_needed_rank_transform_operator_map_by_layouts, \
38
41
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
@@ -65,6 +68,7 @@ def _progress_bar(iterable, total=None):
65
68
  elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
66
69
  remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
67
70
 
71
+ sys.stdout.reconfigure(encoding="utf-8")
68
72
  print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
69
73
  if iteration == total:
70
74
  print()
@@ -285,8 +289,9 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
285
289
 
286
290
 
287
291
  def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
288
- needed_rank, device_num):
292
+ needed_rank, device_num, choice_func):
289
293
  """Find the rank_id under redundant groups."""
294
+ io_time = 0
290
295
  for param_name in pipe_param_list:
291
296
  rank_num = int(needed_rank)
292
297
  redundancy_ranks = _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num)
@@ -299,11 +304,23 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
299
304
  open_file_id = real_rank
300
305
  break
301
306
  if open_file_id is not None:
302
- output = file_dict[open_file_id].get_tensor(param_name)
307
+ start_time = time.time()
308
+ output = file_dict[open_file_id].get_slice(param_name)
309
+ end_time = time.time()
310
+ cost_time = end_time - start_time
311
+ io_time += cost_time
312
+ if choice_func is not None:
313
+ choice_out = choice_func(param_name)
314
+ if isinstance(choice_out, bool) and not choice_out:
315
+ continue
316
+ if not isinstance(choice_out, (bool, str)):
317
+ raise ValueError("For 'unified_safetensors', the return value type of the function "
318
+ f"'choice_func' must be bool or str, but got {type(choice_out)}.")
303
319
  saftensor_dict[param_name] = output
304
320
  else:
305
321
  raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
306
322
  f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
323
+ return io_time
307
324
 
308
325
 
309
326
  def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
@@ -316,9 +333,10 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
316
333
  """
317
334
  Transforms safetensors files to a specified format without using parallel processing.
318
335
  """
336
+ io_cost_time = 0
319
337
  if src_strategy_file is not None:
320
338
  from mindspore.train._utils import get_parameter_redundancy
321
- redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file)
339
+ redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
322
340
  redundancy_dict = {}
323
341
  device_num = 0
324
342
  for param_name, redundancy in redundancy_dict_tmp.items():
@@ -352,8 +370,10 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
352
370
  if pipe_param_list:
353
371
  saftensor_dict = dict()
354
372
  if src_strategy_file is not None:
355
- _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict,
356
- redundancy_dict, needed_rank, device_num)
373
+ io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
374
+ saftensor_dict, redundancy_dict, needed_rank,
375
+ device_num, choice_func)
376
+ io_cost_time += io_time
357
377
  else:
358
378
  with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
359
379
  if not unified_flag:
@@ -362,25 +382,32 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
362
382
  dst_param_name_set = set(dst_strategy_list_keys)
363
383
  hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
364
384
  pipe_param_list.extend(list(hyper_param_set))
385
+ io_time = 0
365
386
  for param_name in pipe_param_list:
366
387
  if param_name not in f.keys():
367
388
  # param not in ckpt file, check reason
368
389
  continue
369
- output = f.get_tensor(param_name)
370
- save_param_name = param_name
390
+ start_time = time.time()
391
+ output = f.get_slice(param_name)
392
+ end_time = time.time()
393
+ cost_time = end_time - start_time
394
+ io_time += cost_time
395
+ io_cost_time += io_time
371
396
  if choice_func is not None:
372
397
  choice_out = choice_func(param_name)
373
- if isinstance(choice_out, bool):
374
- if not choice_out:
375
- continue
376
- elif isinstance(choice_out, str):
377
- save_param_name = choice_out
378
- else:
398
+ if isinstance(choice_out, bool) and not choice_out:
399
+ continue
400
+ if not isinstance(choice_out, (bool, str)):
379
401
  raise ValueError("For 'unified_safetensors', the return value type of the function "
380
402
  f"'choice_func' must be bool or str, but got {type(choice_out)}.")
381
- saftensor_dict[save_param_name] = output
403
+ saftensor_dict[param_name] = output
382
404
  else:
405
+ start_time = time.time()
383
406
  saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
407
+ end_time = time.time()
408
+ cost_time = end_time - start_time
409
+ io_cost_time += cost_time
410
+
384
411
  for param_name, param in saftensor_dict.items():
385
412
  src_rank = int(needed_rank) % src_stage_device_num
386
413
  param_total_dict[param_name][src_rank] = param
@@ -399,7 +426,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
399
426
  local_rank_id = transform_rank % dst_stage_device_num
400
427
  transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
401
428
  param_attr_dict, src_strategy_list, dst_strategy_list,
402
- param_total_dict_keys, src_strategy_file)
429
+ param_total_dict_keys, src_strategy_file, choice_func)
403
430
  if file_index is not None:
404
431
  save_safetensor_file = f"part{file_index}.{output_format}"
405
432
  save_safetensor_file_dir = dst_safetensors_dir
@@ -413,15 +440,17 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
413
440
  if _transform_param_list is not None:
414
441
  _transform_param_list.append({save_file_name: transform_param_dict})
415
442
  else:
416
- if output_format == "safetensors":
417
- save_file(transform_param_dict, save_file_name)
418
- else:
419
- transform_param_dict = _load_and_transform(transform_param_dict, None, None,
420
- transform_func=lambda v, name: ms.Parameter(v,
421
- name=name))
422
- ms.save_checkpoint(transform_param_dict, save_file_name)
443
+ if transform_param_dict:
444
+ if output_format == "safetensors":
445
+ save_file(transform_param_dict, save_file_name)
446
+ else:
447
+ transform_param_dict = _load_and_transform(transform_param_dict,
448
+ None, None, transform_func=
449
+ lambda v, name: ms.Parameter(v, name=name))
450
+ ms.save_checkpoint(transform_param_dict, save_file_name)
423
451
  del param_total_dict_keys
424
452
  del param_total_dict
453
+ return io_cost_time
425
454
 
426
455
 
427
456
  def _save_final_safetensors(_transform_param_list, output_format):
@@ -552,6 +581,7 @@ def _extrace_number(file_name):
552
581
  number_ls = [int(i) for i in number_ls]
553
582
  return number_ls[-2:]
554
583
 
584
+
555
585
  def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
556
586
  """
557
587
  Collects all safetensors files from the specified directory and its subdirectories.
@@ -589,7 +619,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
589
619
  dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
590
620
  dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
591
621
  dst_device_num = dst_stage_device_num * dst_stage_num
592
- for rank in _progress_bar(range(dst_device_num)):
622
+ for rank in range(dst_device_num):
593
623
  needed_rank_list = ms.rank_list_for_transform(rank, src_strategy_dict, dst_strategy_dict)
594
624
  needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
595
625
  needed_rank_list_map[needed_rank_list_key].append(rank)
@@ -605,7 +635,8 @@ def load_file_by_param_name(filename, parme_name_list):
605
635
 
606
636
 
607
637
  def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
608
- dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None):
638
+ dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None,
639
+ choice_func=None):
609
640
  """
610
641
  Transform model parallel dimension for distributed safetensor files.
611
642
  """
@@ -613,7 +644,10 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
613
644
  device_num = -1
614
645
  param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
615
646
  for param_name in param_total_dict_keys:
616
- tensor_shape = list(param_total_dict[param_name].values())[0].shape
647
+ if str(type(list(param_total_dict[param_name].values())[0])) == "<class 'builtins.PySafeSlice'>":
648
+ tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
649
+ else:
650
+ tensor_shape = list(param_total_dict[param_name].values())[0].shape
617
651
  from_dev_matrix = [1]
618
652
  from_tensor_map = [-1] * len(tensor_shape)
619
653
  from_opt_shard_step = 0
@@ -646,6 +680,9 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
646
680
  continue
647
681
  origin_tensor_shape += (item * param_strategy[i],)
648
682
 
683
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
684
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
685
+
649
686
  from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
650
687
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
651
688
  to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
@@ -665,22 +702,132 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
665
702
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
666
703
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
667
704
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
705
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
668
706
  transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
669
707
  param_total_dict_copy = param_total_dict[param_name].copy()
670
708
  _apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
671
-
709
+ if choice_func is not None:
710
+ choice_out = choice_func(param_name)
711
+ if isinstance(choice_out, str):
712
+ param_name = choice_out
672
713
  transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
714
+ if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
715
+ transform_param_dict[param_name] = transform_param_dict[param_name][:]
673
716
 
674
717
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
675
718
  for param_name in param_total_dict_keys:
719
+ if choice_func is not None:
720
+ choice_out = choice_func(param_name)
721
+ if isinstance(choice_out, str):
722
+ continue
676
723
  if param_name not in transform_param_dict:
677
724
  transform_para = param_total_dict[param_name][rank_id % device_num]
725
+ if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
726
+ transform_para = transform_para[:]
678
727
  transform_param_dict[param_name] = transform_para
679
728
  return transform_param_dict
680
729
 
681
730
 
731
+ def _cal_param_size(shape, dtype):
732
+ """cal param size by dtype and shape"""
733
+ dtype_size = {
734
+ "BOOL": 1,
735
+ "U8": 1,
736
+ "I8": 1,
737
+ "F8_E5M2": 1,
738
+ "F8_E4M3": 1,
739
+ "I16": 2,
740
+ "U16": 2,
741
+ "I32": 4,
742
+ "U32": 4,
743
+ "I64": 8,
744
+ "U64": 8,
745
+ "F16": 2,
746
+ "BF16": 2,
747
+ "F32": 4,
748
+ "F64": 8,
749
+ }
750
+ num_elements = math.prod(shape)
751
+ element_size = dtype_size.get(dtype, 4)
752
+ total_bytes = num_elements * element_size
753
+ return total_bytes
754
+
755
+
756
+ def _split_weight_dict(weights, num_groups):
757
+ """split weights by num"""
758
+ sorted_items = sorted(weights.items(), key=lambda x: -x[1])
759
+ groups = [[] for _ in range(num_groups)]
760
+ total_bytes = [0] * num_groups
761
+ for weight_name, byte_size in sorted_items:
762
+ min_index = total_bytes.index(min(total_bytes))
763
+ groups[min_index].append(weight_name)
764
+ total_bytes[min_index] += byte_size
765
+
766
+ return groups
767
+
768
+
769
+ def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
770
+ """save hyper param"""
771
+ if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
772
+ with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
773
+ all_key = f.keys()
774
+ hyper_parameter = set(all_key) - set(name_list)
775
+ if hyper_parameter:
776
+ hyper_dict = {}
777
+ for key in hyper_parameter:
778
+ hyper_dict[key] = f.get_tensor(key)
779
+ save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
780
+
781
+
782
+ def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
783
+ """save parameter map json file"""
784
+ param_name_dict = dict()
785
+ for index, part_list in enumerate(split_list):
786
+ for name in part_list:
787
+ save_param_name = name
788
+ if choice_func is not None:
789
+ choice_out = choice_func(name)
790
+ if isinstance(choice_out, str):
791
+ save_param_name = choice_out
792
+ if save_param_name == -1:
793
+ break
794
+ param_name_dict[save_param_name] = f"part{index}.safetensors"
795
+ output_dict = {"metadata": {"total_size": param_total_size}, "weight_map": param_name_dict}
796
+ if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
797
+ json_str = json.dumps(output_dict, indent=4)
798
+ map_file = os.path.join(dst_dir, "param_name_map.json")
799
+ with open(map_file, 'w') as f:
800
+ f.write(json_str)
801
+
802
+
803
+ def _get_dst_shape(param_name, param_shape, src_strategy_list):
804
+ """get dst shape by strategy"""
805
+ from_dev_matrix = [1]
806
+ from_tensor_map = [-1] * len(param_shape)
807
+ from_opt_shard_size = 0
808
+ if src_strategy_list is not None:
809
+ from_dev_matrix, from_tensor_map, _, from_opt_shard_size = _extract_layout_item(
810
+ src_strategy_list.get(param_name))
811
+ to_dev_matrix_origin = [1]
812
+ to_tensor_map_origin = [-1] * len(param_shape)
813
+ to_opt_shard_step = 0
814
+ to_opt_shard_size = 0
815
+
816
+ param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
817
+ origin_tensor_shape = ()
818
+ for i, item in enumerate(param_shape):
819
+ if i == 0 and from_opt_shard_size > 0:
820
+ origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
821
+ continue
822
+ origin_tensor_shape += (item * param_strategy[i],)
823
+
824
+ _, _, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
825
+ to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
826
+ return to_full_tensor_shape
827
+
828
+
682
829
  def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
683
- max_process_num=64, choice_func=None):
830
+ max_process_num=64, choice_func=None, split_dst_file=()):
684
831
  """
685
832
  Merge multiple safetensor files into a unified safetensor file.
686
833
 
@@ -692,9 +839,14 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
692
839
  saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
693
840
  file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
694
841
  meaning all safetensors files in the source weight directory will be merged.
695
- max_process_num (int): Maximum number of processes. Default: 64.
696
- choice_func (callable): A callable function used to filter parameters or modify parameter names.
697
- The return value of the function must be of type str (string) or bool (boolean). Default: None.
842
+ max_process_num (int, optional): Maximum number of processes. Default: ``64``.
843
+ choice_func (callable, optional): A callable function used to filter parameters or modify parameter names.
844
+ The return value of the function must be of type str (string) or bool (boolean). Default: ``None``.
845
+ split_dst_file (tuple, optional) - A parameter used to manually split a task into multiple subtasks for
846
+ execution, represented as a tuple containing two elements. The first element indicates the number of
847
+ the current subtask, and the second element indicates the total number of tasks. This parameter supports
848
+ splitting and executing tasks multiple times on a single machine, and also supports executing different
849
+ subtasks on multiple machines respectively. Default: ``()``.
698
850
 
699
851
  Raises:
700
852
  ValueError: If the safetensors file of rank is missing.
@@ -707,8 +859,12 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
707
859
  >>> src_dir = "/usr/safetensors/llama31B/4p_safetensors/"
708
860
  >>> src_strategy_file = "/usr/safetensors/llama31B/strategy_4p.ckpt"
709
861
  >>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
710
- >>> ms.unified_safetensors(src_dir, src_strategy_file, dst_dir)
862
+ >>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
711
863
  """
864
+ pid = os.getpid()
865
+ total_cores = os.cpu_count()
866
+ all_cores = set(range(total_cores))
867
+ os.sched_setaffinity(pid, all_cores)
712
868
  _check_transform_safetensors(src_dir, "", src_strategy_file, None)
713
869
  _make_dir(dst_dir, "path")
714
870
  if os.path.isfile(src_dir):
@@ -732,13 +888,11 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
732
888
  "but it is missing.".format(needed_rank, rank))
733
889
  layout_map = _convert_to_list(src_strategy_dict)
734
890
 
735
- total_size = 0
736
891
  actual_params = set()
737
892
  for _, file_name in all_safetensor_files_map.items():
738
- total_size += os.path.getsize(file_name) / 1024 / 1024 / 1024
739
893
  with safe_open(file_name, framework="np") as f:
740
894
  actual_params.update(f.keys())
741
- split_num = math.ceil(total_size / 3)
895
+
742
896
  params_to_store = actual_params & set(layout_map.keys())
743
897
 
744
898
  name_list = []
@@ -746,37 +900,55 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
746
900
  if name.startswith("accu_grads"):
747
901
  continue
748
902
  name_list.append(name)
749
- split_list = _split_list(name_list, split_num)
750
-
751
- with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
752
- all_key = f.keys()
753
- hyper_parameter = set(all_key) - set(name_list)
754
- if hyper_parameter:
755
- hyper_dict = {}
756
- for key in hyper_parameter:
757
- hyper_dict[key] = f.get_tensor(key)
758
- save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
759
-
760
- # save parameter map json
761
- param_name_dict = dict()
762
- for index, part_list in enumerate(split_list):
763
- for name in part_list:
764
- save_param_name = name
765
- if choice_func is not None:
766
- choice_out = choice_func(name)
767
- if isinstance(choice_out, bool):
768
- if not choice_out:
769
- continue
770
- elif isinstance(choice_out, str):
771
- save_param_name = choice_out
772
- param_name_dict[save_param_name] = f"part{index}.safetensors"
773
- json_str = json.dumps(param_name_dict, indent=4)
774
- map_file = os.path.join(dst_dir, "param_name_map.json")
775
- with open(map_file, 'w') as f:
776
- f.write(json_str)
777
903
 
904
+ param_size_dict = {}
905
+ param_total_size = 0
906
+ for _, file_name in all_safetensor_files_map.items():
907
+ with safe_open(file_name, framework="np") as f:
908
+ for k in f.keys():
909
+ if k in name_list:
910
+ py_slice = f.get_slice(k)
911
+ param_total_size += _cal_param_size(py_slice.get_shape(), py_slice.get_dtype())
912
+ param_dst_shape = _get_dst_shape(k, py_slice.get_shape(), origin_src_strategy_list)
913
+ # Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
914
+ param_dst_shape = [int(item) for item in param_dst_shape]
915
+ if choice_func is not None:
916
+ choice_out = choice_func(k)
917
+ if isinstance(choice_out, bool):
918
+ if not choice_out:
919
+ continue
920
+ if k not in param_size_dict:
921
+ param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.get_dtype())
922
+ split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
923
+ split_num = min(split_num, len(name_list))
924
+ split_list = _split_weight_dict(param_size_dict, split_num)
925
+
926
+ if split_dst_file:
927
+ current_machine_num = split_dst_file[0]
928
+ total_machine_num = split_dst_file[1]
929
+ n = len(split_list)
930
+ avg_length = n // total_machine_num
931
+ remainder = n % total_machine_num
932
+ start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
933
+ end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
934
+ sub_list = []
935
+ for i in range(len(split_list)):
936
+ if start_index <= i < end_index:
937
+ sub_list.append(split_list[i])
938
+ else:
939
+ sub_list.append([-1])
940
+ else:
941
+ sub_list = split_list
942
+
943
+ _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
944
+ _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
945
+
946
+ if split_dst_file:
947
+ split_num = end_index - start_index
948
+ res = list(range(start_index, end_index))
949
+ else:
950
+ res = [i for i in range(split_num)]
778
951
  max_process = min(split_num, max_process_num)
779
- res = [i for i in range(split_num)]
780
952
  res = _split_list(res, max_process)
781
953
  processes = []
782
954
  src_strategy_name = None
@@ -786,7 +958,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
786
958
  p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
787
959
  needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
788
960
  src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
789
- "", dst_dir, "safetensors", None, split_list, res[i], True, src_strategy_name, choice_func))
961
+ "", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
790
962
  p.start()
791
963
  processes.append(p)
792
964
  for p in processes:
@@ -801,13 +973,20 @@ def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor
801
973
  ckpt_prefix, dst_safetensors_dir, output_format,
802
974
  _transform_param_list, pipe_param_list=None, file_index=None,
803
975
  unified_flag=False, src_strategy_file=None, choice_func=None):
976
+ """transform safetensors single semaphore"""
977
+ total_io_cost_time = 0
804
978
  for i in file_index:
805
- _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
806
- dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
807
- origin_src_strategy_list,
808
- origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir, output_format,
809
- _transform_param_list, pipe_param_list[i], i, unified_flag, src_strategy_file,
810
- choice_func)
979
+ io_cost_time = _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map,
980
+ src_stage_device_num, dst_stage_device_num, src_strategy_dict,
981
+ dst_strategy_dict, origin_src_strategy_list,
982
+ origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir,
983
+ output_format, _transform_param_list, pipe_param_list[i], i,
984
+ unified_flag, src_strategy_file, choice_func)
985
+ while psutil.virtual_memory().percent > 50:
986
+ time.sleep(1)
987
+ total_io_cost_time += io_cost_time
988
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
989
+ f"Unified safetensors io cost time:{total_io_cost_time}.")
811
990
 
812
991
 
813
992
  def _split_list(split_list, split_num):
@@ -854,22 +1033,13 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
854
1033
  return sf_obj
855
1034
 
856
1035
 
857
- def _check_name_map_value_is_str(value):
858
- """check input is bool"""
859
- if not isinstance(value, str):
860
- raise ValueError(
861
- f"For 'load_distributed_checkpoint', the value of name_map must be str, but got {type(value)}.")
862
-
863
-
864
- def _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param):
1036
+ def _process_hyper_params(file_list, total_safetensors_dir, total_param):
865
1037
  """process hyper params"""
866
1038
  if 'hyper_param.safetensors' in file_list:
867
1039
  hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
868
1040
  with safe_open(hyper_parameter_file_name, framework="np") as f:
869
1041
  for key in f.keys():
870
- cur_param_name = name_map.get(key) if name_map is not None and key in name_map else key
871
- _check_name_map_value_is_str(cur_param_name)
872
- total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
1042
+ total_param[key] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
873
1043
  return total_param
874
1044
 
875
1045
 
@@ -887,12 +1057,15 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
887
1057
  values = len(keys) * [file_list[0]]
888
1058
  param_name_map = dict(zip(keys, values))
889
1059
  else:
890
- if len(json_files) != 1:
891
- raise ValueError(f"For 'load_parallel_checkpoint', the number of json files in 'total_safetensors_dir' "
892
- f"must be 1, but got {len(json_files)}.")
1060
+ if not json_files:
1061
+ raise ValueError(
1062
+ f"For 'load_parallel_checkpoint', there must be a JSON file named 'param_name_map.json' in "
1063
+ f"the 'total_safetensors_dir'.")
893
1064
  param_name_json = os.path.join(total_safetensors_dir, json_files[0])
894
1065
  with open(param_name_json, 'r') as f:
895
1066
  param_name_map = json.load(f)
1067
+ if "weight_map" in param_name_map:
1068
+ param_name_map = param_name_map["weight_map"]
896
1069
 
897
1070
  if dst_strategy_file is not None:
898
1071
  _, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
@@ -907,8 +1080,12 @@ def _load_parallel_checkpoint(file_info):
907
1080
  """load parallel safetensors by merged file."""
908
1081
  total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
909
1082
  rank_id, output_format, name_map, return_param_dict = file_info
1083
+ pid = os.getpid()
1084
+ total_cores = os.cpu_count()
1085
+ all_cores = set(range(total_cores))
1086
+ os.sched_setaffinity(pid, all_cores)
910
1087
  file_list = os.listdir(total_safetensors_dir)
911
- json_files = [file for file in file_list if file.endswith('.json')]
1088
+ json_files = [file for file in file_list if file == "param_name_map.json"]
912
1089
  param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
913
1090
  json_files, dst_strategy_file,
914
1091
  rank_id)
@@ -916,14 +1093,16 @@ def _load_parallel_checkpoint(file_info):
916
1093
  dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
917
1094
  is not None else 1
918
1095
  local_rank_id = rank_id % dst_stage_device_num
919
- for param_name in param_list:
1096
+ total_io_cost_time = 0
1097
+ for param_name in _progress_bar(param_list):
920
1098
  if param_name not in param_name_map:
921
1099
  continue
922
1100
  file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
923
1101
  with safe_open(file_name, framework="np") as f:
924
- if param_name not in f.keys():
1102
+ cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
1103
+ if cur_param_name not in f.keys():
925
1104
  continue
926
- sf_obj = f.get_slice(param_name)
1105
+ sf_obj = f.get_slice(cur_param_name)
927
1106
 
928
1107
  tensor_shape = sf_obj.get_shape()
929
1108
  from_dev_matrix = [1]
@@ -945,6 +1124,9 @@ def _load_parallel_checkpoint(file_info):
945
1124
  continue
946
1125
  origin_tensor_shape += (item * param_strategy[i],)
947
1126
 
1127
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
1128
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
1129
+
948
1130
  from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
949
1131
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
950
1132
  to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
@@ -964,19 +1146,29 @@ def _load_parallel_checkpoint(file_info):
964
1146
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
965
1147
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
966
1148
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
1149
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
1150
+ has_layout_from, has_layout_to)
967
1151
  transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
968
-
1152
+ start_time = time.time()
969
1153
  slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
1154
+ end_time = time.time()
1155
+ cost_time = end_time - start_time
1156
+ total_io_cost_time += cost_time
970
1157
  else:
1158
+ start_time = time.time()
971
1159
  slice_param = sf_obj[:]
972
- cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
973
- _check_name_map_value_is_str(cur_param_name)
974
- total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
975
-
976
- total_param = _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param)
1160
+ end_time = time.time()
1161
+ cost_time = end_time - start_time
1162
+ total_io_cost_time += cost_time
1163
+ total_param[param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
1164
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1165
+ f"load distributed safetensors io cost time:{total_io_cost_time}.")
1166
+ total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
977
1167
  if net is not None:
978
1168
  if not return_param_dict:
1169
+ logger.info("start load param into net...")
979
1170
  param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
1171
+ logger.info("load param into net is end...")
980
1172
  return param_not_load, ckpt_not_load
981
1173
  return total_param
982
1174
  _make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
mindspore/pgodb140.dll CHANGED
Binary file
mindspore/pgort140.dll CHANGED
Binary file