mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +47 -198
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +229 -99
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +480 -372
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +5 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +975 -1981
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +324 -573
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +179 -120
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +798 -761
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +933 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1373 -192
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +53 -42
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +19 -15
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +3 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +361 -359
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  287. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  288. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4351 -3813
  334. mindspore/ops/function/nn_func.py +1712 -637
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +452 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +26 -18
  343. mindspore/ops/functional.py +23 -7
  344. mindspore/ops/functional_overload.py +1548 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +23 -15
  347. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +298 -87
  360. mindspore/ops/operations/debug_ops.py +157 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +212 -531
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1895 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +159 -40
  426. mindspore/parallel/_cell_wrapper.py +132 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +700 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +258 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -59
  446. mindspore/parallel/transform_safetensors.py +364 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +416 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +96 -27
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +269 -136
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +552 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import os
19
19
  import json
20
20
  import numpy as np
21
21
  import mindspore as ms
22
+ from mindspore import _checkparam as Validator
22
23
  from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
23
24
  _get_needed_rank_list_by_layouts, _get_needed_rank_transform_operator_map_by_layouts, \
24
25
  _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
@@ -34,7 +35,12 @@ def _convert_to_list(strategy, rank_id=None):
34
35
  try:
35
36
  layout = strategy.get(param_name)
36
37
  dev_mat = list(layout.dev_matrix[0].dim)
37
- tensor_map = list(layout.tensor_map[0].dim)
38
+ # for layout one axis two slices, layout(("dp", "mp"), "None")
39
+ if len(layout.tensor_map) > 1:
40
+ tensor_map = [list(tensor_map.dim) for tensor_map in layout.tensor_map
41
+ if list(tensor_map.dim)]
42
+ else:
43
+ tensor_map = list(layout.tensor_map[0].dim)
38
44
  param_split_shape = list(layout.param_split_shape[0].dim)
39
45
  field_size = int(layout.field)
40
46
  shard_stride = int(layout.opt_weight_shard_step)
@@ -115,11 +121,15 @@ def _check_strategy_file(strategy_filename):
115
121
  f"be empty. Please check whether the 'strategy_filename' is correct.")
116
122
 
117
123
 
118
- def _load_protobuf_strategy(strategy_filename):
124
+ def _load_protobuf_strategy(strategy_filename, strategy_set=None):
119
125
  """load strategy from protobuf file"""
120
126
  parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
121
127
  with open(strategy_filename, 'rb') as f:
122
128
  pb_content = f.read()
129
+ if strategy_set is not None:
130
+ if pb_content in strategy_set:
131
+ return {}
132
+ strategy_set.add(pb_content)
123
133
  try:
124
134
  parallel_strategy_map.ParseFromString(pb_content)
125
135
  except BaseException as e:
@@ -188,8 +198,11 @@ def _merge_protobuf_strategy(src_strategy_files, dst_strategy_file):
188
198
  """merge protobuf strategy"""
189
199
  dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
190
200
  merged_stage = []
201
+ strategy_set = set()
191
202
  for src_strategy_file in src_strategy_files:
192
- src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file)
203
+ src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file, strategy_set=strategy_set)
204
+ if not src_parallel_strategy_map:
205
+ continue
193
206
  strategy_items = src_parallel_strategy_map.parallel_strategy_item
194
207
  layout_items = src_parallel_strategy_map.parallel_layout_item
195
208
  if not strategy_items or not layout_items:
@@ -339,6 +352,9 @@ def _get_device_num_from_strategy(strategy_file=None):
339
352
  src_strategy = strategy_file
340
353
  strategy_list = _convert_to_list(src_strategy)
341
354
  device_mat = list(strategy_list.values())[0][0]
355
+ if not device_mat:
356
+ raise ValueError("The parallel strategy file only contains pipeline-parallelism, which is not supported for "
357
+ "parallel strategy conversion now.")
342
358
  return np.prod(device_mat)
343
359
 
344
360
 
@@ -407,7 +423,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
407
423
  from_opt_shard_size = 0
408
424
  if src_strategy_list is not None:
409
425
  if param_name not in src_strategy_list:
410
- ms.log.warning("The parameter {} is not in src_strategy.".format(param_name))
426
+ ms.log.info("The parameter {} is not in src_strategy.".format(param_name))
411
427
  continue
412
428
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
413
429
  src_strategy_list.get(param_name))
@@ -417,7 +433,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
417
433
  to_opt_shard_size = 0
418
434
  if dst_strategy_list is not None:
