mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (493) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +25 -194
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +109 -75
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +2014 -3386
  46. mindspore/common/api.py +386 -355
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/generator.py +3 -0
  52. mindspore/common/hook_handle.py +5 -3
  53. mindspore/common/initializer.py +10 -6
  54. mindspore/common/jit_begin_end.py +94 -0
  55. mindspore/common/jit_config.py +6 -1
  56. mindspore/common/jit_context.py +76 -0
  57. mindspore/common/jit_trace.py +378 -0
  58. mindspore/common/lazy_inline.py +2 -2
  59. mindspore/common/mutable.py +5 -4
  60. mindspore/common/parameter.py +106 -39
  61. mindspore/common/seed.py +2 -2
  62. mindspore/common/sparse_tensor.py +23 -17
  63. mindspore/common/tensor.py +332 -714
  64. mindspore/communication/__init__.py +7 -5
  65. mindspore/communication/_comm_helper.py +47 -2
  66. mindspore/communication/comm_func.py +70 -53
  67. mindspore/communication/management.py +83 -17
  68. mindspore/context.py +228 -571
  69. mindspore/dataset/__init__.py +44 -20
  70. mindspore/dataset/audio/__init__.py +2 -8
  71. mindspore/dataset/audio/transforms.py +3 -17
  72. mindspore/dataset/core/config.py +3 -3
  73. mindspore/dataset/engine/cache_client.py +1 -1
  74. mindspore/dataset/engine/datasets.py +102 -120
  75. mindspore/dataset/engine/datasets_audio.py +22 -22
  76. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  77. mindspore/dataset/engine/datasets_text.py +78 -85
  78. mindspore/dataset/engine/datasets_user_defined.py +109 -77
  79. mindspore/dataset/engine/datasets_vision.py +111 -108
  80. mindspore/dataset/engine/iterators.py +5 -3
  81. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  82. mindspore/dataset/engine/samplers.py +279 -57
  83. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  84. mindspore/dataset/engine/validators.py +10 -0
  85. mindspore/dataset/text/__init__.py +7 -6
  86. mindspore/dataset/text/transforms.py +6 -5
  87. mindspore/dataset/text/utils.py +3 -3
  88. mindspore/dataset/transforms/__init__.py +0 -9
  89. mindspore/dataset/transforms/transforms.py +3 -3
  90. mindspore/dataset/utils/browse_dataset.py +1 -1
  91. mindspore/dataset/vision/__init__.py +2 -9
  92. mindspore/dataset/vision/transforms.py +202 -158
  93. mindspore/dataset/vision/utils.py +7 -5
  94. mindspore/device_context/ascend/op_debug.py +60 -1
  95. mindspore/device_context/ascend/op_tuning.py +0 -4
  96. mindspore/device_manager.py +39 -3
  97. mindspore/dnnl.dll +0 -0
  98. mindspore/dpcmi.dll +0 -0
  99. mindspore/experimental/es/embedding_service.py +35 -27
  100. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
  101. mindspore/experimental/map_parameter.py +4 -4
  102. mindspore/experimental/optim/adadelta.py +22 -26
  103. mindspore/experimental/optim/adagrad.py +4 -4
  104. mindspore/experimental/optim/adam.py +4 -0
  105. mindspore/experimental/optim/adamax.py +4 -4
  106. mindspore/experimental/optim/adamw.py +4 -0
  107. mindspore/experimental/optim/asgd.py +1 -1
  108. mindspore/experimental/optim/lr_scheduler.py +40 -22
  109. mindspore/experimental/optim/radam.py +5 -5
  110. mindspore/experimental/optim/rprop.py +1 -1
  111. mindspore/experimental/optim/sgd.py +1 -1
  112. mindspore/hal/contiguous_tensors_handle.py +6 -10
  113. mindspore/hal/device.py +55 -81
  114. mindspore/hal/event.py +38 -55
  115. mindspore/hal/memory.py +115 -147
  116. mindspore/hal/stream.py +81 -125
  117. mindspore/include/dataset/constants.h +7 -4
  118. mindspore/include/dataset/execute.h +2 -2
  119. mindspore/jpeg62.dll +0 -0
  120. mindspore/log.py +40 -2
  121. mindspore/mindrecord/__init__.py +20 -7
  122. mindspore/mindspore_backend_common.dll +0 -0
  123. mindspore/mindspore_backend_manager.dll +0 -0
  124. mindspore/mindspore_common.dll +0 -0
  125. mindspore/mindspore_core.dll +0 -0
  126. mindspore/mindspore_dump.dll +0 -0
  127. mindspore/mindspore_frontend.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_memory_pool.dll +0 -0
  130. mindspore/mindspore_ms_backend.dll +0 -0
  131. mindspore/mindspore_ops.dll +0 -0
  132. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  133. mindspore/mindspore_ops_kernel_common.dll +0 -0
  134. mindspore/mindspore_profiler.dll +0 -0
  135. mindspore/mindspore_pyboost.dll +0 -0
  136. mindspore/mindspore_pynative.dll +0 -0
  137. mindspore/mindspore_res_manager.dll +0 -0
  138. mindspore/mindspore_runtime_pipeline.dll +0 -0
  139. mindspore/mint/__init__.py +133 -702
  140. mindspore/mint/distributed/__init__.py +5 -1
  141. mindspore/mint/distributed/distributed.py +198 -113
  142. mindspore/mint/linalg/__init__.py +2 -0
  143. mindspore/mint/nn/__init__.py +280 -18
  144. mindspore/mint/nn/functional.py +282 -64
  145. mindspore/mint/nn/layer/__init__.py +4 -0
  146. mindspore/mint/nn/layer/_functions.py +7 -3
  147. mindspore/mint/nn/layer/activation.py +120 -13
  148. mindspore/mint/nn/layer/conv.py +234 -28
  149. mindspore/mint/nn/layer/normalization.py +15 -16
  150. mindspore/mint/nn/layer/padding.py +1 -1
  151. mindspore/mint/nn/layer/pooling.py +66 -1
  152. mindspore/mint/optim/__init__.py +2 -1
  153. mindspore/mint/optim/sgd.py +171 -0
  154. mindspore/msobj140.dll +0 -0
  155. mindspore/mspdb140.dll +0 -0
  156. mindspore/mspdbcore.dll +0 -0
  157. mindspore/mspdbst.dll +0 -0
  158. mindspore/mspft140.dll +0 -0
  159. mindspore/msvcdis140.dll +0 -0
  160. mindspore/msvcp140_1.dll +0 -0
  161. mindspore/msvcp140_2.dll +0 -0
  162. mindspore/msvcp140_atomic_wait.dll +0 -0
  163. mindspore/msvcp140_codecvt_ids.dll +0 -0
  164. mindspore/nn/__init__.py +4 -1
  165. mindspore/nn/cell.py +1253 -179
  166. mindspore/nn/layer/activation.py +23 -21
  167. mindspore/nn/layer/basic.py +22 -16
  168. mindspore/nn/layer/container.py +1 -1
  169. mindspore/nn/layer/conv.py +53 -42
  170. mindspore/nn/layer/embedding.py +9 -8
  171. mindspore/nn/layer/normalization.py +48 -42
  172. mindspore/nn/layer/pooling.py +75 -31
  173. mindspore/nn/layer/transformer.py +11 -10
  174. mindspore/nn/learning_rate_schedule.py +4 -2
  175. mindspore/nn/loss/loss.py +27 -19
  176. mindspore/nn/optim/ada_grad.py +6 -5
  177. mindspore/nn/optim/adadelta.py +9 -7
  178. mindspore/nn/optim/adafactor.py +1 -1
  179. mindspore/nn/optim/adam.py +18 -14
  180. mindspore/nn/optim/adamax.py +8 -7
  181. mindspore/nn/optim/adasum.py +5 -5
  182. mindspore/nn/optim/asgd.py +3 -1
  183. mindspore/nn/optim/ftrl.py +11 -9
  184. mindspore/nn/optim/lamb.py +1 -1
  185. mindspore/nn/optim/lazyadam.py +12 -10
  186. mindspore/nn/optim/momentum.py +7 -6
  187. mindspore/nn/optim/optimizer.py +2 -2
  188. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  189. mindspore/nn/optim/rmsprop.py +13 -12
  190. mindspore/nn/optim/rprop.py +9 -7
  191. mindspore/nn/optim/sgd.py +9 -6
  192. mindspore/nn/optim/tft_wrapper.py +5 -2
  193. mindspore/nn/probability/bijector/bijector.py +17 -11
  194. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  195. mindspore/nn/probability/bijector/invert.py +2 -2
  196. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  197. mindspore/nn/probability/bijector/softplus.py +3 -2
  198. mindspore/nn/probability/distribution/beta.py +3 -3
  199. mindspore/nn/probability/distribution/categorical.py +1 -1
  200. mindspore/nn/probability/distribution/cauchy.py +4 -2
  201. mindspore/nn/probability/distribution/exponential.py +6 -7
  202. mindspore/nn/probability/distribution/gamma.py +2 -2
  203. mindspore/nn/probability/distribution/gumbel.py +2 -2
  204. mindspore/nn/probability/distribution/half_normal.py +5 -3
  205. mindspore/nn/probability/distribution/logistic.py +5 -3
  206. mindspore/nn/probability/distribution/poisson.py +1 -1
  207. mindspore/nn/probability/distribution/uniform.py +5 -3
  208. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  209. mindspore/nn/reinforcement/tensor_array.py +1 -1
  210. mindspore/nn/wrap/__init__.py +6 -6
  211. mindspore/nn/wrap/cell_wrapper.py +178 -117
  212. mindspore/nn/wrap/grad_reducer.py +45 -36
  213. mindspore/nn/wrap/loss_scale.py +3 -3
  214. mindspore/numpy/array_creations.py +3 -3
  215. mindspore/numpy/array_ops.py +1 -1
  216. mindspore/numpy/utils.py +1 -2
  217. mindspore/numpy/utils_const.py +1 -2
  218. mindspore/opencv_core452.dll +0 -0
  219. mindspore/opencv_imgcodecs452.dll +0 -0
  220. mindspore/opencv_imgproc452.dll +0 -0
  221. mindspore/ops/__init__.py +3 -2
  222. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  223. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  224. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  225. mindspore/ops/_register_for_op.py +0 -11
  226. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  227. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  228. mindspore/ops/_vmap/vmap_array_ops.py +32 -6
  229. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  230. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  231. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  232. mindspore/ops/auto_generate/__init__.py +4 -3
  233. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
  234. mindspore/ops/auto_generate/gen_extend_func.py +286 -208
  235. mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
  236. mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
  237. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  238. mindspore/ops/composite/__init__.py +2 -1
  239. mindspore/ops/composite/base.py +19 -24
  240. mindspore/ops/composite/math_ops.py +6 -16
  241. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  242. mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
  243. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  244. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  248. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  249. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  250. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  251. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  252. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  254. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  255. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  256. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  257. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  259. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  260. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  263. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  264. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  268. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  271. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  272. mindspore/ops/function/__init__.py +28 -2
  273. mindspore/ops/function/_add_attr_func.py +58 -0
  274. mindspore/ops/function/array_func.py +1631 -2347
  275. mindspore/ops/function/clip_func.py +38 -45
  276. mindspore/ops/function/debug_func.py +36 -44
  277. mindspore/ops/function/grad/__init__.py +1 -0
  278. mindspore/ops/function/grad/grad_func.py +104 -71
  279. mindspore/ops/function/image_func.py +1 -1
  280. mindspore/ops/function/linalg_func.py +46 -78
  281. mindspore/ops/function/math_func.py +3024 -3855
  282. mindspore/ops/function/nn_func.py +678 -274
  283. mindspore/ops/function/other_func.py +159 -1
  284. mindspore/ops/function/parameter_func.py +17 -30
  285. mindspore/ops/function/random_func.py +216 -361
  286. mindspore/ops/function/reshard_func.py +4 -70
  287. mindspore/ops/function/sparse_func.py +3 -3
  288. mindspore/ops/function/sparse_unary_func.py +5 -5
  289. mindspore/ops/function/spectral_func.py +25 -58
  290. mindspore/ops/function/vmap_func.py +26 -18
  291. mindspore/ops/functional.py +8 -5
  292. mindspore/ops/functional_overload.py +655 -4
  293. mindspore/ops/op_info_register.py +32 -244
  294. mindspore/ops/operations/__init__.py +21 -14
  295. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  296. mindspore/ops/operations/_grad_ops.py +1 -10
  297. mindspore/ops/operations/_inner_ops.py +5 -76
  298. mindspore/ops/operations/_ms_kernel.py +4 -10
  299. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  300. mindspore/ops/operations/_scalar_ops.py +3 -2
  301. mindspore/ops/operations/_sequence_ops.py +1 -1
  302. mindspore/ops/operations/_tensor_array.py +1 -1
  303. mindspore/ops/operations/array_ops.py +39 -24
  304. mindspore/ops/operations/comm_ops.py +150 -107
  305. mindspore/ops/operations/custom_ops.py +287 -32
  306. mindspore/ops/operations/debug_ops.py +119 -16
  307. mindspore/ops/operations/inner_ops.py +1 -1
  308. mindspore/ops/operations/linalg_ops.py +1 -58
  309. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  310. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  311. mindspore/ops/operations/math_ops.py +21 -18
  312. mindspore/ops/operations/nn_ops.py +67 -224
  313. mindspore/ops/operations/other_ops.py +62 -9
  314. mindspore/ops/operations/random_ops.py +13 -7
  315. mindspore/ops/operations/reshard_ops.py +1 -1
  316. mindspore/ops/operations/sparse_ops.py +2 -2
  317. mindspore/ops/primitive.py +43 -32
  318. mindspore/ops/tensor_method.py +243 -17
  319. mindspore/ops_generate/__init__.py +0 -5
  320. mindspore/ops_generate/aclnn/__init__.py +0 -0
  321. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  322. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  323. mindspore/ops_generate/api/__init__.py +0 -0
  324. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  325. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  326. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  327. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  328. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  329. mindspore/ops_generate/api/gen_api.py +103 -0
  330. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  331. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  332. mindspore/ops_generate/common/__init__.py +0 -0
  333. mindspore/ops_generate/common/gen_constants.py +91 -0
  334. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  335. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  336. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  337. mindspore/ops_generate/gen_ops.py +23 -325
  338. mindspore/ops_generate/op_def/__init__.py +0 -0
  339. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  340. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  341. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
  342. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  343. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  344. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  345. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  346. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  347. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  348. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  349. mindspore/ops_generate/pyboost/__init__.py +0 -0
  350. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  351. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  352. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  353. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  354. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  355. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  356. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  357. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  358. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  359. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  360. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  361. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  362. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  363. mindspore/ops_generate/resources/__init__.py +0 -0
  364. mindspore/ops_generate/resources/resource_list.py +30 -0
  365. mindspore/ops_generate/resources/resource_loader.py +36 -0
  366. mindspore/ops_generate/resources/resource_manager.py +64 -0
  367. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  368. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  369. mindspore/parallel/__init__.py +6 -2
  370. mindspore/parallel/_auto_parallel_context.py +140 -12
  371. mindspore/parallel/_cell_wrapper.py +132 -15
  372. mindspore/parallel/_parallel_serialization.py +95 -4
  373. mindspore/parallel/_ps_context.py +1 -1
  374. mindspore/parallel/_recovery_context.py +7 -2
  375. mindspore/parallel/_tensor.py +142 -18
  376. mindspore/parallel/_utils.py +198 -25
  377. mindspore/parallel/algo_parameter_config.py +3 -3
  378. mindspore/parallel/auto_parallel.py +732 -0
  379. mindspore/parallel/checkpoint_convert.py +159 -0
  380. mindspore/parallel/checkpoint_transform.py +658 -37
  381. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  382. mindspore/parallel/cluster/run.py +1 -1
  383. mindspore/parallel/function/__init__.py +24 -0
  384. mindspore/parallel/function/reshard_func.py +258 -0
  385. mindspore/parallel/nn/__init__.py +25 -0
  386. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  387. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  388. mindspore/parallel/parameter_broadcast.py +24 -13
  389. mindspore/parallel/shard.py +137 -62
  390. mindspore/parallel/transform_safetensors.py +288 -95
  391. mindspore/pgodb140.dll +0 -0
  392. mindspore/pgort140.dll +0 -0
  393. mindspore/profiler/__init__.py +9 -5
  394. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  395. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  397. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
  398. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  399. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  400. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  401. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  402. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  403. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  404. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  405. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  406. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  407. mindspore/profiler/common/constant.py +12 -0
  408. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  409. mindspore/profiler/common/path_manager.py +24 -0
  410. mindspore/profiler/common/profiler_context.py +26 -2
  411. mindspore/profiler/common/profiler_meta_data.py +74 -0
  412. mindspore/profiler/common/profiler_parameters.py +59 -18
  413. mindspore/profiler/common/profiler_path_manager.py +66 -7
  414. mindspore/profiler/dynamic_profiler.py +112 -79
  415. mindspore/profiler/envprofiler.py +26 -1
  416. mindspore/profiler/experimental_config.py +197 -0
  417. mindspore/profiler/mstx.py +57 -14
  418. mindspore/profiler/platform/npu_profiler.py +33 -7
  419. mindspore/profiler/profiler.py +541 -45
  420. mindspore/profiler/profiler_action_controller.py +1 -1
  421. mindspore/profiler/profiler_interface.py +4 -0
  422. mindspore/profiler/schedule.py +57 -22
  423. mindspore/rewrite/api/node.py +15 -13
  424. mindspore/rewrite/api/symbol_tree.py +1 -1
  425. mindspore/run_check/_check_version.py +25 -14
  426. mindspore/run_check/run_check.py +1 -1
  427. mindspore/runtime/__init__.py +2 -2
  428. mindspore/runtime/executor.py +40 -11
  429. mindspore/runtime/memory.py +37 -13
  430. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  431. mindspore/swresample-4.dll +0 -0
  432. mindspore/swscale-6.dll +0 -0
  433. mindspore/tbbmalloc.dll +0 -0
  434. mindspore/tinyxml2.dll +0 -0
  435. mindspore/train/__init__.py +8 -8
  436. mindspore/train/_utils.py +43 -9
  437. mindspore/train/amp.py +1 -1
  438. mindspore/train/callback/__init__.py +2 -2
  439. mindspore/train/callback/_callback.py +2 -16
  440. mindspore/train/callback/_checkpoint.py +24 -40
  441. mindspore/train/callback/_cluster_monitor.py +14 -18
  442. mindspore/train/callback/_flops_collector.py +2 -3
  443. mindspore/train/callback/_history.py +7 -4
  444. mindspore/train/callback/_lambda_callback.py +2 -2
  445. mindspore/train/callback/_landscape.py +0 -3
  446. mindspore/train/callback/_loss_monitor.py +2 -1
  447. mindspore/train/callback/_on_request_exit.py +6 -5
  448. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  449. mindspore/train/callback/_summary_collector.py +8 -13
  450. mindspore/train/callback/_time_monitor.py +2 -1
  451. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
  452. mindspore/train/data_sink.py +25 -2
  453. mindspore/train/dataset_helper.py +4 -5
  454. mindspore/train/loss_scale_manager.py +8 -7
  455. mindspore/train/metrics/accuracy.py +3 -3
  456. mindspore/train/metrics/confusion_matrix.py +9 -9
  457. mindspore/train/metrics/error.py +3 -3
  458. mindspore/train/metrics/hausdorff_distance.py +4 -4
  459. mindspore/train/metrics/mean_surface_distance.py +3 -3
  460. mindspore/train/metrics/metric.py +0 -12
  461. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  462. mindspore/train/metrics/precision.py +8 -6
  463. mindspore/train/metrics/recall.py +9 -9
  464. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  465. mindspore/train/mind_ir_pb2.py +19 -12
  466. mindspore/train/model.py +262 -127
  467. mindspore/train/serialization.py +246 -988
  468. mindspore/train/summary/_summary_adapter.py +2 -2
  469. mindspore/train/summary/summary_record.py +1 -1
  470. mindspore/turbojpeg.dll +0 -0
  471. mindspore/utils/__init__.py +3 -2
  472. mindspore/utils/dryrun.py +4 -2
  473. mindspore/utils/hooks.py +81 -0
  474. mindspore/utils/runtime_execution_order_check.py +2 -0
  475. mindspore/utils/utils.py +138 -4
  476. mindspore/vcmeta.dll +0 -0
  477. mindspore/vcruntime140.dll +0 -0
  478. mindspore/vcruntime140_1.dll +0 -0
  479. mindspore/version.py +1 -1
  480. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
  481. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
  482. mindspore/_install_custom.py +0 -43
  483. mindspore/common/_register_for_adapter.py +0 -74
  484. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  485. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  486. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  487. mindspore/ops_generate/gen_constants.py +0 -190
  488. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  489. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  490. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  492. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
  493. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -18,17 +18,20 @@ from __future__ import division