419
435
  if param_name not in dst_strategy_list:
420
- ms.log.warning("The parameter {} is not in dst_strategy.".format(param_name))
436
+ ms.log.info("The parameter {} is not in dst_strategy.".format(param_name))
421
437
  continue
422
438
  to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
423
439
  dst_strategy_list.get(param_name))
@@ -431,6 +447,9 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
431
447
  continue
432
448
  origin_tensor_shape += (item * param_strategy[i],)
433
449
 
450
+ has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
451
+ has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
452
+
434
453
  from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
435
454
  from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
436
455
  to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
@@ -450,6 +469,7 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
450
469
  from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
451
470
  to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
452
471
  _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
472
+ _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
453
473
  transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
454
474
  param_total_dict_copy = param_total_dict[param_name].copy()
455
475
  _apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
@@ -546,6 +566,32 @@ def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
546
566
  param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
547
567
 
548
568
 
569
+ def _insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
570
+ insert_from_reshape, insert_to_reshape):
571
+ """ insert layout expand op reshape """
572
+ from_opt_shard_size = from_info_tuple[0]
573
+ from_dev_matrix = from_info_tuple[1]
574
+ from_tensor_map = from_info_tuple[2]
575
+ from_full_tensor_shape = from_info_tuple[3]
576
+ to_opt_shard_size = to_info_tuple[0]
577
+ to_dev_matrix_origin = to_info_tuple[1]
578
+ to_tensor_map_origin = to_info_tuple[2]
579
+ origin_tensor_shape = to_info_tuple[3]
580
+ for param_rank, _ in param_rank_map.items():
581
+ if from_opt_shard_size == 0 and insert_from_reshape:
582
+ from_slice_tensor_shape = ()
583
+ from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
584
+ for i, item in enumerate(from_full_tensor_shape):
585
+ from_slice_tensor_shape += (item // from_tensor_strategy[i],)
586
+ param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
587
+ if to_opt_shard_size == 0 and insert_to_reshape:
588
+ to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
589
+ to_slice_tensor_shape = ()
590
+ for i, item in enumerate(origin_tensor_shape):
591
+ to_slice_tensor_shape += (item // to_tensor_strategy[i],)
592
+ param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
593
+
594
+
549
595
  def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded_device_index, rank):
550
596
  """Calculate rank list for optimizer parallel when first dim of parameter is sharded by other parallel method"""
551
597
  total_device_num = 1
@@ -559,3 +605,59 @@ def _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_sharded
559
605
  start = rank - offset
560
606
  param_total_list = list(range(start, start + range_size))
561
607
  return param_total_list
608
+
609
+
610
+ def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
611
+ output_format, name_map, return_param_dict):
612
+ """gather transform tasks"""
613
+ tasks = []
614
+ for rank in range(0, dst_device_num):
615
+ tasks.append(
616
+ (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
617
+ return_param_dict))
618
+ return tasks
619
+
620
+
621
+ def _check_checkpoint_file(checkpoint_filenames):
622
+ """Check checkpoint file name."""
623
+ for index, filename in enumerate(checkpoint_filenames):
624
+ if not isinstance(filename, str) or not os.path.exists(filename) \
625
+ or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
626
+ raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
627
+ f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
628
+ f"be a string ending with '.ckpt', and the checkpoint file it represents must "
629
+ f"be exist and not empty.")
630
+
631
+
632
+ def _check_predict_strategy(predict_strategy):
633
+ """Check predict strategy."""
634
+
635
+ def _check_int_list(arg):
636
+ if not isinstance(arg, list):
637
+ return False
638
+ for item in arg:
639
+ if not isinstance(item, int):
640
+ return False
641
+ return True
642
+
643
+ if predict_strategy is None:
644
+ return
645
+
646
+ flag = True
647
+ predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
648
+ for key in predict_strategy.keys():
649
+ if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
650
+ or len(predict_strategy[key]) < 4:
651
+ flag = False
652
+ dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
653
+ if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
654
+ not (_check_int_list(param_split_shape) or not param_split_shape) or \
655
+ not (isinstance(field_size, int) and field_size == 0):
656
+ flag = False
657
+
658
+ if not flag:
659
+ raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
660
+ f"the key of it must be string, and the value of it must be list or tuple that "
661
+ f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
662
+ f"param_split_shape (list[int]) and field_size (int, which value is 0)."
663
+ f"Please check whether 'predict_strategy' is correct.")
@@ -115,7 +115,7 @@ def _set_ps_context(**kwargs):
115
115
  enable_ps (bool): Whether to enable parameter server training mode.
116
116
  Only after enable_ps is set True, the environment variables will be effective.
117
117
  Default: ``False``.
118
- config_file_path (string): Configuration file path used by recovery. Default: ''.
118
+ config_file_path (str): Configuration file path used by recovery. Default: ''.
119
119
  scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
120
120
  enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False``.
121
121
  client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ''.
@@ -33,18 +33,23 @@ def recovery_context():
33
33
  RECOVERY_CONTEXT = RecoveryContext.get_instance()
34
34
  return RECOVERY_CONTEXT
35
35
 
36
+
36
37
  _set_recovery_context_func_map = {
37
38
  "ckpt_path": recovery_context().set_ckpt_path,
38
- "need_reset": recovery_context().set_need_reset
39
+ "need_reset": recovery_context().set_need_reset,
40
+ "is_reboot_node": recovery_context().set_is_reboot_node,
41
+ "is_arf": recovery_context().set_is_arf
39
42
  }
40
43
 
41
44
  _get_recovery_context_func_map = {
42
45
  "enable_recovery": recovery_context().enable_recovery,
46
+ "enable_repeat_register": recovery_context().enable_repeat_register,
43
47
  "latest_ckpt_file": recovery_context().latest_ckpt_file,
44
48
  "latest_ckpt_epoch": recovery_context().latest_ckpt_epoch,
45
49
  "latest_ckpt_step": recovery_context().latest_ckpt_step,
46
50
  "need_reset": recovery_context().need_reset,
47
51
  "recovery_path": recovery_context().recovery_path,
52
+ "is_arf": recovery_context().is_arf,
48
53
  "ckpt_path": recovery_context().ckpt_path
49
54
  }
50
55
 
@@ -64,7 +69,7 @@ def _set_recovery_context(**kwargs):
64
69
  MS_RECOVERY_INTERVAL # The persistent interval for recovery
65
70
 
66
71
  Args:
67
- ckpt_path (string): Set the recovery path used to save checkpoint. Default: ''.
72
+ ckpt_path (str): Set the recovery path used to save checkpoint. Default: ''.
68
73
  need_reset (bool): Set whether should call reset minddata and load ckpt for disaster recovery.
69
74
  Default: ``False``.
70
75
 