18
18
 
19
19
  import numpy as np
20
20
 
21
+ import mindspore.log as logger
21
22
  from mindspore import context
22
23
  from mindspore.nn.cell import Cell
23
24
  from mindspore.ops import operations as P
24
25
  from mindspore.ops.operations.comm_ops import AllGather
25
- from mindspore.communication import GlobalComm
26
+ from mindspore.communication import GlobalComm, get_rank
26
27
  from mindspore.common import jit
27
- from mindspore.communication import create_group, destroy_group
28
+ from mindspore.communication import create_group, destroy_group, get_group_size
28
29
  from mindspore.communication._comm_helper import _get_group_map
29
30
  from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
31
+ from mindspore.parallel.shard import Layout
30
32
 
31
33
  _ALLGATHER_CELL = None
34
+ ALLREDUCE_GROUP_LIST = []
32
35
 
33
36
 
34
37
  class AllGatherCell(Cell):
@@ -134,7 +137,7 @@ def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
134
137
 
135
138
  def _get_group_name(group_map, group):
136
139
  """get group name"""
137
- group_name = str(group)
140
+ group_name = "remove_redundancy" + str(group)
138
141
  is_manual_communication_group = True
139
142
  if group_map:
140
143
  for name, rank_list in group_map.items():
@@ -142,20 +145,37 @@ def _get_group_name(group_map, group):
142
145
  group_name = name
143
146
  is_manual_communication_group = False
144
147
  break
145
- if is_manual_communication_group:
146
- create_group(str(group), list(group))
147
148
  return group_name, is_manual_communication_group
148
149
 
149
150
 
150
- def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
151
+ def _get_param_redundancy_reversed(param_redundancy, cur_rank):
152
+ """Generate the reverse mapping of parameter redundancy based on the current rank."""
153
+ param_redundancy_reversed = {}
154
+ for key, redundancy in param_redundancy.items():
155
+ for item in redundancy:
156
+ if len(item) == 1:
157
+ continue
158
+ if cur_rank in item:
159
+ param_redundancy_reversed.setdefault(item, []).append(key)
160
+ return param_redundancy_reversed
161
+
162
+
163
+ def _remove_param_not_load(param_name, param_not_load):
164
+ """Remove param_name from param_not_load."""
165
+ if param_not_load is not None and param_name in param_not_load:
166
+ param_not_load.remove(param_name)
167
+
168
+
169
+ def _single_parameter_broadcast(net, layout, param_not_load=None):
151
170
  """
152
171
  Broadcast single parameter to other rank in data parallel dimension.
153
172
  """
154
173
  from mindspore import Tensor
155
174
  origin_parallel_mode = context.get_auto_parallel_context("parallel_mode")