@@ -38,10 +38,17 @@ def _get_tensor_strategy(dev_mat, tensor_map):
38
38
  """
39
39
  tensor_strategy = []
40
40
  for dim in tensor_map:
41
- if dim == -1:
42
- tensor_strategy.append(1)
41
+ if isinstance(dim, (tuple, list)):
42
+ acc_stra = 1
43
+ for i in dim:
44
+ if i != -1:
45
+ acc_stra *= dev_mat[len(dev_mat) - i - 1]
46
+ tensor_strategy.append(acc_stra)
43
47
  else:
44
- tensor_strategy.append(dev_mat[-dim - 1])
48
+ if dim == -1:
49
+ tensor_strategy.append(1)
50
+ else:
51
+ tensor_strategy.append(dev_mat[-dim - 1])
45
52
  return tensor_strategy
46
53
 
47
54
 
@@ -182,7 +189,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
182
189
  Args:
183
190
  dev_mat (list): The device matrix of devices.
184
191
  tensor_map (list): The split strategy of tensor.
185
- opt_shard_group(string): The group of optimizer shard
192
+ opt_shard_group(str): The group of optimizer shard
186
193
 
187
194
  Returns:
188
195
  Integer, the slice index for slice on this device.
@@ -388,6 +395,124 @@ def _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
388
395
  return from_tensor_layout, to_tensor_layout
389
396
 
390
397
 
398
+ def _expand_layout(dev_matrix, tensor_map, tensor_shape):
399
+ """
400
+ expand nested tensor_map and reshape tensor shape according to tensor_map
401
+ dev_matrix = [4, 2, 2]
402
+ tensor_map = [[2, 1], 0]
403
+ tensor_shape = [8, 8]
404
+ =>
405
+ expanded_tensor_map = [2, 1, 0]
406
+ expanded_tensor_map = [4, 8/4, 8]
407
+ """
408
+ new_tensor_map = []
409
+ new_tensor_shape = []
410
+ for index, dim in enumerate(tensor_map):
411
+ if isinstance(dim, (tuple, list)):
412
+ accu_shape = 1
413
+ for i in range(len(dim) - 1):
414
+ new_tensor_map.append(dim[i])
415
+ new_tensor_shape.append(dev_matrix[len(dev_matrix) - 1 - dim[i]])
416
+ accu_shape *= dev_matrix[len(dev_matrix) - 1 - dim[i]]
417
+ new_tensor_map.append(dim[-1])
418
+ new_tensor_shape.append(tensor_shape[index] // accu_shape)
419
+ else:
420
+ new_tensor_map.append(dim)
421
+ new_tensor_shape.append(tensor_shape[index])
422
+ return dev_matrix, new_tensor_map, new_tensor_shape
423
+
424
+
425
+ def _construct_tensor_layout_for_opt_shard_by_layout(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
426
+ origin_full_tensor_shape):
427
+ """
428
+ Construct tensor layout for optimizer parallel when using layout.
429
+ For example, For Tensor with shape (4,2)
430
+ dev_matrix = [2, 2, 2, 2]
431
+ tensor_map = [[1, 0], -1]
432
+ opt_shard_size = 2
433
+ ==>
434
+ dev_matrix = [2, 2, 2, 2]
435
+ tensor_map = [[1, 0], 2, -1]
436
+ the new strategy is [4, 2, 1]
437
+ the tensor_shape should reshape to (model_parallel_size, -1, xx, xx)
438
+ first 4 means the model parallel sharding of data_dim
439
+ second 2 means the opt sharding of data_dim.
440
+ """
441
+ if opt_shard_step == 0 or opt_shard_size == 0:
442
+ return dev_matrix, tensor_map, list(origin_full_tensor_shape)
443
+ tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
444
+ repeated_dim = []
445
+ dev_sharded_index = []
446
+ dev_matrix, expanded_tensor_map, _ = _expand_layout(dev_matrix, tensor_map, origin_full_tensor_shape)
447
+ for dim in expanded_tensor_map:
448
+ if dim != -1:
449
+ dev_sharded_index.append(len(dev_matrix) - dim - 1)
450
+ for index, value in enumerate(dev_matrix):
451
+ if index not in dev_sharded_index and value > 1:
452
+ repeated_dim.append(index)
453
+ if not repeated_dim:
454
+ raise ValueError("The device_matrix {} and tensor_map {} cannot sharding opt_shard".
455
+ format(dev_matrix, tensor_map))
456
+ return _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
457
+ tensor_strategy, repeated_dim)
458
+
459
+
460
+ def _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
461
+ tensor_strategy, repeated_dim):
462
+ """
463
+ helper function to assign repeated device_matrix dim for opt shard.
464
+ """
465
+ new_dev_matrix = list(copy.deepcopy(dev_matrix))
466
+ new_dev_matrix_map = list(range(len(dev_matrix)))
467
+ opt_shard_dim = []
468
+ remained_opt_shard_size = opt_shard_size if opt_shard_size != -1 else \
469
+ int(np.prod([dev_matrix[i] for i in repeated_dim]))
470
+ for dim in repeated_dim[::-1]:
471
+ opt_sharding_size = dev_matrix[dim]
472
+ if remained_opt_shard_size // opt_sharding_size == 0:
473
+ if opt_sharding_size % remained_opt_shard_size != 0:
474
+ raise ValueError("dev_matrix value {} at dim {} cannot be divided by needed opt sharding "
475
+ "size {}".format(dev_matrix[dim], len(dev_matrix) - dim - 1,
476
+ remained_opt_shard_size))
477
+ opt_sharding_size = remained_opt_shard_size
478
+ # update dev_matrix
479
+ new_dev_matrix[dim] = dev_matrix[dim] // opt_sharding_size
480
+ new_dev_matrix.insert(dim + 1, opt_sharding_size)
481
+ for i in range(len(dev_matrix) - dim - 1, len(dev_matrix)):
482
+ new_dev_matrix_map[i] += 1
483
+ if remained_opt_shard_size % opt_sharding_size != 0:
484
+ raise ValueError("Remained opt_shard_size {} cannot be divided by current sharding size {}, "
485
+ "the repeat dim is {} with dev_matrix value {}".
486
+ format(remained_opt_shard_size, opt_sharding_size,
487
+ len(dev_matrix) - dim - 1, dev_matrix[dim]))
488
+ remained_opt_shard_size //= opt_sharding_size
489
+ opt_shard_dim.insert(0, dim)
490
+ if remained_opt_shard_size == 1:
491
+ break
492
+ tensor_map_new = list(copy.deepcopy(tensor_map))
493
+ if len(new_dev_matrix) != len(dev_matrix):
494
+ opt_shard_dim = list(map(lambda x: x + 1, opt_shard_dim))
495
+ for index, item in enumerate(tensor_map_new):
496
+ if isinstance(item, (tuple, list)):
497
+ item = list(map(lambda x: new_dev_matrix_map[x] if x >= 0 else x, item))
498
+ tensor_map_new[index] = item
499
+ else:
500
+ if item >= 0:
501
+ tensor_map_new[index] = new_dev_matrix_map[item]
502
+ tensor_shape_new = list(copy.deepcopy(origin_full_tensor_shape))
503
+ tensor_shape_new[0] = tensor_strategy[0]
504
+ first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
505
+ accu_shape = 1
506
+ for i in range(len(opt_shard_dim) - 1):
507
+ opt_sharding_size = new_dev_matrix[opt_shard_dim[i]]
508
+ tensor_shape_new.insert(i + 1, opt_sharding_size)
509
+ accu_shape = accu_shape * opt_sharding_size
510
+ tensor_shape_new.insert(len(opt_shard_dim), first_dim_no_sharding_size // accu_shape)
511
+ for index, r_dim in enumerate(opt_shard_dim):
512
+ tensor_map_new.insert(index + 1, len(new_dev_matrix) - r_dim - 1)
513
+ return list(new_dev_matrix), tensor_map_new, tensor_shape_new
514
+
515
+
391
516
  def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
392
517
  origin_full_tensor_shape):
393
518
  """
@@ -404,6 +529,11 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
404
529
  And the model parallel sharding dim is the right of opt sharding dim, so it would be 0-1-2-3 model parallel sharding
405
530
  then 0-4 optimizer sharding.
406
531
  """
532
+ has_layout = any(isinstance(i, (list, tuple)) for i in tensor_map)
533
+ if has_layout:
534
+ output = _construct_tensor_layout_for_opt_shard_by_layout(dev_matrix, tensor_map, opt_shard_step,
535
+ opt_shard_size, origin_full_tensor_shape)
536
+ return _expand_layout(*output)
407
537
 
408
538
  if opt_shard_step == 0 or opt_shard_size == 0:
409
539
  return dev_matrix, tensor_map, list(origin_full_tensor_shape)
@@ -424,18 +554,8 @@ def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_ste
424
554
  format(opt_shard_step, np.prod(dev_matrix[repeated_dim[0] + 1:])))
425
555
  first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
426
556
  if (len(repeated_dim) < len(dev_matrix) and len(repeated_dim) > 1) or repeated_dim[0] > 0:
427
- tensor_shape_new = list(origin_full_tensor_shape)
428
- tensor_shape_new[0] = tensor_strategy[0]
429
- accu_shp = 1
430
- for i in range(len(repeated_dim) - 1):
431
- opt_sharding_size = dev_matrix[repeated_dim[i]]
432
- tensor_shape_new.insert(i + 1, opt_sharding_size)
433
- accu_shp = accu_shp * opt_sharding_size
434
- tensor_shape_new.insert(len(repeated_dim), first_dim_no_sharding_size // accu_shp)
435
- tensor_map_new = list(copy.deepcopy(tensor_map))
436
- for index, r_dim in enumerate(repeated_dim):
437
- tensor_map_new.insert(index + 1, len(dev_matrix) - r_dim - 1)
438
- return list(dev_matrix), tensor_map_new, tensor_shape_new
557
+ return _construct_tensor_layout_helper(dev_matrix, tensor_map, opt_shard_size, origin_full_tensor_shape,
558
+ tensor_strategy, repeated_dim)
439
559
 
440
560
  full_tensor_shape = list(origin_full_tensor_shape)
441
561
  full_tensor_shape[0] = tensor_strategy[0]
@@ -610,9 +730,13 @@ def _apply_operator(operator_name):
610
730
  """
611
731
  if not isinstance(numpy_data_list, list):
612
732
  raise TypeError("The data_list should be a list.")
733
+ new_numpy_data_list = []
613
734
  for numpy_data in numpy_data_list:
614
- if not isinstance(numpy_data, np.ndarray):
615
- raise TypeError("The data should be a numpy.ndarray.")
735
+ if str(type(numpy_data)) == "<class 'builtins.PySafeSlice'>":
736
+ new_numpy_data_list.append(numpy_data[:])
737
+ else:
738
+ new_numpy_data_list.append(numpy_data)
739
+ numpy_data_list = new_numpy_data_list
616
740
  _check_operator(allgather_op)
617
741
  concat_group = allgather_op[1][:-1]
618
742
  if len(concat_group) != len(numpy_data_list):