156
175
  origin_dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
176
+ cur_rank = get_rank()
157
177
  if layout:
158
- param_redundancy = get_parameter_redundancy(layout, initial_rank)
178
+ param_redundancy = get_parameter_redundancy(layout)
159
179
  else:
160
180
  param_redundancy = get_parameter_redundancy(net)
161
181
  if not param_redundancy:
@@ -163,33 +183,130 @@ def _single_parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
163
183
  single_params = remove_param_redundancy(param_redundancy)
164
184
  if not single_params:
165
185
  return
166
- param_redundancy_reversed = {}
167
- for key, redundancy in param_redundancy.items():
168
- for item in redundancy:
169
- if len(item) == 1:
170
- continue
171
- if cur_rank in item:
172
- param_redundancy_reversed.setdefault(item, []).append(key)
186
+ param_redundancy_reversed = _get_param_redundancy_reversed(param_redundancy, cur_rank)
173
187
  if not param_redundancy_reversed or cur_rank not in single_params:
174
188
  return
175
189
  net_param_dict = net.parameters_dict()
176
190
  _chang_parallel_context(origin_dataset_strategy)
177
191
  group_map = _get_group_map()
192
+ if group_map:
193
+ group_map = {key: group_map[key] for key in sorted(group_map.keys())}
178
194
  for group, params in param_redundancy_reversed.items():
179
195
  group_name, is_manual_communication_group = _get_group_name(group_map, group)
180
196
  allreduce_input = []
181
197
  for param in params:
182
198
  if param not in net_param_dict:
183
199
  continue
200
+ if param.startswith("accu_grads") or param.endswith("expert_load"):
201
+ continue
184
202
  real_param = net_param_dict[param]
203
+ _remove_param_not_load(real_param.name, param_not_load)
185
204
  if param not in single_params[cur_rank]:
186
205
  real_param.set_data(Tensor(np.zeros(real_param.shape), dtype=real_param.dtype), real_param.sliced)
187
206
  allreduce_input.append(real_param)
188
207
  if not allreduce_input:
189
208
  continue
209
+ if is_manual_communication_group:
210
+ create_group(group_name, list(group))
211
+ allreduce_input.sort(key=lambda param: (str(param.shape), str(param.dtype)))
190
212
  communicator = SingleCommunicator(group_name)
191
213
  for real_param in allreduce_input:
192
- real_param.set_data(communicator(real_param), real_param.sliced)
214
+ real_param.set_data(communicator(Tensor(real_param)), real_param.sliced)
193
215
  if is_manual_communication_group:
194
216
  destroy_group(group_name)
195
217
  _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy)
218
+
219
+
220
+ def _insert_virtual_pp_dim(layout):
221
+ """insert virtual pp dim in device matrix and create new layout"""
222
+ if len(layout.to_dict()["rank_list"]) == get_group_size():
223
+ return layout
224
+ remain_pp = get_group_size() // len(layout.to_dict()["rank_list"])
225
+ layout_info = layout.to_dict()
226
+ device_matrix = layout_info["device_matrix"]
227
+ tensor_map = layout_info["tensor_map"]
228
+ alias_name = layout_info["alias_name"]
229
+ new_devmat = Layout((remain_pp,) + device_matrix, ("remain_pp",) + alias_name)
230
+ tensor_map_alias_name = []
231
+ for val in tensor_map:
232
+ sub_alias_name = []
233
+ if isinstance(val, tuple):
234
+ for sub_val in val:
235
+ if sub_val == -1:
236
+ sub_alias_name.append("None")
237
+ else:
238
+ sub_alias_name.append(alias_name[len(device_matrix) - sub_val - 1])
239
+ tensor_map_alias_name.append(tuple(sub_alias_name))
240
+ else:
241
+ if val == -1:
242
+ tensor_map_alias_name.append("None")
243
+ else:
244
+ tensor_map_alias_name.append(alias_name[len(device_matrix) - val - 1])
245
+ new_layout = new_devmat(*tensor_map_alias_name)
246
+ return new_layout
247
+
248
+
249
+ class CommTensorDataForPP(Cell):
250
+ """Communicate tensor data for pipeline parallel scenario."""
251
+
252
+ def __init__(self, src_dtensor_info, dst_dtensor_info):
253
+ super().__init__()
254
+ self.zeros = P.Zeros()
255
+
256
+ self._current_rank_id = get_rank()
257
+ self._from_dev_num_in_stage = len(src_dtensor_info.layout.to_dict()["rank_list"])
258
+ self._from_rank_id = src_dtensor_info.layout.to_dict()["rank_list"]
259
+ self._current_rank_has_data = self._current_rank_id in src_dtensor_info.layout.to_dict()["rank_list"]
260
+ self._diff_rank_id = [
261
+ rank_id for rank_id in dst_dtensor_info.layout.to_dict()["rank_list"] if rank_id not in self._from_rank_id]
262
+ self._group, self._root_idx = self._create_all_reduce_group()
263
+
264
+ def comm_data(self, comm_data):
265
+ """communicate data"""
266
+ from mindspore import mint
267
+ comm_handle = mint.distributed.broadcast(comm_data, self._root_idx, self._group, async_op=False)
268
+ return comm_handle
269
+
270
+ def _create_all_reduce_group(self):
271
+ """create all reduce group"""
272
+ global ALLREDUCE_GROUP_LIST
273
+ current_rank_stage_id = self._current_rank_id // self._from_dev_num_in_stage
274
+ end_stage = self._from_dev_num_in_stage * (current_rank_stage_id + 1)
275
+ rank_pos_in_stage = [rank_id for rank_id in range(self._from_dev_num_in_stage * current_rank_stage_id,
276
+ end_stage)].index(self._current_rank_id)
277
+ root_idx = self._from_rank_id[rank_pos_in_stage]
278
+ all_reduce_rank_list = [self._from_rank_id[rank_pos_in_stage]]
279
+ while rank_pos_in_stage < len(self._diff_rank_id):
280
+ all_reduce_rank_list.append(self._diff_rank_id[rank_pos_in_stage])
281
+ rank_pos_in_stage += self._from_dev_num_in_stage
282
+ all_reduce_rank_list.sort()
283
+ str_rank_list = '-'.join([str(rank) for rank in all_reduce_rank_list])
284
+ all_reduce_group = f"pp_allreduce_group-{str_rank_list}"
285
+ if all_reduce_group in ALLREDUCE_GROUP_LIST:
286
+ return all_reduce_group, root_idx
287
+ ALLREDUCE_GROUP_LIST.append(all_reduce_group)
288
+ create_group(all_reduce_group, all_reduce_rank_list)
289
+ logger.debug(f"Create group {all_reduce_group} for tensor data communication.")
290
+ return all_reduce_group, root_idx
291
+
292
+
293
+ class RedistributionCell(Cell):
294
+ """Redistribute src_layout to dst_layout"""
295
+
296
+ def __init__(self, src_layout, dst_layout):
297
+ super().__init__()
298
+ if src_layout is None or dst_layout is None:
299
+ raise ValueError("src_layout and dst_layout should not be None.")
300
+ self._total_dev_num = get_group_size()
301
+ src_layout = _insert_virtual_pp_dim(src_layout)
302
+ dst_layout = _insert_virtual_pp_dim(dst_layout)
303
+ self.src_identity = P.Identity().shard(in_strategy=(src_layout,), out_strategy=(src_layout,))
304
+ self.src_identity.add_prim_attr("self_define_shard", True)
305
+ self.dst_identity = P.Identity().shard(in_strategy=(dst_layout,), out_strategy=(dst_layout,))
306
+ self.dst_identity.add_prim_attr("self_define_shard", True)
307
+
308
+ def construct(self, input_tensor):
309
+ """run redistribution"""
310
+ src_tensor = self.src_identity(input_tensor)
311
+ dst_tensor = self.dst_identity(src_tensor)
312
+ return dst_tensor
@@ -19,6 +19,7 @@ import os
19
19
  import json
20
20
  import numpy as np
21
21
  import mindspore as ms
22
+ from mindspore import _checkparam as Validator
22
23
  from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
23
24
  _get_needed_rank_list_by_layouts, _get_needed_rank_transform_operator_map_by_layouts, \
24
25
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
@@ -34,7 +35,12 @@ def _convert_to_list(strategy, rank_id=None):
34
35
  try:
35
36
  layout = strategy.get(param_name)
36
37
  dev_mat = list(layout.dev_matrix[0].dim)
37
- tensor_map = list(layout.tensor_map[0].dim)
38
+ # for layout one axis two slices, layout(("dp", "mp"), "None")
39
+ if len(layout.tensor_map) > 1:
40
+ tensor_map = [list(tensor_map.dim) for tensor_map in layout.tensor_map
41
+ if list(tensor_map.dim)]
42
+ else:
43
+ tensor_map = list(layout.tensor_map[0].dim)
38
44
  param_split_shape = list(layout.param_split_shape[0].dim)
39
45
  field_size = int(layout.field)
40
46
  shard_stride = int(layout.opt_weight_shard_step)
@@ -417,7 +423,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
417
423
  from_opt_shard_size = 0
418
424
  if src_strategy_list is not None:
419
425
  if param_name not in src_strategy_list:
420
- ms.log.warning("The parameter {} is not in src_strategy.".format(param_name))
426
+ ms.log.info("The parameter {} is not in src_strategy.".format(param_name))
421
427
  continue
422
428
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
423
429
  src_strategy_list.get(param_name))
@@ -427,7 +433,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
427
433
  to_opt_shard_size = 0
428
434
  if dst_strategy_list is not None:
429
435
  if param_name not in dst_strategy_list:
430
- ms.log.warning("The parameter {} is not in dst_strategy.".format(param_name))
436
+ ms.log.info("The parameter {} is not in dst_strategy.".format(param_name))
431
437
  continue
432
438
  to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
433
439
  dst_strategy_list.get(param_name))
@@ -441,6 +447,9 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
441
447
  continue
442
448
  origin_tensor_shape += (item * param_strategy[i],)
443
449
 
450
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
451
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
452
+
444
453
  from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
445
454
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
446
455
  to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
@@ -460,6 +469,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
460
469
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
461
470
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
462
471
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
472
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
463
473
  transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
464
474
  param_total_dict_copy = param_total_dict[param_name].copy()
465
475
  _apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
@@ -556,6 +566,32 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
556
566
  param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
557
567
 
558
568
 
569
+ def _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
570
+ insert_from_reshape, insert_to_reshape):
571
+ """ insert layout expand op reshape """
572
+ from_opt_shard_size = from_info_tuple[0]
573
+ from_dev_matrix = from_info_tuple[1]
574
+ from_tensor_map = from_info_tuple[2]
575
+ from_full_tensor_shape = from_info_tuple[3]
576
+ to_opt_shard_size = to_info_tuple[0]
577
+ to_dev_matrix_origin = to_info_tuple[1]
578
+ to_tensor_map_origin = to_info_tuple[2]
579
+ origin_tensor_shape = to_info_tuple[3]
580
+ for param_rank, _ in param_rank_map.items():
581
+ if from_opt_shard_size == 0 and insert_from_reshape:
582
+ from_slice_tensor_shape = ()
583
+ from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
584
+ for i, item in enumerate(from_full_tensor_shape):
585
+ from_slice_tensor_shape += (item // from_tensor_strategy[i],)
586
+ param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
587
+ if to_opt_shard_size == 0 and insert_to_reshape:
588
+ to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
589
+ to_slice_tensor_shape = ()
590
+ for i, item in enumerate(origin_tensor_shape):
591
+ to_slice_tensor_shape += (item // to_tensor_strategy[i],)
592
+ param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
593
+
594
+
559
595
  def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
560
596
  """Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
561
597
  total_device_num = 1
@@ -569,4 +605,59 @@ def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded
569
605
  start = rank - offset
570
606
  param_total_list = list(range(start, start + range_size))
571
607
  return param_total_list
572
-
608
+
609
+
610
+ def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
611
+ output_format, name_map, return_param_dict):
612
+ """gather transform tasks"""
613
+ tasks = []
614
+ for rank in range(0, dst_device_num):
615
+ tasks.append(
616
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
617
+ return_param_dict))
618
+ return tasks
619
+
620
+
621
+ def _check_checkpoint_file(checkpoint_filenames):
622
+ """Check checkpoint file name."""
623
+ for index, filename in enumerate(checkpoint_filenames):
624
+ if not isinstance(filename, str) or not os.path.exists(filename) \
625
+ or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
626
+ raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
627
+ f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
628
+ f"be a string ending with '.ckpt', and the checkpoint file it represents must "
629
+ f"be exist and not empty.")
630
+
631
+
632
+ def _check_predict_strategy(predict_strategy):
633
+ """Check predict strategy."""
634
+
635
+ def _check_int_list(arg):
636
+ if not isinstance(arg, list):
637
+ return False
638
+ for item in arg:
639
+ if not isinstance(item, int):
640
+ return False
641
+ return True
642
+
643
+ if predict_strategy is None:
644
+ return
645
+
646
+ flag = True
647
+ predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
648
+ for key in predict_strategy.keys():
649
+ if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
650
+ or len(predict_strategy[key]) < 4:
651
+ flag = False
652
+ dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
653
+ if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
654
+ not (_check_int_list(param_split_shape) or not param_split_shape) or \
655
+ not (isinstance(field_size, int) and field_size == 0):
656
+ flag = False
657
+
658
+ if not flag:
659
+ raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
660
+ f"the key of it must be string, and the value of it must be list or tuple that "
661
+ f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
662
+ f"param_split_shape (list[int]) and field_size (int, which value is 0)."
663
+ f"Please check whether 'predict_strategy' is correct.")
@@ -115,7 +115,7 @@ def _set_ps_context(**kwargs):
115
115
  enable_ps (bool): Whether to enable parameter server training mode.
116
116
  Only after enable_ps is set True, the environment variables will be effective.
117
117
  Default: ``False``.
118
- config_file_path (string): Configuration file path used by recovery. Default: ''.
118
+ config_file_path (str): Configuration file path used by recovery. Default: ''.
119
119
  scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
120
120
  enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False``.
121
121
  client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ''.
@@ -33,18 +33,23 @@ def recovery_context():
33
33
  RECOVERY_CONTEXT = RecoveryContext.get_instance()
34
34
  return RECOVERY_CONTEXT
35
35
 
36
+
36
37
  _set_recovery_context_func_map = {
37
38
  "ckpt_path": recovery_context().set_ckpt_path,
38
- "need_reset": recovery_context().set_need_reset
39
+ "need_reset": recovery_context().set_need_reset,
40
+ "is_reboot_node": recovery_context().set_is_reboot_node,
41
+ "is_arf": recovery_context().set_is_arf
39
42
  }
40
43
 
41
44
  _get_recovery_context_func_map = {
42
45
  "enable_recovery": recovery_context().enable_recovery,
46
+ "enable_repeat_register": recovery_context().enable_repeat_register,
43
47
  "latest_ckpt_file": recovery_context().latest_ckpt_file,
44
48
  "latest_ckpt_epoch": recovery_context().latest_ckpt_epoch,
45
49
  "latest_ckpt_step": recovery_context().latest_ckpt_step,
46
50
  "need_reset": recovery_context().need_reset,
47
51
  "recovery_path": recovery_context().recovery_path,
52
+ "is_arf": recovery_context().is_arf,
48
53
  "ckpt_path": recovery_context().ckpt_path
49
54
  }
50
55
 
@@ -64,7 +69,7 @@ def _set_recovery_context(**kwargs):
64
69
  MS_RECOVERY_INTERVAL # The persistent interval for recovery
65
70
 
66
71
  Args:
67
- ckpt_path (string): Set the recovery path used to save checkpoint. Default: ''.
72
+ ckpt_path (str): Set the recovery path used to save checkpoint. Default: ''.
68
73
  need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery.
69
74
  Default: ``False``.
70
75
 
@@ -38,10 +38,17 @@ def _get_tensor_strategy(dev_mat, tensor_map):
38
38
  """
39
39
  tensor_strategy = []
40
40
  for dim in tensor_map:
41
- if dim == -1:
42
- tensor_strategy.append(1)
41
+ if isinstance(dim, (tuple, list)):
42
+ acc_stra = 1
43
+ for i in dim:
44
+ if i != -1:
45
+ acc_stra *= dev_mat[len(dev_mat) - i - 1]
46
+ tensor_strategy.append(acc_stra)
43
47
  else:
44
- tensor_strategy.append(dev_mat[-dim - 1])
48
+ if dim == -1:
49
+ tensor_strategy.append(1)
50
+ else:
51
+ tensor_strategy.append(dev_mat[-dim - 1])
45
52
  return tensor_strategy
46
53
 
47
54
 
@@ -182,7 +189,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
182
189
  Args:
183
190
  dev_mat (list): The device matrix of devices.
184
191
  tensor_map (list): The split strategy of tensor.
185
- opt_shard_group(string): The group of optimizer shard
192
+ opt_shard_group(str): The group of optimizer shard
186
193
 
187
194
  Returns:
188
195
  Integer, the slice index for slice on this device.
@@ -388,6 +395,124 @@ def _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
388
395
  return from_tensor_layout, to_tensor_layout
389
396
 
390
397
 
398
+ def _expand_layout(dev_matrix, tensor_map, tensor_shape):
399
+ """
400
+ expand nested tensor_map and reshape tensor shape according to tensor_map
401
+ dev_matrix = [4, 2, 2]
402
+ tensor_map = [[2, 1], 0]
403
+ tensor_shape = [8, 8]
404
+ =>
405
+ expanded_tensor_map = [2, 1, 0]
406
+ expanded_tensor_map = [4, 8/4, 8]
407
+ """
408
+ new_tensor_map = []
409
+ new_tensor_shape = []
410
+ for index, dim in enumerate(tensor_map):
411
+ if isinstance(dim, (tuple, list)):
412
+ accu_shape = 1
413
+ for i in range(len(dim) - 1):
414
+ new_tensor_map.append(dim[i])
415
+ new_tensor_shape.append(dev_matrix[len(dev_matrix) - 1 - dim[i]])
416
+ accu_shape *= dev_matrix[len(dev_matrix) - 1 - dim[i]]
417
+ new_tensor_map.append(dim[-1])
418
+ new_tensor_shape.append(tensor_shape[index] // accu_shape)
419
+ else:
420
+ new_tensor_map.append(dim)
421
+ new_tensor_shape.append(tensor_shape[index])
422
+ return dev_matrix, new_tensor_map, new_tensor_shape
423
+
424
+
425
+ def _construct_tensor_layout_for_opt_shard_by_layout(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
426
+ origin_full_tensor_shape):
427
+ """
428
+ Construct tensor layout for optimizer parallel when using layout.
429
+ For example, For Tensor with shape (4,2)
430
+ dev_matrix = [2, 2, 2, 2]
431
+ tensor_map = [[1, 0], -1]
432
+ opt_shard_size = 2
433
+ ==>
434
+ dev_matrix = [2, 2, 2, 2]
435
+ tensor_map = [[1, 0], 2, -1]
436
+ the new strategy is [4, 2, 1]
437
+ the tensor_shape should reshape to (model_parallel_size, -1, xx, xx)
438
+ first 4 means the model parallel sharding of data_dim
439
+ second 2 means the opt sharding of data_dim.
440
+ """
441
+ if opt_shard_step == 0 or opt_shard_size == 0:
442
+ return dev_matrix, tensor_map, list(origin_full_tensor_shape)
443
+ tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
444
+ repeated_dim = []
445
+ dev_sharded_index = []
446
+ dev_matrix, expanded_tensor_map, _ = _expand_layout(dev_matrix, tensor_map, origin_full_tensor_shape)
447
+ for dim in expanded_tensor_map:
448
+ if dim != -1:
449
+ dev_sharded_index.append(len(dev_matrix) - dim - 1)
450
+ for index, value in enumerate(dev_matrix):
451
+ if index not in dev_sharded_index and value > 1:
452
+ repeated_dim.append(index)
453
+ if not repeated_dim:
454
+ raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
455
+ format(dev_matrix, tensor_map))
456
+ return _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
457
+ tensor_strategy, repeated_dim)
458
+
459
+
460
+ def _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
461
+ tensor_strategy, repeated_dim):
462
+ """
463
+ helper function to assign repeated device_matrix dim for opt shard.
464
+ """
465
+ new_dev_matrix = list(copy.deepcopy(dev_matrix))
466
+ new_dev_matrix_map = list(range(len(dev_matrix)))
467
+ opt_shard_dim = []
468
+ remained_opt_shard_size = opt_shard_size if opt_shard_size != -1 else \
469
+ int(np.prod([dev_matrix[i] for i in repeated_dim]))
470
+ for dim in repeated_dim[::-1]:
471
+ opt_sharding_size = dev_matrix[dim]
472
+ if remained_opt_shard_size // opt_sharding_size == 0:
473
+ if opt_sharding_size % remained_opt_shard_size != 0:
474
+ raise ValueError("dev_matrix value {} at dim {} cannot be divided by needed opt sharding "
475
+ "size {}".format(dev_matrix[dim], len(dev_matrix) - dim - 1,
476
+ remained_opt_shard_size))
477
+ opt_sharding_size = remained_opt_shard_size
478
+ # update dev_matrix
479
+ new_dev_matrix[dim] = dev_matrix[dim] // opt_sharding_size
480
+ new_dev_matrix.insert(dim + 1, opt_sharding_size)
481
+ for i in range(len(dev_matrix) - dim - 1, len(dev_matrix)):
482
+ new_dev_matrix_map[i] += 1
483
+ if remained_opt_shard_size % opt_sharding_size != 0:
484
+ raise ValueError("Remained opt_shard_size {} cannot be divided by current sharding size {}, "
485
+ "the repeat dim is {} with dev_matrix value {}".
486
+ format(remained_opt_shard_size, opt_sharding_size,
487
+ len(dev_matrix) - dim - 1, dev_matrix[dim]))
488
+ remained_opt_shard_size //= opt_sharding_size
489
+ opt_shard_dim.insert(0, dim)
490
+ if remained_opt_shard_size == 1:
491
+ break
492
+ tensor_map_new = list(copy.deepcopy(tensor_map))
493
+ if len(new_dev_matrix) != len(dev_matrix):
494
+ opt_shard_dim = list(map(lambda x: x + 1, opt_shard_dim))
495
+ for index, item in enumerate(tensor_map_new):
496
+ if isinstance(item, (tuple, list)):
497
+ item = list(map(lambda x: new_dev_matrix_map[x] if x >= 0 else x, item))
498
+ tensor_map_new[index] = item
499
+ else:
500
+ if item >= 0:
501
+ tensor_map_new[index] = new_dev_matrix_map[item]
502
+ tensor_shape_new = list(copy.deepcopy(origin_full_tensor_shape))
503
+ tensor_shape_new[0] = tensor_strategy[0]
504
+ first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
505
+ accu_shape = 1
506
+ for i in range(len(opt_shard_dim) - 1):
507
+ opt_sharding_size = new_dev_matrix[opt_shard_dim[i]]
508
+ tensor_shape_new.insert(i + 1, opt_sharding_size)
509
+ accu_shape = accu_shape * opt_sharding_size
510
+ tensor_shape_new.insert(len(opt_shard_dim), first_dim_no_sharding_size // accu_shape)
511
+ for index, r_dim in enumerate(opt_shard_dim):
512
+ tensor_map_new.insert(index + 1, len(new_dev_matrix) - r_dim - 1)
513
+ return list(new_dev_matrix), tensor_map_new, tensor_shape_new
514
+
515
+
391
516
  def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
392
517
  origin_full_tensor_shape):
393
518
  """
@@ -404,6 +529,11 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
404
529
  And the model parallel sharding dim is the right of opt sharding dim, so it would be 0-1-2-3 model parallel sharding
405
530
  then 0-4 optimizer sharding.
406
531
  """
532
+ has_layout = any(isinstance(i, (list, tuple)) for i in tensor_map)
533
+ if has_layout:
534
+ output = _construct_tensor_layout_for_opt_shard_by_layout(dev_matrix, tensor_map, opt_shard_step,
535
+ opt_shard_size, origin_full_tensor_shape)
536
+ return _expand_layout(*output)
407
537
 
408
538
  if opt_shard_step == 0 or opt_shard_size == 0:
409
539
  return dev_matrix, tensor_map, list(origin_full_tensor_shape)
@@ -424,18 +554,8 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
424
554
  format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
425
555
  first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
426
556
  if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
427
- tensor_shape_new = list(origin_full_tensor_shape)
428
- tensor_shape_new[0] = tensor_strategy[0]
429
- accu_shp = 1
430
- for i in range(len(repeated_dim) - 1):
431
- opt_sharding_size = dev_matrix[repeated_dim[i]]
432
- tensor_shape_new.insert(i + 1, opt_sharding_size)
433
- accu_shp = accu_shp * opt_sharding_size
434
- tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
435
- tensor_map_new = list(copy.deepcopy(tensor_map))
436
- for index, r_dim in enumerate(repeated_dim):
437
- tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
438
- return list(dev_matrix), tensor_map_new, tensor_shape_new
557
+ return _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
558
+ tensor_strategy, repeated_dim)
439
559
 
440
560
  full_tensor_shape = list(origin_full_tensor_shape)
441
561
  full_tensor_shape[0] = tensor_strategy[0]
@@ -610,9 +730,13 @@ def _apply_operator(operator_name):
610
730
  """
611
731
  if not isinstance(numpy_data_list, list):
612
732
  raise TypeError("The data_list should be a list.")
733
+ new_numpy_data_list = []
613
734
  for numpy_data in numpy_data_list:
614
- if not isinstance(numpy_data, np.ndarray):
615
- raise TypeError("The data should be a numpy.ndarray.")
735
+ if str(type(numpy_data)) == "<class 'builtins.PySafeSlice'>":
736
+ new_numpy_data_list.append(numpy_data[:])
737
+ else:
738
+ new_numpy_data_list.append(numpy_data)
739
+ numpy_data_list = new_numpy_data_list
616
740
  _check_operator(allgather_op)
617
741
  concat_group = allgather_op[1][:-1]
618
742
  if len(concat_group) != len(numpy_data_list):