mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1217) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +124 -25
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +299 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +182 -68
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +192 -252
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +67 -26
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +20 -7
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_pijit_context.py +190 -0
  46. mindspore/common/_register_for_adapter.py +7 -0
  47. mindspore/common/_register_for_recompute.py +48 -0
  48. mindspore/common/_register_for_tensor.py +10 -10
  49. mindspore/common/_stub_tensor.py +7 -1
  50. mindspore/common/_tensor_overload.py +139 -0
  51. mindspore/common/_utils.py +5 -17
  52. mindspore/common/api.py +449 -129
  53. mindspore/common/auto_dynamic_shape.py +27 -14
  54. mindspore/common/dtype.py +17 -10
  55. mindspore/common/dump.py +8 -11
  56. mindspore/common/file_system.py +48 -0
  57. mindspore/common/generator.py +254 -0
  58. mindspore/common/hook_handle.py +65 -30
  59. mindspore/common/initializer.py +1 -1
  60. mindspore/common/jit_config.py +34 -14
  61. mindspore/common/lazy_inline.py +72 -19
  62. mindspore/common/mindir_util.py +12 -2
  63. mindspore/common/mutable.py +79 -14
  64. mindspore/common/no_inline.py +54 -0
  65. mindspore/common/np_dtype.py +25 -0
  66. mindspore/common/parameter.py +73 -21
  67. mindspore/common/recompute.py +292 -0
  68. mindspore/common/seed.py +9 -9
  69. mindspore/common/sparse_tensor.py +276 -24
  70. mindspore/common/symbol.py +122 -0
  71. mindspore/common/tensor.py +668 -514
  72. mindspore/communication/__init__.py +6 -11
  73. mindspore/communication/_comm_helper.py +43 -3
  74. mindspore/communication/comm_func.py +1395 -0
  75. mindspore/communication/management.py +117 -104
  76. mindspore/config/op_info.config +22 -54
  77. mindspore/context.py +455 -71
  78. mindspore/dataset/__init__.py +5 -5
  79. mindspore/dataset/audio/__init__.py +6 -6
  80. mindspore/dataset/audio/transforms.py +711 -158
  81. mindspore/dataset/callback/ds_callback.py +2 -2
  82. mindspore/dataset/core/config.py +7 -0
  83. mindspore/dataset/core/validator_helpers.py +7 -0
  84. mindspore/dataset/engine/cache_client.py +2 -2
  85. mindspore/dataset/engine/datasets.py +201 -116
  86. mindspore/dataset/engine/datasets_audio.py +14 -14
  87. mindspore/dataset/engine/datasets_standard_format.py +83 -3
  88. mindspore/dataset/engine/datasets_text.py +39 -39
  89. mindspore/dataset/engine/datasets_user_defined.py +230 -141
  90. mindspore/dataset/engine/datasets_vision.py +78 -74
  91. mindspore/dataset/engine/iterators.py +29 -0
  92. mindspore/dataset/engine/obs/util.py +7 -0
  93. mindspore/dataset/engine/offload.py +5 -7
  94. mindspore/dataset/engine/queue.py +138 -66
  95. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  96. mindspore/dataset/engine/validators.py +41 -15
  97. mindspore/dataset/text/__init__.py +2 -5
  98. mindspore/dataset/text/transforms.py +408 -121
  99. mindspore/dataset/text/utils.py +9 -9
  100. mindspore/dataset/transforms/__init__.py +0 -3
  101. mindspore/dataset/transforms/transforms.py +261 -76
  102. mindspore/dataset/utils/browse_dataset.py +9 -9
  103. mindspore/dataset/utils/line_reader.py +2 -0
  104. mindspore/dataset/vision/__init__.py +7 -10
  105. mindspore/dataset/vision/c_transforms.py +10 -10
  106. mindspore/dataset/vision/py_transforms_util.py +1 -1
  107. mindspore/dataset/vision/transforms.py +2844 -549
  108. mindspore/dataset/vision/utils.py +161 -10
  109. mindspore/dataset/vision/validators.py +16 -3
  110. mindspore/dnnl.dll +0 -0
  111. mindspore/dpcmi.dll +0 -0
  112. mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
  113. mindspore/experimental/es/embedding_service.py +883 -0
  114. mindspore/experimental/es/embedding_service_layer.py +581 -0
  115. mindspore/experimental/llm_boost/__init__.py +21 -0
  116. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  119. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  120. mindspore/experimental/llm_boost/register.py +129 -0
  121. mindspore/experimental/llm_boost/utils.py +31 -0
  122. mindspore/experimental/optim/__init__.py +12 -2
  123. mindspore/experimental/optim/adadelta.py +161 -0
  124. mindspore/experimental/optim/adagrad.py +168 -0
  125. mindspore/experimental/optim/adam.py +35 -34
  126. mindspore/experimental/optim/adamax.py +170 -0
  127. mindspore/experimental/optim/adamw.py +124 -15
  128. mindspore/experimental/optim/asgd.py +153 -0
  129. mindspore/experimental/optim/lr_scheduler.py +66 -121
  130. mindspore/experimental/optim/nadam.py +157 -0
  131. mindspore/experimental/optim/optimizer.py +18 -8
  132. mindspore/experimental/optim/radam.py +194 -0
  133. mindspore/experimental/optim/rmsprop.py +154 -0
  134. mindspore/experimental/optim/rprop.py +164 -0
  135. mindspore/experimental/optim/sgd.py +28 -19
  136. mindspore/hal/__init__.py +40 -0
  137. mindspore/hal/_ascend.py +57 -0
  138. mindspore/hal/_base.py +57 -0
  139. mindspore/hal/_cpu.py +56 -0
  140. mindspore/hal/_gpu.py +57 -0
  141. mindspore/hal/contiguous_tensors_handle.py +175 -0
  142. mindspore/hal/device.py +356 -0
  143. mindspore/hal/event.py +179 -0
  144. mindspore/hal/memory.py +326 -0
  145. mindspore/hal/stream.py +357 -0
  146. mindspore/include/api/data_type.h +2 -2
  147. mindspore/include/api/dual_abi_helper.h +16 -3
  148. mindspore/include/api/model.h +4 -3
  149. mindspore/include/api/model_group.h +13 -1
  150. mindspore/include/api/status.h +14 -0
  151. mindspore/include/api/types.h +10 -10
  152. mindspore/include/c_api/model_c.h +173 -0
  153. mindspore/include/c_api/types_c.h +19 -0
  154. mindspore/include/dataset/config.h +2 -2
  155. mindspore/include/dataset/constants.h +2 -2
  156. mindspore/include/dataset/execute.h +3 -5
  157. mindspore/include/dataset/vision.h +58 -2
  158. mindspore/jpeg62.dll +0 -0
  159. mindspore/log.py +3 -3
  160. mindspore/mindrecord/__init__.py +5 -1
  161. mindspore/mindrecord/config.py +809 -0
  162. mindspore/mindrecord/filereader.py +25 -0
  163. mindspore/mindrecord/filewriter.py +138 -103
  164. mindspore/mindrecord/mindpage.py +40 -6
  165. mindspore/mindrecord/shardutils.py +3 -2
  166. mindspore/mindrecord/shardwriter.py +7 -0
  167. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  168. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  169. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  170. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  171. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  172. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  173. mindspore/mindspore_backend.dll +0 -0
  174. mindspore/mindspore_common.dll +0 -0
  175. mindspore/mindspore_core.dll +0 -0
  176. mindspore/mindspore_glog.dll +0 -0
  177. mindspore/mindspore_np_dtype.dll +0 -0
  178. mindspore/mindspore_ops.dll +0 -0
  179. mindspore/mint/__init__.py +1586 -0
  180. mindspore/mint/distributed/__init__.py +31 -0
  181. mindspore/mint/distributed/distributed.py +254 -0
  182. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  183. mindspore/mint/nn/__init__.py +757 -0
  184. mindspore/mint/nn/functional.py +679 -0
  185. mindspore/mint/nn/layer/__init__.py +39 -0
  186. mindspore/mint/nn/layer/activation.py +133 -0
  187. mindspore/mint/nn/layer/normalization.py +477 -0
  188. mindspore/mint/nn/layer/pooling.py +110 -0
  189. mindspore/mint/optim/__init__.py +24 -0
  190. mindspore/mint/optim/adamw.py +206 -0
  191. mindspore/mint/special/__init__.py +63 -0
  192. mindspore/msobj140.dll +0 -0
  193. mindspore/mspdb140.dll +0 -0
  194. mindspore/mspdbcore.dll +0 -0
  195. mindspore/mspdbst.dll +0 -0
  196. mindspore/mspft140.dll +0 -0
  197. mindspore/msvcdis140.dll +0 -0
  198. mindspore/msvcp140_1.dll +0 -0
  199. mindspore/msvcp140_2.dll +0 -0
  200. mindspore/msvcp140_atomic_wait.dll +0 -0
  201. mindspore/msvcp140_codecvt_ids.dll +0 -0
  202. mindspore/multiprocessing/__init__.py +73 -0
  203. mindspore/nn/cell.py +461 -323
  204. mindspore/nn/dynamic_lr.py +2 -2
  205. mindspore/nn/layer/activation.py +292 -135
  206. mindspore/nn/layer/basic.py +288 -83
  207. mindspore/nn/layer/channel_shuffle.py +3 -16
  208. mindspore/nn/layer/container.py +3 -3
  209. mindspore/nn/layer/conv.py +75 -66
  210. mindspore/nn/layer/embedding.py +221 -45
  211. mindspore/nn/layer/image.py +4 -7
  212. mindspore/nn/layer/math.py +1 -1
  213. mindspore/nn/layer/normalization.py +150 -68
  214. mindspore/nn/layer/padding.py +64 -87
  215. mindspore/nn/layer/pooling.py +175 -12
  216. mindspore/nn/layer/rnn_cells.py +6 -16
  217. mindspore/nn/layer/rnns.py +6 -5
  218. mindspore/nn/layer/thor_layer.py +1 -2
  219. mindspore/nn/layer/timedistributed.py +1 -1
  220. mindspore/nn/layer/transformer.py +55 -53
  221. mindspore/nn/learning_rate_schedule.py +6 -5
  222. mindspore/nn/loss/__init__.py +2 -2
  223. mindspore/nn/loss/loss.py +145 -88
  224. mindspore/nn/optim/__init__.py +2 -1
  225. mindspore/nn/optim/ada_grad.py +4 -2
  226. mindspore/nn/optim/adadelta.py +4 -2
  227. mindspore/nn/optim/adafactor.py +1 -1
  228. mindspore/nn/optim/adam.py +102 -181
  229. mindspore/nn/optim/adamax.py +4 -2
  230. mindspore/nn/optim/adasum.py +3 -3
  231. mindspore/nn/optim/asgd.py +4 -2
  232. mindspore/nn/optim/ftrl.py +31 -61
  233. mindspore/nn/optim/lamb.py +5 -3
  234. mindspore/nn/optim/lars.py +2 -2
  235. mindspore/nn/optim/lazyadam.py +6 -4
  236. mindspore/nn/optim/momentum.py +13 -25
  237. mindspore/nn/optim/optimizer.py +6 -3
  238. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  239. mindspore/nn/optim/rmsprop.py +9 -3
  240. mindspore/nn/optim/rprop.py +4 -2
  241. mindspore/nn/optim/sgd.py +5 -3
  242. mindspore/nn/optim/tft_wrapper.py +127 -0
  243. mindspore/nn/optim/thor.py +2 -2
  244. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  245. mindspore/nn/probability/distribution/beta.py +2 -2
  246. mindspore/nn/probability/distribution/categorical.py +4 -6
  247. mindspore/nn/probability/distribution/cauchy.py +2 -2
  248. mindspore/nn/probability/distribution/exponential.py +2 -2
  249. mindspore/nn/probability/distribution/geometric.py +1 -1
  250. mindspore/nn/probability/distribution/gumbel.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +1 -1
  252. mindspore/nn/probability/distribution/poisson.py +2 -2
  253. mindspore/nn/probability/distribution/uniform.py +2 -2
  254. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  255. mindspore/nn/wrap/__init__.py +2 -1
  256. mindspore/nn/wrap/cell_wrapper.py +46 -12
  257. mindspore/nn/wrap/grad_reducer.py +148 -8
  258. mindspore/nn/wrap/loss_scale.py +44 -7
  259. mindspore/numpy/__init__.py +2 -0
  260. mindspore/numpy/array_creations.py +67 -68
  261. mindspore/numpy/array_ops.py +70 -66
  262. mindspore/numpy/dtypes.py +3 -3
  263. mindspore/numpy/fft.py +966 -0
  264. mindspore/numpy/logic_ops.py +11 -10
  265. mindspore/numpy/math_ops.py +147 -152
  266. mindspore/numpy/utils.py +3 -0
  267. mindspore/numpy/utils_const.py +4 -4
  268. mindspore/opencv_core452.dll +0 -0
  269. mindspore/opencv_imgcodecs452.dll +0 -0
  270. mindspore/opencv_imgproc452.dll +0 -0
  271. mindspore/ops/__init__.py +9 -6
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +162 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +147 -59
  295. mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +201 -66
  307. mindspore/ops/composite/math_ops.py +10 -49
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/function/__init__.py +53 -11
  340. mindspore/ops/function/array_func.py +1269 -1821
  341. mindspore/ops/function/clip_func.py +19 -31
  342. mindspore/ops/function/debug_func.py +114 -5
  343. mindspore/ops/function/fft_func.py +44 -0
  344. mindspore/ops/function/grad/grad_func.py +30 -22
  345. mindspore/ops/function/image_func.py +27 -21
  346. mindspore/ops/function/linalg_func.py +35 -68
  347. mindspore/ops/function/math_func.py +1170 -2697
  348. mindspore/ops/function/nn_func.py +2116 -1128
  349. mindspore/ops/function/other_func.py +8 -8
  350. mindspore/ops/function/parameter_func.py +5 -93
  351. mindspore/ops/function/random_func.py +435 -113
  352. mindspore/ops/function/reshard_func.py +104 -0
  353. mindspore/ops/function/sparse_func.py +4 -4
  354. mindspore/ops/function/sparse_unary_func.py +9 -16
  355. mindspore/ops/function/spectral_func.py +1 -1
  356. mindspore/ops/function/vmap_func.py +16 -15
  357. mindspore/ops/functional.py +355 -346
  358. mindspore/ops/op_info_register.py +18 -45
  359. mindspore/ops/operations/__init__.py +38 -24
  360. mindspore/ops/operations/_grad_ops.py +21 -927
  361. mindspore/ops/operations/_infer_ops.py +19 -0
  362. mindspore/ops/operations/_inner_ops.py +173 -607
  363. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/_tensor_array.py +8 -8
  367. mindspore/ops/operations/array_ops.py +106 -2837
  368. mindspore/ops/operations/comm_ops.py +799 -127
  369. mindspore/ops/operations/custom_ops.py +124 -119
  370. mindspore/ops/operations/debug_ops.py +142 -41
  371. mindspore/ops/operations/image_ops.py +1 -217
  372. mindspore/ops/operations/inner_ops.py +5 -40
  373. mindspore/ops/operations/linalg_ops.py +1 -49
  374. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  375. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  376. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  377. mindspore/ops/operations/math_ops.py +666 -4972
  378. mindspore/ops/operations/nn_ops.py +205 -2213
  379. mindspore/ops/operations/other_ops.py +60 -49
  380. mindspore/ops/operations/random_ops.py +50 -54
  381. mindspore/ops/operations/reshard_ops.py +53 -0
  382. mindspore/ops/operations/sparse_ops.py +4 -4
  383. mindspore/ops/primitive.py +216 -103
  384. mindspore/ops_generate/__init__.py +27 -0
  385. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  386. mindspore/ops_generate/arg_handler.py +197 -0
  387. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  388. mindspore/ops_generate/gen_constants.py +36 -0
  389. mindspore/ops_generate/gen_ops.py +1099 -0
  390. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  391. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  392. mindspore/ops_generate/gen_utils.py +209 -0
  393. mindspore/ops_generate/op_proto.py +145 -0
  394. mindspore/ops_generate/pyboost_utils.py +367 -0
  395. mindspore/ops_generate/template.py +261 -0
  396. mindspore/parallel/__init__.py +8 -4
  397. mindspore/parallel/_auto_parallel_context.py +100 -10
  398. mindspore/parallel/_cell_wrapper.py +99 -9
  399. mindspore/parallel/_cost_model_context.py +1 -1
  400. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  401. mindspore/parallel/_parallel_serialization.py +67 -23
  402. mindspore/parallel/_ps_context.py +1 -1
  403. mindspore/parallel/_recovery_context.py +1 -1
  404. mindspore/parallel/_tensor.py +99 -22
  405. mindspore/parallel/_transformer/__init__.py +1 -1
  406. mindspore/parallel/_transformer/layers.py +1 -1
  407. mindspore/parallel/_transformer/loss.py +1 -1
  408. mindspore/parallel/_transformer/moe.py +1 -1
  409. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  410. mindspore/parallel/_transformer/transformer.py +2 -2
  411. mindspore/parallel/_utils.py +173 -6
  412. mindspore/parallel/algo_parameter_config.py +8 -10
  413. mindspore/parallel/checkpoint_transform.py +204 -38
  414. mindspore/parallel/cluster/__init__.py +15 -0
  415. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  416. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  417. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  418. mindspore/parallel/cluster/run.py +136 -0
  419. mindspore/parallel/mpi/__init__.py +1 -1
  420. mindspore/parallel/mpi/_mpi_config.py +1 -1
  421. mindspore/parallel/parameter_broadcast.py +151 -0
  422. mindspore/parallel/shard.py +279 -37
  423. mindspore/parallel/transform_safetensors.py +993 -0
  424. mindspore/pgodb140.dll +0 -0
  425. mindspore/pgort140.dll +0 -0
  426. mindspore/profiler/__init__.py +4 -2
  427. mindspore/profiler/common/constant.py +29 -0
  428. mindspore/profiler/common/process_pool.py +41 -0
  429. mindspore/profiler/common/registry.py +47 -0
  430. mindspore/profiler/common/singleton.py +28 -0
  431. mindspore/profiler/common/util.py +153 -0
  432. mindspore/profiler/dynamic_profiler.py +694 -0
  433. mindspore/profiler/envprofiling.py +18 -20
  434. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  435. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  436. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  437. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  438. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  439. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  440. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  441. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  442. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  443. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  444. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  445. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  446. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  447. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  448. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  449. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  450. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  451. mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
  452. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  453. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  454. mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
  455. mindspore/profiler/parser/base_timeline_generator.py +25 -25
  456. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  457. mindspore/profiler/parser/framework_parser.py +4 -393
  458. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  459. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  460. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  461. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  462. mindspore/profiler/parser/integrator.py +3 -1
  463. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  464. mindspore/profiler/parser/minddata_parser.py +72 -3
  465. mindspore/profiler/parser/profiler_info.py +94 -7
  466. mindspore/profiler/profiler.py +153 -0
  467. mindspore/profiler/profiling.py +631 -508
  468. mindspore/rewrite/__init__.py +2 -14
  469. mindspore/rewrite/api/node.py +122 -36
  470. mindspore/rewrite/api/pattern_engine.py +2 -3
  471. mindspore/rewrite/api/scoped_value.py +16 -15
  472. mindspore/rewrite/api/symbol_tree.py +45 -29
  473. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  474. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  475. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  476. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  477. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  478. mindspore/rewrite/common/__init__.py +1 -2
  479. mindspore/rewrite/common/config.py +24 -0
  480. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  481. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  482. mindspore/rewrite/common/namespace.py +118 -0
  483. mindspore/rewrite/node/__init__.py +5 -5
  484. mindspore/rewrite/node/call_function.py +23 -7
  485. mindspore/rewrite/node/cell_container.py +7 -3
  486. mindspore/rewrite/node/control_flow.py +53 -28
  487. mindspore/rewrite/node/node.py +212 -196
  488. mindspore/rewrite/node/node_manager.py +51 -22
  489. mindspore/rewrite/node/node_topological_manager.py +3 -23
  490. mindspore/rewrite/parsers/__init__.py +12 -0
  491. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  492. mindspore/rewrite/parsers/assign_parser.py +637 -413
  493. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  494. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  495. mindspore/rewrite/parsers/constant_parser.py +5 -5
  496. mindspore/rewrite/parsers/container_parser.py +4 -6
  497. mindspore/rewrite/parsers/expr_parser.py +55 -0
  498. mindspore/rewrite/parsers/for_parser.py +31 -98
  499. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  500. mindspore/rewrite/parsers/if_parser.py +28 -10
  501. mindspore/rewrite/parsers/module_parser.py +8 -182
  502. mindspore/rewrite/parsers/parser.py +1 -5
  503. mindspore/rewrite/parsers/parser_register.py +1 -1
  504. mindspore/rewrite/parsers/return_parser.py +5 -10
  505. mindspore/rewrite/parsers/while_parser.py +59 -0
  506. mindspore/rewrite/sparsify/utils.py +1 -1
  507. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  508. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
  509. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  510. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  511. mindspore/run_check/_check_version.py +40 -115
  512. mindspore/run_check/run_check.py +1 -1
  513. mindspore/safeguard/rewrite_obfuscation.py +597 -263
  514. mindspore/swresample-4.dll +0 -0
  515. mindspore/swscale-6.dll +0 -0
  516. mindspore/tbbmalloc.dll +0 -0
  517. mindspore/tinyxml2.dll +0 -0
  518. mindspore/train/__init__.py +7 -5
  519. mindspore/train/_utils.py +204 -4
  520. mindspore/train/amp.py +335 -295
  521. mindspore/train/anf_ir_pb2.py +14 -2
  522. mindspore/train/callback/__init__.py +5 -2
  523. mindspore/train/callback/_backup_and_restore.py +5 -5
  524. mindspore/train/callback/_callback.py +4 -4
  525. mindspore/train/callback/_checkpoint.py +220 -43
  526. mindspore/train/callback/_cluster_monitor.py +201 -0
  527. mindspore/train/callback/_early_stop.py +2 -2
  528. mindspore/train/callback/_flops_collector.py +239 -0
  529. mindspore/train/callback/_landscape.py +15 -9
  530. mindspore/train/callback/_loss_monitor.py +5 -5
  531. mindspore/train/callback/_on_request_exit.py +136 -33
  532. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  533. mindspore/train/callback/_summary_collector.py +12 -12
  534. mindspore/train/callback/_tft_register.py +352 -0
  535. mindspore/train/callback/_time_monitor.py +3 -3
  536. mindspore/train/data_sink.py +6 -5
  537. mindspore/train/dataset_helper.py +66 -23
  538. mindspore/train/loss_scale_manager.py +2 -2
  539. mindspore/train/metrics/accuracy.py +7 -7
  540. mindspore/train/metrics/confusion_matrix.py +8 -6
  541. mindspore/train/metrics/cosine_similarity.py +6 -4
  542. mindspore/train/metrics/error.py +2 -2
  543. mindspore/train/metrics/metric.py +3 -3
  544. mindspore/train/metrics/perplexity.py +2 -1
  545. mindspore/train/metrics/roc.py +4 -4
  546. mindspore/train/metrics/topk.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +116 -37
  548. mindspore/train/model.py +382 -76
  549. mindspore/train/serialization.py +787 -288
  550. mindspore/train/summary/_summary_adapter.py +1 -1
  551. mindspore/train/summary/summary_record.py +51 -28
  552. mindspore/train/train_thor/convert_utils.py +3 -3
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +21 -0
  555. mindspore/utils/utils.py +60 -0
  556. mindspore/vcmeta.dll +0 -0
  557. mindspore/vcruntime140.dll +0 -0
  558. mindspore/vcruntime140_1.dll +0 -0
  559. mindspore/version.py +1 -1
  560. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
  561. mindspore-2.4.0.dist-info/RECORD +1406 -0
  562. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
  563. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  564. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  565. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  566. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  567. mindspore/gen_ops.py +0 -273
  568. mindspore/include/c_api/ms/abstract.h +0 -67
  569. mindspore/include/c_api/ms/attribute.h +0 -197
  570. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  571. mindspore/include/c_api/ms/base/macros.h +0 -32
  572. mindspore/include/c_api/ms/base/status.h +0 -33
  573. mindspore/include/c_api/ms/base/types.h +0 -282
  574. mindspore/include/c_api/ms/context.h +0 -102
  575. mindspore/include/c_api/ms/graph.h +0 -160
  576. mindspore/include/c_api/ms/node.h +0 -606
  577. mindspore/include/c_api/ms/tensor.h +0 -161
  578. mindspore/include/c_api/ms/value.h +0 -84
  579. mindspore/mindspore_shared_lib.dll +0 -0
  580. mindspore/nn/layer/flash_attention.py +0 -189
  581. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  582. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  583. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  584. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  585. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  586. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  587. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  588. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  589. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  590. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  591. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  592. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  593. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  594. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  595. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  596. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  597. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  601. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  602. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  603. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  604. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  605. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  606. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  607. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  608. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  609. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  610. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  611. mindspore/ops/_op_impl/tbe/add.py +0 -42
  612. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  613. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  614. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  615. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  616. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  617. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  618. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  620. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  621. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  622. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  623. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  624. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  625. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  626. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  627. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  628. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  629. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  631. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  632. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  633. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  634. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  635. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  636. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  637. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  638. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  639. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  640. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  641. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  642. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  643. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  644. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  645. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  646. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  647. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  648. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  649. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  650. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  651. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  653. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  654. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  655. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  656. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  658. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  659. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  660. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  661. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  662. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  663. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  664. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  666. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  667. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  668. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  669. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  670. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  671. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  672. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  673. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  674. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  675. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  676. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  677. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  678. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  680. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  681. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  682. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  683. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  684. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  685. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  686. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  687. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  688. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  689. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  690. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  691. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  692. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  693. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  694. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  695. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  696. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  697. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  698. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  699. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  700. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  701. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  702. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  703. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  705. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  706. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  707. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  709. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  710. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  711. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  712. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  713. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  714. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  715. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  716. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  717. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  718. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  719. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  720. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  721. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  722. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  724. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  725. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  726. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  727. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  728. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  729. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  730. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  731. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  732. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  733. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  734. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  735. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  737. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  738. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  739. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  740. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  741. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  742. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  743. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  744. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  745. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  746. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  747. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  748. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  749. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  750. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  751. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  752. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  754. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  756. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  757. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  758. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  759. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  760. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  761. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  762. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  763. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  764. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  765. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  766. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  767. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  768. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  769. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  770. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  771. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  772. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  773. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  774. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  776. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  777. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  778. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  779. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  780. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  781. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  782. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  783. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  784. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  785. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  786. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  787. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  788. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  789. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  790. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  791. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  792. mindspore/ops/_op_impl/tbe/div.py +0 -41
  793. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  794. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  795. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  797. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  798. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  799. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  800. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  801. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  802. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  803. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  804. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  805. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  806. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  807. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  809. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  810. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  811. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  813. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  814. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  815. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  816. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  817. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  818. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  819. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  820. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  821. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  822. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  823. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  824. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  825. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  826. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  827. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  828. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  829. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  830. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  831. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  832. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  833. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  834. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  835. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  836. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  837. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  838. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  839. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  840. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  841. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  842. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  843. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  844. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  845. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  846. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  847. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  848. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  849. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  850. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  851. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  852. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  853. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  854. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  855. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  856. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  857. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  858. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  859. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  861. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  862. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  863. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  864. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  865. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  866. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  867. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  868. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  869. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  870. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  871. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  872. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  873. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  874. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  875. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  876. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  877. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  878. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  879. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  881. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  883. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  884. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  885. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  886. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  887. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  888. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  889. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  890. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  891. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  892. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  893. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  895. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  896. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  897. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  898. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  899. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  900. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  901. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  902. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  903. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  904. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  905. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  906. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  907. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  908. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  909. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  910. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  911. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  912. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  913. mindspore/ops/_op_impl/tbe/less.py +0 -41
  914. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  915. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  916. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  917. mindspore/ops/_op_impl/tbe/log.py +0 -40
  918. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  919. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  921. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  922. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  923. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  924. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  925. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  926. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  928. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  929. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  930. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  931. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  932. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  933. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  934. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  935. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  936. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  937. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  938. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  939. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  940. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  941. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  942. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  943. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  944. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  945. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  946. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  947. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  948. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  949. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  950. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  951. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  952. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  953. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  954. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  955. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  956. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  957. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  958. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  959. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  960. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  962. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  963. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  964. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  965. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  966. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  967. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  968. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  969. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  971. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  972. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  974. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  975. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  976. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  977. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  978. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  980. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  981. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  982. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  983. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  984. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  985. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  986. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  987. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  988. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  989. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  990. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  991. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  992. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  993. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  994. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  995. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  996. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  997. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  998. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  999. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  1000. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  1002. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  1003. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  1004. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  1005. mindspore/ops/_op_impl/tbe/range.py +0 -39
  1006. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  1007. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  1008. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  1009. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  1013. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  1014. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  1017. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  1018. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  1020. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  1022. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  1023. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  1024. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  1026. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  1027. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  1029. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  1030. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  1032. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  1033. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  1034. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  1035. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  1036. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  1037. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  1038. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  1039. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1041. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1042. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1043. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1045. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1051. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1052. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1053. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1054. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1055. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1060. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1061. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1062. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1063. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1064. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1065. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1066. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1067. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1068. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1069. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1070. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1071. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1072. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1073. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1074. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1075. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1076. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1077. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1078. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1079. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1080. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1083. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1084. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1085. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1086. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1087. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1088. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1089. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1090. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1092. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1093. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1097. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1099. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1100. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1101. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1102. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1103. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1104. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1106. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1107. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1110. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1111. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1115. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1118. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1119. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1122. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1123. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1124. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1125. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1126. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1127. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1128. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1129. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1130. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1132. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1133. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1134. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1135. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1136. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1137. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1138. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1143. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1144. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1145. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1146. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1147. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1148. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1149. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1151. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1152. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1154. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1155. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1156. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1157. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1158. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1159. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1160. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1161. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1162. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1163. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1164. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1165. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1166. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1167. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1168. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1169. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1170. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1171. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1172. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1173. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1174. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1175. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1176. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1177. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1178. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1179. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1180. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1181. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1182. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1183. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1184. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1185. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1186. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1187. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1188. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1189. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1190. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1191. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1192. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1193. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1194. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1195. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1196. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1197. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1198. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1199. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1200. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1201. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1202. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1203. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1204. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1205. mindspore/ops/_tracefunc.py +0 -241
  1206. mindspore/ops/arg_dtype_cast.py +0 -54
  1207. mindspore/ops/silent_check.py +0 -162
  1208. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  1209. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  1210. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1211. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1212. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1213. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1214. mindspore/rewrite/namespace.py +0 -53
  1215. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1216. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  1217. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,19 +17,23 @@
17
17
  from __future__ import absolute_import
18
18
  from __future__ import division
19
19
 
20
+ import binascii
20
21
  import copy
21
22
  import json
22
23
  import os
24
+ import re
23
25
  import shutil
24
26
  import stat
25
27
  import threading
26
28
  from threading import Thread, RLock
29
+ from multiprocessing import Process
27
30
  from collections import defaultdict, OrderedDict
28
31
  from io import BytesIO
29
32
 
30
33
  import math
31
34
  import sys
32
35
  import time
36
+ import google
33
37
  import numpy as np
34
38
 
35
39
  from mindspore.train.checkpoint_pb2 import Checkpoint
@@ -50,32 +54,41 @@ from mindspore.common.api import _generate_branch_control_input
50
54
  from mindspore.common.initializer import initializer, One
51
55
  from mindspore.common.parameter import Parameter, _offload_if_config
52
56
  from mindspore.common.tensor import Tensor
57
+ from mindspore._c_expression import Tensor as Tensor_
53
58
  from mindspore.common._utils import is_shape_unknown
59
+ from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
54
60
  from mindspore.communication.management import get_rank, get_group_size
55
61
  from mindspore.experimental import MapParameter
56
- from mindspore.parallel._cell_wrapper import get_allgather_cell
62
+ from mindspore.ops import Cast
63
+ from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
57
64
  from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
58
65
  from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
59
- from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
66
+ from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
67
+ _get_device_num, _is_parallel_mode
68
+ from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
60
69
  from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
61
- _restore_group_info_list
70
+ _restore_group_info_list, _get_param_list_when_first_dim_sharded
62
71
  from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
63
72
  _store_warm_up_ptr_by_tensor_list, _cache_enable
64
- from mindspore.train._utils import read_proto
73
+ from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
74
+ from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
75
+ _extract_pipeline_stage_num
76
+ from mindspore.train._utils import read_proto, get_parameter_redundancy
65
77
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
66
78
  split_mindir, split_dynamic_mindir
79
+ from mindspore.common.generator import Generator
80
+ from safetensors.numpy import save_file
81
+ from safetensors import safe_open
67
82
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
68
- from ..ops.operations import Cast
69
83
 
70
84
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
71
85
  "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
72
86
  "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
73
- "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
87
+ "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
74
88
 
75
89
  tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
76
90
  "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
77
- "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U",
78
- "BFloat16": np.float32}
91
+ "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
79
92
 
80
93
  np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
81
94
 
@@ -95,6 +108,92 @@ INT_64_MAX = 9223372036854775807
95
108
 
96
109
  cpu_cast = Cast().set_device("CPU")
97
110
 
111
+ _ckpt_fs = FileSystem()
112
+
113
+
114
+ def init_ckpt_file_system(fs: FileSystem):
115
+ """Initialize checkpoint file system"""
116
+ if _register_mindio_file_system(fs):
117
+ return
118
+ _register_basic_file_system(fs)
119
+
120
+
121
+ # Initialize checkpoint file system
122
+ init_ckpt_file_system(_ckpt_fs)
123
+
124
+
125
+ def _get_cur_rank_dp(parameter_layout_dict):
126
+ """ Get dp and tp from layout dict. """
127
+ pp_num = _get_auto_parallel_context("pipeline_stages")
128
+ dev_num = _get_device_num()
129
+ global_rank = get_rank()
130
+ pipe_size = dev_num // pp_num
131
+ initial_rank = (global_rank // pipe_size) * pipe_size
132
+ parameter_redundancy_dict = get_parameter_redundancy(
133
+ parameter_layout_dict, initial_rank)
134
+ value_len = sys.maxsize
135
+ min_value = ()
136
+ for key, value in parameter_redundancy_dict.items():
137
+ if "accu_grads" in key or "inputs" in key:
138
+ continue
139
+ for item in value:
140
+ if len(item) < value_len and global_rank in item:
141
+ value_len = len(item)
142
+ min_value = item
143
+ return min_value
144
+
145
+
146
+ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
147
+ """
148
+ Find available checkpoint file path from all backup checkpoint files of current rank.
149
+ It suppose that checkpoint path contains substring 'rank_{rank_id}' which is used to
150
+ distinguish between different path.If cur_ckpt_path doesn't have 'rank_{rank_id}' substring, will return
151
+ cur_ckpt_path itself when cur_ckpt_path is exist, otherwise return None.
152
+
153
+ Note:
154
+ This API must be called after the communication is initialized because the cluster information
155
+ needs to be obtained internally.
156
+
157
+ Args:
158
+ cur_ckpt_path (str): the checkpoint file path which cur rank needs.
159
+ cur_strategy_path (str): strategy file path for current rank.
160
+
161
+ Returns:
162
+ - new_ckpt_file (String), if found available checkpoint file , return it.
163
+ - None, if not found available checkpoint, return None.
164
+
165
+ Examples:
166
+ >>> import mindspore as ms
167
+ >>> from mindspore.communication import init
168
+ >>> from mindspore import get_ckpt_path_with_strategy
169
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
170
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
171
+ >>> init()
172
+ >>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
173
+ >>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
174
+ >>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
175
+ >>> print(ckpt_file_new)
176
+ """
177
+ dp = _get_cur_rank_dp(cur_strategy_path)
178
+ pattern = r'rank_\d+'
179
+ for i in dp:
180
+ new_ckpt_path = re.sub(pattern, f"rank_{str(i)}", cur_ckpt_path)
181
+ if not os.path.isfile(new_ckpt_path):
182
+ continue
183
+ return new_ckpt_path
184
+ return None
185
+
186
+
187
+ class ParamDictFuture:
188
+ def __init__(self, executor, param_dict_future):
189
+ self.executor = executor
190
+ self.param_dict_future = param_dict_future
191
+
192
+ def result(self):
193
+ param_dict = self.param_dict_future.result()
194
+ self.executor.shutdown()
195
+ return param_dict
196
+
98
197
 
99
198
  def _special_process_par(par, new_par):
100
199
  """
@@ -221,53 +320,72 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
221
320
  logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
222
321
 
223
322
 
224
- def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False):
323
+ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
324
+ format="ckpt"):
225
325
  """Execute the process of saving checkpoint into file."""
226
326
  try:
227
327
  with _ckpt_mutex:
328
+ file_name_list = list(os.path.splitext(ckpt_file_name))
329
+ file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
330
+ tmp_name = ''.join(file_name_list)
228
331
  if os.path.exists(ckpt_file_name):
229
332
  os.chmod(ckpt_file_name, stat.S_IWUSR)
230
333
  os.remove(ckpt_file_name)
231
- with open(ckpt_file_name, "ab") as f:
232
- plain_data = None
233
- if enc_key is not None:
234
- plain_data = BytesIO()
235
-
236
- for name, value in data_list.items():
237
- if name == "random_op":
238
- _write_random_seed(name, value, f)
239
- continue
240
- if value[0] == "mapparameter":
241
- _write_mapparameter(name, value, f, map_param_inc)
242
- continue
243
- if value[0] == "offload_parameter":
244
- new_value = value[1:]
245
- if value[3].dtype == mstype.bfloat16:
246
- new_value[2] = cpu_cast(value[3], mstype.float32).asnumpy().reshape(-1)
247
- else:
248
- new_value[2] = value[3].asnumpy().reshape(-1)
249
- _write_parameter_data(name, new_value, f, enc_key, plain_data)
250
- _offload_if_config(value[3])
251
- continue
252
- if value[0] == "BFloat16_tensor":
253
- _write_bfloat16_data(name, value, f, enc_key, plain_data)
254
- continue
255
- if isinstance(value[2], Tensor):
256
- _write_hugeparameter(name, value, f)
257
- continue
258
-
259
- _write_parameter_data(name, value, f, enc_key, plain_data)
260
-
261
- if enc_key is not None:
262
- plain_data.seek(0)
263
- max_block_size = ENCRYPT_BLOCK_SIZE * 1024
264
- block_data = plain_data.read(max_block_size)
265
- while block_data:
266
- f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
334
+ if os.path.exists(tmp_name):
335
+ os.chmod(tmp_name, stat.S_IWUSR)
336
+ os.remove(tmp_name)
337
+ if format == "ckpt":
338
+ with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
339
+ plain_data = None
340
+ if enc_key is not None:
341
+ plain_data = BytesIO()
342
+
343
+ crc_num = 0
344
+ for name, value in data_list.items():
345
+ if name == "random_op":
346
+ _write_random_seed(name, value, f)
347
+ continue
348
+ if value[0] == "mapparameter":
349
+ _write_mapparameter(name, value, f, map_param_inc)
350
+ continue
351
+ if value[0] == "offload_parameter":
352
+ new_value = value[1:]
353
+ new_value[2] = value[3]
354
+ _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
355
+ _offload_if_config(value[3])
356
+ continue
357
+ if value[1] == "str":
358
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
359
+ continue
360
+ if isinstance(value[2], np.ndarray):
361
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
362
+ continue
363
+ if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
364
+ _write_hugeparameter(name, value, f)
365
+ continue
366
+
367
+ crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
368
+
369
+ if enc_key is not None:
370
+ plain_data.seek(0)
371
+ max_block_size = ENCRYPT_BLOCK_SIZE * 1024
267
372
  block_data = plain_data.read(max_block_size)
268
-
269
- os.chmod(ckpt_file_name, stat.S_IRUSR)
270
-
373
+ while block_data:
374
+ f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
375
+ block_data = plain_data.read(max_block_size)
376
+ if crc_check:
377
+ f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
378
+ elif format == "safetensors":
379
+ save_dict = {}
380
+ for name, value in data_list.items():
381
+ save_dict[name] = value[2].asnumpy()
382
+ save_file(save_dict, tmp_name)
383
+ if not os.path.exists(tmp_name):
384
+ logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
385
+ f"simultaneously modified a file.")
386
+ else:
387
+ os.rename(tmp_name, ckpt_file_name)
388
+ os.chmod(ckpt_file_name, stat.S_IRUSR)
271
389
  except BaseException as e:
272
390
  logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
273
391
  "or the disk space is insufficient and so on.", ckpt_file_name)
@@ -286,22 +404,7 @@ def _write_random_seed(name, value, f):
286
404
  f.write(checkpoint_list.SerializeToString())
287
405
 
288
406
 
289
- def _write_bfloat16_data(name, value, f, enc_key, plain_data):
290
- """Write bfloat16 data into protobuf file"""
291
- checkpoint_list = Checkpoint()
292
- param_value = checkpoint_list.value.add()
293
- param_value.tag = name
294
- param_tensor = param_value.tensor
295
- param_tensor.dims.extend(value[1])
296
- param_tensor.tensor_type = value[2]
297
- param_tensor.tensor_content = value[3].get_bytes()
298
- if enc_key is None:
299
- f.write(checkpoint_list.SerializeToString())
300
- else:
301
- plain_data.write(checkpoint_list.SerializeToString())
302
-
303
-
304
- def _write_parameter_data(name, value, f, enc_key, plain_data):
407
+ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
305
408
  """Write parameter data into protobuf file."""
306
409
  data_size = value[2].nbytes / 1024
307
410
  if data_size > SLICE_SIZE:
@@ -320,10 +423,40 @@ def _write_parameter_data(name, value, f, enc_key, plain_data):
320
423
  param_tensor.tensor_content = param_slice.tobytes()
321
424
 
322
425
  if enc_key is None:
323
- f.write(checkpoint_list.SerializeToString())
426
+ output_data = checkpoint_list.SerializeToString()
427
+ if crc_check:
428
+ crc_num = binascii.crc32(output_data, crc_num)
429
+ f.write(output_data)
430
+ else:
431
+ plain_data.write(checkpoint_list.SerializeToString())
432
+
433
+ return crc_num
434
+
435
+
436
+ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
437
+ """Write parameter bytes data into protobuf file."""
438
+ bytes_value = value[2].get_bytes()
439
+ chunk_size = 1024 * SLICE_SIZE
440
+
441
+ for i in range(0, len(bytes_value), chunk_size):
442
+ checkpoint_list = Checkpoint()
443
+ param_value = checkpoint_list.value.add()
444
+ param_value.tag = name
445
+ param_tensor = param_value.tensor
446
+ param_tensor.dims.extend(value[0])
447
+ param_tensor.tensor_type = value[1]
448
+ param_tensor.tensor_content = bytes_value[i:i + chunk_size]
449
+
450
+ if enc_key is None:
451
+ output_data = checkpoint_list.SerializeToString()
452
+ if crc_check:
453
+ crc_num = binascii.crc32(output_data, crc_num)
454
+ f.write(output_data)
324
455
  else:
325
456
  plain_data.write(checkpoint_list.SerializeToString())
326
457
 
458
+ return crc_num
459
+
327
460
 
328
461
  def _write_mapparameter(name, value, f, map_param_inc=False):
329
462
  """Write map parameter into protobuf file."""
@@ -365,8 +498,11 @@ def _write_hugeparameter(name, value, f):
365
498
  offset += numpy_data.shape[0]
366
499
 
367
500
 
368
- def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
501
+ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
369
502
  """Check save_obj and ckpt_file_name for save_checkpoint."""
503
+ if format not in ["safetensors", "ckpt"]:
504
+ raise ValueError(f"For 'save_checkpoint', the format must be "
505
+ f"'safetensors' or 'ckpt', but got {format}.")
370
506
  if not isinstance(save_obj, (nn.Cell, list, dict)):
371
507
  raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
372
508
  "but got {}.".format(type(save_obj)))
@@ -374,20 +510,32 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
374
510
  raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
375
511
  "'ckpt_file_name' must be "
376
512
  "string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
377
- ckpt_file_name = os.path.abspath(ckpt_file_name)
513
+ ckpt_file_name = os.path.realpath(ckpt_file_name)
378
514
  if os.path.isdir(ckpt_file_name):
379
515
  raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
380
516
  "it must be a file name.".format(ckpt_file_name))
381
- if not ckpt_file_name.endswith('.ckpt'):
382
- ckpt_file_name += ".ckpt"
517
+ if not ckpt_file_name.endswith(format):
518
+ ckpt_file_name += f".{format}"
383
519
  return ckpt_file_name
384
520
 
385
521
 
522
+ def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
523
+ global_step_num=None):
524
+ param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
525
+ or map_param_inc or global_step_num is not None)
526
+ if format == "safetensors" and param_not_default:
527
+ raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
528
+
529
+
386
530
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
387
- async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
531
+ async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
532
+ crc_check=False, format="ckpt", **kwargs):
388
533
  r"""
389
534
  Save checkpoint to a specified file.
390
535
 
536
+ Note:
537
+ The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
538
+
391
539
  Args:
392
540
  save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
393
541
  list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
@@ -409,6 +557,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
409
557
  If returns ``True`` , the Parameter that matching the custom condition will be saved.
410
558
  If returns ``False`` , the Parameter that not matching the custom condition will not
411
559
  be saved. Default: ``None`` .
560
+ crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
561
+ result to the file. Default: ``False`` .
562
+ format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
412
563
  kwargs (dict): Configuration options dictionary.
413
564
 
414
565
  Raises:
@@ -420,7 +571,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
420
571
  >>> import mindspore as ms
421
572
  >>>
422
573
  >>> # Define the network structure of LeNet5. Refer to
423
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
574
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
424
575
  >>> net = LeNet5()
425
576
  >>> ms.save_checkpoint(net, "./lenet.ckpt",
426
577
  ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
@@ -440,35 +591,57 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
440
591
 
441
592
  Tutorial Examples:
442
593
  - `Saving and Loading the Model - Saving and Loading the Model Weight
443
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
594
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
444
595
  """
445
- ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
596
+ ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
446
597
  integrated_save = Validator.check_bool(integrated_save)
447
598
  async_save = Validator.check_bool(async_save)
448
599
  append_dict = _check_append_dict(append_dict)
449
600
  enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
450
601
  enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
602
+ crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
451
603
  map_param_inc = kwargs.get('incremental', False)
452
604
  logger.info("Execute the process of saving checkpoint files.")
453
-
454
- save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
605
+ global_step_num = kwargs.get('global_step_num', None)
606
+ _check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
607
+
608
+ if append_dict and "__exception_save__" in append_dict:
609
+ s1 = mindspore.hal.Stream()
610
+ with mindspore.hal.StreamCtx(s1):
611
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
612
+ s1.synchronize()
613
+ else:
614
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
455
615
 
456
616
  if append_dict:
617
+ if "__exception_save__" in append_dict:
618
+ del append_dict["__exception_save__"]
457
619
  append_info_list = []
458
620
  for k_name, value in append_dict.items():
459
- if not isinstance(value, str):
621
+ if isinstance(value, Generator):
622
+ value = value.get_state()
623
+ elif not isinstance(value, str):
460
624
  value = Tensor(value)
461
625
  append_info_list.append({"name": k_name, "data": value})
462
626
  save_obj.extend(append_info_list)
463
627
 
464
628
  data_list = OrderedDict()
629
+ data_list_np = OrderedDict()
465
630
  with _ckpt_mutex:
466
631
  for param in save_obj:
467
632
  if param["name"] == "random_op":
468
- data_list["random_op"] = param["data"]
633
+ if os.getenv("AITURBO") == "1":
634
+ data_list_np["random_op"] = []
635
+ data_list_np["random_op"].append(param["data"])
636
+ if crc_check:
637
+ bytes_value = bytes(data_list_np[key][0])
638
+ data_list_np[key].append(binascii.crc32(bytes_value))
639
+ else:
640
+ data_list["random_op"] = param["data"]
469
641
  continue
470
642
  key = param["name"]
471
643
  data_list[key] = []
644
+ data_list_np[key] = []
472
645
  if isinstance(param["data"], MapParameter):
473
646
  data_list[param["name"]].append("mapparameter")
474
647
  data_list[param["name"]].append(param["data"])
@@ -479,49 +652,48 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
479
652
  elif param["data"][0] == "offload_parameter":
480
653
  data_list[key].append("offload_parameter")
481
654
  _save_param_list_data(data_list, key, param)
482
- elif param["data"][0] == "BFloat16_tensor":
483
- data_list[key].append("BFloat16_tensor")
484
- _save_param_list_data(data_list, key, param)
485
- continue
486
655
 
487
656
  if isinstance(param["data"], str):
488
- data_list[key].append([0])
489
- data_list[key].append('str')
490
- data = np.array(param["data"])
491
- data_list[key].append(data)
657
+ if os.getenv("AITURBO") == "1":
658
+ data_list_np[key].append(np.array(param["data"]))
659
+ if crc_check:
660
+ bytes_value = data_list_np[key][0].tobytes()
661
+ data_list_np[key].append(binascii.crc32(bytes_value))
662
+ else:
663
+ data_list[key].append([0])
664
+ data_list[key].append('str')
665
+ data = np.array(param["data"])
666
+ data_list[key].append(data)
492
667
  else:
493
668
  if isinstance(param["data"], Parameter):
494
669
  param["data"].init_data()
495
- if isinstance(param["data"], Tensor) and param["data"].dtype == mstype.bfloat16:
496
- data_list[key].append("BFloat16_tensor")
497
- dims = []
498
- for dim in param["data"].shape:
499
- dims.append(dim)
500
- data_list[key].append(dims)
501
- data_list[key].append("BFloat16")
502
- data_list[key].append(cpu_cast(param["data"], mstype.float32))
503
- continue
504
- dims = []
505
- if param['data'].shape == ():
506
- dims.append(0)
670
+ if os.getenv("AITURBO") == "1":
671
+ data_list_np[key].append(param["data"].asnumpy())
672
+ if crc_check:
673
+ bytes_value = data_list_np[key][0].tobytes()
674
+ data_list_np[key].append(binascii.crc32(bytes_value))
507
675
  else:
676
+ dims = []
508
677
  for dim in param['data'].shape:
509
678
  dims.append(dim)
510
- data_list[key].append(dims)
511
- tensor_type = str(param["data"].dtype)
512
- data_list[key].append(tensor_type)
513
- if param["data"].dtype == mstype.bfloat16:
514
- data = cpu_cast(param["data"], mstype.float32).asnumpy().reshape(-1)
515
- else:
516
- data = param["data"].asnumpy().reshape(-1)
517
- data_list[key].append(data)
518
-
519
- if async_save:
679
+ data_list[key].append(dims)
680
+ tensor_type = str(param["data"].dtype)
681
+ data_list[key].append(tensor_type)
682
+ data = param["data"]
683
+ data_list[key].append(data)
684
+
685
+ if os.getenv("AITURBO") == "1":
686
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
687
+ ckpt_name = os.path.basename(ckpt_file_name)
688
+ aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
689
+ elif async_save:
520
690
  data_copy = copy.deepcopy(data_list)
521
- thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")
691
+ thr = Thread(target=_exec_save,
692
+ args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
693
+ name="asyn_save_ckpt")
522
694
  thr.start()
523
695
  else:
524
- _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
696
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
525
697
 
526
698
  logger.info("Saving checkpoint process is finished.")
527
699
 
@@ -532,7 +704,21 @@ def _convert_list_to_param_list(save_obj, choice_func):
532
704
  if not save_obj:
533
705
  return param_list
534
706
  if isinstance(save_obj[0], dict):
535
- param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
707
+ for param in save_obj:
708
+ if isinstance(param, dict) and "name" in param and "data" in param:
709
+ if not isinstance(param["name"], str):
710
+ raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
711
+ f"should be string, but got {type(param['name'])}.")
712
+ if not isinstance(param["data"], Tensor):
713
+ raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
714
+ f"should be parameter, but got {type(param['data'])}.")
715
+ if choice_func is not None and not choice_func(param["name"]):
716
+ continue
717
+ each_param = {"name": param["name"], "data": param["data"]}
718
+ param_list.append(each_param)
719
+ else:
720
+ raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
721
+ f"have key values 'name' and 'value', but got {type(param)} and {param}.")
536
722
  else:
537
723
  for param in save_obj:
538
724
  if isinstance(param, Parameter):
@@ -585,6 +771,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
585
771
 
586
772
  def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
587
773
  """Convert nn.Cell to param_list."""
774
+ sync_pipeline_shared_parameters(save_obj)
588
775
  param_list = []
589
776
  parameter_layout_dict = save_obj.parameter_layout_dict
590
777
  if _is_in_auto_parallel_mode() and not parameter_layout_dict:
@@ -597,7 +784,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
597
784
  if phase in save_obj.compile_cache and _executor.has_compiled(phase):
598
785
  random_byte = _executor._graph_executor.get_random_status(phase)
599
786
  param_list.append({"name": "random_op", "data": random_byte})
600
- append_dict.pop("random_op")
787
+ append_dict.pop("random_op")
601
788
  for (key, value) in param_dict.items():
602
789
  each_param = {"name": key}
603
790
  if isinstance(value, MapParameter):
@@ -619,18 +806,16 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
619
806
  param_data.append(param_tensor.shape)
620
807
  param_data.append(str(param_tensor.dtype))
621
808
  param_data.append(value.key)
622
- elif value.data.dtype == mstype.bfloat16:
623
- param_data = ["BFloat16_tensor"]
624
- param_data.append(cpu_cast(value.data, mstype.float32))
625
- param_data.append(value.data.shape)
626
- param_data.append("BFloat16")
627
- param_data.append(value.key)
628
809
  else:
629
- param_data = Tensor(value.data.asnumpy())
810
+ param_data = value.data
811
+ if append_dict and "__exception_save__" in append_dict:
812
+ param_data = Tensor(Tensor_.move_to(value, "CPU", False))
630
813
 
631
814
  # in automatic model parallel scenario, some parameters were split to all the devices,
632
815
  # which should be combined before saving
633
816
  if key in parameter_layout_dict:
817
+ if not append_dict or "__exception_save__" not in append_dict:
818
+ param_data = Tensor(value.data)
634
819
  param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
635
820
  integrated_save)
636
821
 
@@ -670,9 +855,9 @@ def _check_append_dict(append_dict):
670
855
  raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
671
856
  "{}.".format(type(append_dict)))
672
857
  for key, value in append_dict.items():
673
- if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor)):
858
+ if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)):
674
859
  raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
675
- f"value: int, float or bool, but got key: {type(key)}, value: {type(value)}")
860
+ f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}")
676
861
  return append_dict
677
862
 
678
863
 
@@ -699,13 +884,13 @@ def load(file_name, **kwargs):
699
884
  - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
700
885
  - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
701
886
 
702
- - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'.
887
+ - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
703
888
  - For details of using the customized decryption, please check the `tutorial
704
- <https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
889
+ <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
705
890
 
706
891
  - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
707
892
  `obfuscate_model()
708
- <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
893
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
709
894
 
710
895
  Returns:
711
896
  GraphCell, a compiled graph that can executed by `GraphCell`.
@@ -735,7 +920,7 @@ def load(file_name, **kwargs):
735
920
 
736
921
  Tutorial Examples:
737
922
  - `Saving and Loading the Model - Saving and Loading MindIR
738
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
923
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
739
924
  """
740
925
  if not isinstance(file_name, str):
741
926
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
@@ -746,7 +931,7 @@ def load(file_name, **kwargs):
746
931
  if not os.path.exists(file_name):
747
932
  raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
748
933
  "please check whether the 'file_name' is correct.")
749
- file_name = os.path.abspath(file_name)
934
+ file_name = os.path.realpath(file_name)
750
935
 
751
936
  # set customized functions for dynamic obfuscation
752
937
  obfuscated = _check_load_obfuscate(**kwargs)
@@ -776,7 +961,7 @@ def load(file_name, **kwargs):
776
961
  return graph
777
962
 
778
963
 
779
- def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
964
+ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
780
965
  """
781
966
  Auto Split MindIR.
782
967
 
@@ -784,10 +969,10 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
784
969
 
785
970
  Args:
786
971
  file_name (str): MindIR file name.
787
- device_num (int): device number.
788
- rank_id (int): rank id.
789
- dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
790
- sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
972
+ device_num (int): device number. Default: '8'.
973
+ rank_id (int): rank id. Default: '0'.
974
+ dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
975
+ sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
791
976
 
792
977
  Raises:
793
978
  ValueError: MindIR file does not exist or `file_name` is not a string.
@@ -809,7 +994,7 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
809
994
  if not os.path.exists(file_name):
810
995
  raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
811
996
  "please check whether the 'file_name' is correct.")
812
- file_name = os.path.abspath(file_name)
997
+ file_name = os.path.realpath(file_name)
813
998
 
814
999
  logger.info("Execute the process of export and split mindir.")
815
1000
  dynamic = True
@@ -909,13 +1094,14 @@ def obfuscate_model(obf_config, **kwargs):
909
1094
  - customized_func (function): A python function used for customized function mode, which used for control
910
1095
  the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
911
1096
  Reference to 'my_func()' in
912
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
1097
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
913
1098
  This function needs to ensure that its result is constant for any input. Users can refer to opaque
914
1099
  predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
915
1100
  when loading obfuscated model.
916
1101
  - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
917
1102
  structure of obfuscated models corresponding to different random seeds is different. If
918
- `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
1103
+ `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
1104
+ interface when loading
919
1105
  obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
920
1106
  be set, and the latter mode would be applied if both of them are set.
921
1107
 
@@ -923,7 +1109,7 @@ def obfuscate_model(obf_config, **kwargs):
923
1109
 
924
1110
  - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
925
1111
  - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
926
- Option: 'AES-GCM' | 'AES-CBC' | 'SM4-CBC'. Default: 'AES-GCM'.
1112
+ Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
927
1113
 
928
1114
  Raises:
929
1115
  TypeError: If `obf_config` is not a dict.
@@ -934,11 +1120,15 @@ def obfuscate_model(obf_config, **kwargs):
934
1120
  ValueError: If `obf_ratio` is not provided in `obf_config`.
935
1121
  ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
936
1122
  ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
937
- ValueError: If `original_model_path` is not exist or `original_model_path` is not end with '.mindir'.
1123
+ ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
938
1124
 
939
1125
  Examples:
940
1126
  >>> import mindspore as ms
941
1127
  >>> import mindspore.nn as nn
1128
+ >>> import numpy as np
1129
+ >>> # Download ori_net.mindir
1130
+ >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
1131
+ >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
942
1132
  >>> obf_config = {'original_model_path': "./net.mindir",
943
1133
  ... 'save_model_path': "./obf_net",
944
1134
  ... 'model_inputs': [input1, ],
@@ -998,12 +1188,81 @@ def obfuscate_model(obf_config, **kwargs):
998
1188
  obf_net = nn.GraphCell(obf_graph)
999
1189
  if obf_random_seed != 0:
1000
1190
  append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
1001
- model_inputs += [append_y_tensor,]
1191
+ model_inputs += [append_y_tensor]
1002
1192
  export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
1003
1193
 
1004
1194
 
1195
+ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1196
+ dec_mode, crc_check, format):
1197
+ """load parameter into parameter_dict"""
1198
+ ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
1199
+ if format == "safetensors":
1200
+ with safe_open(ckpt_file_name, framework='np') as f:
1201
+ for k in f.keys():
1202
+ parameter_dict[k] = Parameter(f.get_tensor(k))
1203
+ return
1204
+ checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1205
+ try:
1206
+ param_data_list = []
1207
+ map_data_list = [[], [], []]
1208
+ map_shape_list = [0, 0, 0]
1209
+ if specify_prefix:
1210
+ logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
1211
+ "please use `choice_func` instead.")
1212
+ if filter_prefix:
1213
+ logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1214
+ "please use `choice_func` instead.")
1215
+ for element_id, element in enumerate(checkpoint_list.value):
1216
+ if element.tag == "random_op":
1217
+ parameter_dict["random_op"] = element.tensor.tensor_content
1218
+ continue
1219
+ if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1220
+ continue
1221
+ if specify_prefix is None and filter_prefix is None and \
1222
+ choice_func is not None and not choice_func(element.tag):
1223
+ continue
1224
+ if element.tensor.ByteSize() == 0:
1225
+ _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
1226
+ parameter_dict)
1227
+ if element.tag in parameter_dict:
1228
+ map_data_list = [[], [], []]
1229
+ map_shape_list = [0, 0, 0]
1230
+ continue
1231
+ data = element.tensor.tensor_content
1232
+ data_type = element.tensor.tensor_type
1233
+ np_type = tensor_to_np_type.get(data_type)
1234
+ ms_type = tensor_to_ms_type[data_type]
1235
+ if data_type == 'str':
1236
+ str_length = int(len(data) / 4)
1237
+ np_type = np_type + str(str_length)
1238
+ param_data_list.append(data)
1239
+ if (element_id == len(checkpoint_list.value) - 1) or \
1240
+ (element.tag != checkpoint_list.value[element_id + 1].tag):
1241
+ new_data = b"".join(param_data_list)
1242
+ param_data_list.clear()
1243
+ dims = element.tensor.dims
1244
+ if data_type == 'str':
1245
+ str_value = np.frombuffer(new_data, np_type)
1246
+ parameter_dict[element.tag] = str(str_value[0])
1247
+ else:
1248
+ if dims == [0]:
1249
+ dims = []
1250
+ param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
1251
+ parameter = Parameter(param_data, name=element.tag)
1252
+ parameter_dict[element.tag] = parameter
1253
+ _offload_if_config(parameter)
1254
+
1255
+ logger.info("Loading checkpoint files process is finished.")
1256
+
1257
+ except BaseException as e:
1258
+ logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
1259
+ raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
1260
+ "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
1261
+
1262
+
1005
1263
  def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
1006
- dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
1264
+ dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
1265
+ crc_check=False, remove_redundancy=False, format="ckpt"):
1007
1266
  """
1008
1267
  Load checkpoint info from a specified file.
1009
1268
 
@@ -1013,6 +1272,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1013
1272
  - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1014
1273
  `choice_func` is recommended instead.
1015
1274
  And using either of those two args will override `choice_func` at the same time.
1275
+ - When loading a checkpoint that has removed redundancy, the network should be compiled.
1016
1276
 
1017
1277
  Args:
1018
1278
  ckpt_file_name (str): Checkpoint file name.
@@ -1034,6 +1294,11 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1034
1294
  and the return value is a bool. If returns ``True`` , the Parameter
1035
1295
  that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1036
1296
  matches the custom condition will be removed. Default: ``None`` .
1297
+ crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1298
+ remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1299
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1300
+ redundant-free loading is not enabled.
1301
+ format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1037
1302
 
1038
1303
  Returns:
1039
1304
  Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
@@ -1076,83 +1341,42 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1076
1341
 
1077
1342
  Tutorial Examples:
1078
1343
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1079
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1344
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1080
1345
  """
1081
- ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
1082
1346
  specify_prefix = _check_prefix(specify_prefix)
1083
1347
  filter_prefix = _check_prefix(filter_prefix)
1084
1348
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1085
1349
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1350
+ crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1351
+ remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1352
+ _check_format_and_other_params(format, dec_key, dec_mode, crc_check)
1086
1353
  logger.info("Execute the process of loading checkpoint files.")
1087
- checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
1088
1354
 
1089
1355
  parameter_dict = {}
1090
- try:
1091
- param_data_list = []
1092
- map_data_list = [[], [], []]
1093
- map_shape_list = [0, 0, 0]
1094
- if specify_prefix:
1095
- logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
1096
- "please use `choice_func` instead.")
1097
- if filter_prefix:
1098
- logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1099
- "please use `choice_func` instead.")
1100
- for element_id, element in enumerate(checkpoint_list.value):
1101
- if element.tag == "random_op":
1102
- parameter_dict["random_op"] = element.tensor.tensor_content
1103
- continue
1104
- if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1105
- continue
1106
- if specify_prefix is None and filter_prefix is None and \
1107
- choice_func is not None and not choice_func(element.tag):
1108
- continue
1109
- if element.tensor.ByteSize() == 0:
1110
- _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
1111
- if element.tag in parameter_dict:
1112
- map_data_list = [[], [], []]
1113
- map_shape_list = [0, 0, 0]
1114
- continue
1115
- data = element.tensor.tensor_content
1116
- data_type = element.tensor.tensor_type
1117
- np_type = tensor_to_np_type.get(data_type)
1118
- ms_type = tensor_to_ms_type[data_type]
1119
- if data_type == 'str':
1120
- str_length = int(len(data) / 4)
1121
- np_type = np_type + str(str_length)
1122
- if data_type == "BFloat16":
1123
- dims = element.tensor.dims
1124
- param_data = np.frombuffer(data, np_type)
1125
- param_data = param_data.reshape(list(dims))
1126
- parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
1127
- parameter_dict[element.tag] = parameter
1128
- continue
1129
- element_data = np.frombuffer(data, np_type)
1130
- param_data_list.append(element_data)
1131
- if (element_id == len(checkpoint_list.value) - 1) or \
1132
- (element.tag != checkpoint_list.value[element_id + 1].tag):
1133
- new_data = b"".join(param_data_list)
1134
- param_data = np.frombuffer(new_data, np_type)
1135
- param_data_list.clear()
1136
- dims = element.tensor.dims
1137
- if dims == [0] and data_type == 'str':
1138
- parameter_dict[element.tag] = str(element_data[0])
1139
- else:
1140
- if dims == [0] and 'Float' in data_type:
1141
- param_data = float(param_data[0])
1142
- if dims == [0] and 'Int' in data_type:
1143
- param_data = int(param_data[0])
1144
- if dims not in ([0], [1]):
1145
- param_data = param_data.reshape(list(dims))
1146
- parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
1147
- parameter_dict[element.tag] = parameter
1148
- _offload_if_config(parameter)
1149
-
1150
- logger.info("Loading checkpoint files process is finished.")
1151
1356
 
1152
- except BaseException as e:
1153
- logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
1154
- raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
1155
- "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
1357
+ if os.getenv("AITURBO") == "1":
1358
+ rank_id = get_rank()
1359
+ from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
1360
+ ckpt_path = os.path.dirname(ckpt_file_name)
1361
+ ckpt_name = os.path.basename(ckpt_file_name)
1362
+ np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id, crc_check)
1363
+ for key, value in np_dict.items():
1364
+ if crc_check and len(value) != 2:
1365
+ raise ValueError(f"When loading a checkpoint from AITurbo, if CRC check is enabled, "
1366
+ f"the length of the value must be 2, but got {len(value)}.")
1367
+ if isinstance(value, str):
1368
+ if crc_check and value[1] != binascii.crc32(np.array(value[0]).tobytes()):
1369
+ raise ValueError(f"When loading a checkpoint from AITurbo, the value of the string has not "
1370
+ f"passed the CRC check and has been corrupted.")
1371
+ parameter_dict[key] = value[0]
1372
+ else:
1373
+ if crc_check and value[1] != binascii.crc32(value[0].tobytes()):
1374
+ raise ValueError(f"When loading a checkpoint from AITurbo, the value of the parameter has not "
1375
+ f"passed the CRC check and has been corrupted.")
1376
+ parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
1377
+ else:
1378
+ _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1379
+ dec_mode, crc_check, format)
1156
1380
 
1157
1381
  if not parameter_dict:
1158
1382
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
@@ -1161,13 +1385,93 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1161
1385
  if _warm_up_host_cache_enabled(parameter_dict):
1162
1386
  (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
1163
1387
  if net is not None:
1164
- load_param_into_net(net, parameter_dict, strict_load)
1388
+ load_param_into_net(net, parameter_dict, strict_load, remove_redundancy)
1165
1389
  if _warm_up_host_cache_enabled(parameter_dict):
1166
1390
  _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1167
1391
 
1168
1392
  return parameter_dict
1169
1393
 
1170
1394
 
1395
+ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
1396
+ dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
1397
+ """
1398
+ Load checkpoint info from a specified file asyncly.
1399
+
1400
+ .. warning::
1401
+ This is an experimental API that is subject to change or deletion.
1402
+
1403
+ Note:
1404
+ - `specify_prefix` and `filter_prefix` do not affect each other.
1405
+ - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1406
+ - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1407
+ `choice_func` is recommended instead.
1408
+ And using either of those two args will override `choice_func` at the same time.
1409
+
1410
+ Args:
1411
+ ckpt_file_name (str): Checkpoint file name.
1412
+ net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1413
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1414
+ parameter into net when parameter name's suffix in checkpoint file is the
1415
+ same as the parameter in the network. When the types are inconsistent
1416
+ perform type conversion on the parameters of the same type, such as float32
1417
+ to float16. Default: ``False`` .
1418
+ filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1419
+ starting with the `filter_prefix` will not be loaded. Default: ``None`` .
1420
+ dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
1421
+ the decryption is not required. Default: ``None`` .
1422
+ dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
1423
+ the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
1424
+ and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
1425
+ specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1426
+ starting with the specify_prefix will be loaded. Default: ``None`` .
1427
+ choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
1428
+ string, and the return value is a bool. If returns ``True`` , the Parameter
1429
+ that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1430
+ matches the custom condition will be removed. Default: ``None`` .
1431
+
1432
+ Returns:
1433
+ A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
1434
+
1435
+ Raises:
1436
+ ValueError: Checkpoint file's format is incorrect.
1437
+ ValueError: Parameter's dict is None after load checkpoint file.
1438
+ TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
1439
+
1440
+ Examples:
1441
+ >>> import mindspore
1442
+ >>> from mindspore import nn
1443
+ >>> from mindspore.train import Model
1444
+ >>> from mindspore.amp import FixedLossScaleManager
1445
+ >>> from mindspore import context
1446
+ >>> from mindspore import load_checkpoint_async
1447
+ >>> from mindspore import load_param_into_net
1448
+ >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1449
+ >>> # Create the dataset taking MNIST as an example. Refer to
1450
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1451
+ >>> dataset = create_dataset()
1452
+ >>> # Define the network structure of LeNet5. Refer to
1453
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1454
+ >>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
1455
+ >>> net = LeNet5()
1456
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
1457
+ >>> loss_scale_manager = FixedLossScaleManager()
1458
+ >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1459
+ >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
1460
+ ... loss_scale_manager=loss_scale_manager)
1461
+ >>> pd_future = load_checkpoint_async(ckpt_file)
1462
+ >>> model.build(train_dataset=dataset, epoch=2)
1463
+ >>> param_dict = pd_future.result()
1464
+ >>> load_param_into_net(net, param_dict)
1465
+ >>> model.train(2, dataset)
1466
+ >>> print("param dict len: ", len(param_dict), flush=True)
1467
+ """
1468
+ from concurrent.futures import ThreadPoolExecutor
1469
+ executor = ThreadPoolExecutor(max_workers=2)
1470
+ param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
1471
+ dec_key, dec_mode, specify_prefix, choice_func)
1472
+ return ParamDictFuture(executor, param_dict_future)
1473
+
1474
+
1171
1475
  def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1172
1476
  map_shape_list, parameter_dict):
1173
1477
  """load map parameter."""
@@ -1198,17 +1502,20 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1198
1502
  parameter_dict[element.tag] = map_array
1199
1503
 
1200
1504
 
1201
- def _check_ckpt_file_name(ckpt_file_name):
1505
+ def _check_ckpt_file_name(ckpt_file_name, format):
1202
1506
  """Check function load_checkpoint's ckpt_file_name."""
1203
1507
  if not isinstance(ckpt_file_name, str):
1204
1508
  raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
1205
1509
  "but got {}.".format(type(ckpt_file_name)))
1206
1510
 
1207
- if ckpt_file_name[-5:] != ".ckpt":
1208
- raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
1511
+ if format not in ['ckpt', 'safetensors']:
1512
+ raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt' or '.safetensors', please "
1209
1513
  "input the correct 'ckpt_file_name'.")
1514
+ if not ckpt_file_name.endswith(format):
1515
+ raise ValueError(f"For 'load_checkpoint', the checkpoint file format must same with 'format', but got "
1516
+ f"file_name:'{ckpt_file_name}', format:'{format}'")
1210
1517
 
1211
- ckpt_file_name = os.path.abspath(ckpt_file_name)
1518
+ ckpt_file_name = os.path.realpath(ckpt_file_name)
1212
1519
  if not os.path.exists(ckpt_file_name):
1213
1520
  raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
1214
1521
  "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
@@ -1239,17 +1546,28 @@ def _check_prefix(prefix):
1239
1546
  return prefix
1240
1547
 
1241
1548
 
1242
- def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
1549
+ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1243
1550
  """Parse checkpoint protobuf."""
1244
1551
  checkpoint_list = Checkpoint()
1245
1552
  try:
1246
1553
  if dec_key is None:
1247
- with open(ckpt_file_name, "rb") as f:
1554
+ with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1248
1555
  pb_content = f.read()
1249
1556
  else:
1250
1557
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1251
1558
  if pb_content is None:
1252
1559
  raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
1560
+ if crc_check and pb_content[-17:-10] != b"crc_num":
1561
+ logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
1562
+ if pb_content[-17:-10] == b"crc_num":
1563
+ crc_num_bytes = pb_content[-10:]
1564
+ pb_content = pb_content[:-17]
1565
+ if crc_check:
1566
+ crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
1567
+ cal_crc_num = binascii.crc32(pb_content, 0)
1568
+ if cal_crc_num != crc_num:
1569
+ raise ValueError("For 'load_checkpoint', the crc check is failed, "
1570
+ "please check whether the ckpt file is damaged.")
1253
1571
  checkpoint_list.ParseFromString(pb_content)
1254
1572
  except BaseException as e:
1255
1573
  if _is_cipher_file(ckpt_file_name):
@@ -1282,17 +1600,40 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
1282
1600
 
1283
1601
  def _init_parameter_data_in_parallel_mode(net, parameter_dict):
1284
1602
  """In parallel mode, only init the paraemters in ckpt."""
1603
+ is_train_phase = net.phase.startswith('train')
1285
1604
  for _, param in net.parameters_and_names():
1605
+ if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
1606
+ param.shape = tuple(parameter_dict[param.name].shape)
1607
+ continue
1286
1608
  if param.name in parameter_dict and param.has_init:
1287
1609
  logger.warning("{} is not init while load ckpt.".format(param.name))
1288
1610
  new_tensor = param.init_data()
1289
1611
  param._update_tensor_data(new_tensor)
1290
1612
 
1291
1613
 
1292
- def load_param_into_net(net, parameter_dict, strict_load=False):
1614
+ def _check_load_param_into_net(net, parameter_dict):
1615
+ """check load_param_into_net"""
1616
+ if not isinstance(net, nn.Cell):
1617
+ logger.critical("Failed to combine the net and the parameters.")
1618
+ msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1619
+ raise TypeError(msg)
1620
+ if not isinstance(parameter_dict, dict):
1621
+ logger.critical("Failed to combine the net and the parameters.")
1622
+ msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1623
+ "but got {}.".format(type(parameter_dict)))
1624
+ raise TypeError(msg)
1625
+ if "random_op" in parameter_dict.keys():
1626
+ net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1627
+ parameter_dict.pop("random_op")
1628
+
1629
+
1630
+ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
1293
1631
  """
1294
1632
  Load parameters into network, return parameter list that are not loaded in the network.
1295
1633
 
1634
+ Note:
1635
+ - When loading a parameter dict that has removed redundancy, the network should be compiled.
1636
+
1296
1637
  Args:
1297
1638
  net (Cell): The network where the parameters will be loaded.
1298
1639
  parameter_dict (dict): The dictionary generated by load checkpoint file,
@@ -1301,10 +1642,13 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1301
1642
  into net when parameter name's suffix in checkpoint file is the same as the
1302
1643
  parameter in the network. When the types are inconsistent perform type conversion
1303
1644
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1645
+ remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1646
+ Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1647
+ redundant-free loading is not enabled.
1304
1648
 
1305
1649
  Returns:
1306
- param_not_load (List), the parameter name in model which are not loaded into the network.
1307
- ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
1650
+ - param_not_load (List), the parameter name in model which are not loaded into the network.
1651
+ - ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
1308
1652
 
1309
1653
  Raises:
1310
1654
  TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@@ -1313,7 +1657,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1313
1657
  >>> import mindspore as ms
1314
1658
  >>>
1315
1659
  >>> # Define the network structure of LeNet5. Refer to
1316
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1660
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1317
1661
  >>> net = LeNet5()
1318
1662
  >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1319
1663
  >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
@@ -1323,20 +1667,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1323
1667
 
1324
1668
  Tutorial Examples:
1325
1669
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1326
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1670
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1327
1671
  """
1328
- if not isinstance(net, nn.Cell):
1329
- logger.critical("Failed to combine the net and the parameters.")
1330
- msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1331
- raise TypeError(msg)
1332
- if not isinstance(parameter_dict, dict):
1333
- logger.critical("Failed to combine the net and the parameters.")
1334
- msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1335
- "but got {}.".format(type(parameter_dict)))
1336
- raise TypeError(msg)
1337
- if "random_op" in parameter_dict.keys():
1338
- net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1339
- parameter_dict.pop("random_op")
1672
+ _check_load_param_into_net(net, parameter_dict)
1340
1673
  for key, value in parameter_dict.items():
1341
1674
  if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1342
1675
  logger.critical("Load parameters into net failed.")
@@ -1345,8 +1678,11 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1345
1678
  raise TypeError(msg)
1346
1679
 
1347
1680
  strict_load = Validator.check_bool(strict_load)
1681
+ remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
1348
1682
  logger.info("Execute the process of loading parameters into net.")
1349
- if not _is_in_auto_parallel_mode():
1683
+ for _, param in net.parameters_and_names():
1684
+ param.from_ckpt = True
1685
+ if not (_is_in_auto_parallel_mode() or _is_parallel_mode()):
1350
1686
  net.init_parameters_data()
1351
1687
  else:
1352
1688
  _init_parameter_data_in_parallel_mode(net, parameter_dict)
@@ -1360,7 +1696,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1360
1696
  # Add has attr protection when load server checkpoint file on worker.
1361
1697
  if not hasattr(parameter_dict[param.name], "data"):
1362
1698
  continue
1363
- new_param = copy.deepcopy(parameter_dict[param.name])
1699
+ new_param = parameter_dict[param.name]
1364
1700
  _update_param(param, new_param, strict_load)
1365
1701
  ckpt_not_load.remove(param.name)
1366
1702
  else:
@@ -1369,18 +1705,31 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1369
1705
  if param_not_load and not strict_load:
1370
1706
  _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1371
1707
 
1372
- logger.debug("Params not matched(in net but not in parameter_dict):")
1373
- for param_name in param_not_load:
1374
- logger.debug("%s", param_name)
1375
-
1376
1708
  logger.info("Loading parameters into net is finished.")
1377
1709
  if param_not_load:
1378
1710
  logger.warning("For 'load_param_into_net', "
1379
1711
  "{} parameters in the 'net' are not loaded, because they are not in the "
1380
1712
  "'parameter_dict', please check whether the network structure is consistent "
1381
- "when training and loading checkpoint.".format(len(param_not_load)))
1382
- for param_name in param_not_load:
1383
- logger.warning("{} is not loaded.".format(param_name))
1713
+ "when training and loading checkpoint. Another possibility is that "
1714
+ "the redundant loading is not enabled, but the loaded checkpoint is saved with "
1715
+ "redundancy removed. ".format(len(param_not_load)))
1716
+ logger.warning("{} are not loaded.".format(param_not_load))
1717
+ if remove_redundancy:
1718
+ parallel_mode = context.get_auto_parallel_context("parallel_mode")
1719
+ if parallel_mode == "stand_alone":
1720
+ raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1721
+ f"in parallel scenarios, but got {parallel_mode}.")
1722
+ if not net.compile_cache and not net.parameter_layout_dict:
1723
+ raise ValueError("When loading a parameter dict that has removed redundancy, "
1724
+ "the network should be compiled.")
1725
+ param_layout = net.parameter_layout_dict
1726
+ rank_id = get_rank()
1727
+ device_num = _get_device_num()
1728
+ stage_num = _get_auto_parallel_context("pipeline_stages")
1729
+ chunk_size = device_num // stage_num
1730
+ initial_rank = (rank_id // chunk_size) * chunk_size
1731
+ _single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
1732
+
1384
1733
  return param_not_load, ckpt_not_load
1385
1734
 
1386
1735
 
@@ -1486,7 +1835,7 @@ def _save_graph(network, file_name):
1486
1835
  """
1487
1836
  logger.info("Execute the process of saving graph.")
1488
1837
 
1489
- file_name = os.path.abspath(file_name)
1838
+ file_name = os.path.realpath(file_name)
1490
1839
  graph_pb = network.get_func_graph_proto()
1491
1840
  if graph_pb:
1492
1841
  with open(file_name, "wb") as f:
@@ -1494,6 +1843,23 @@ def _save_graph(network, file_name):
1494
1843
  f.write(graph_pb)
1495
1844
 
1496
1845
 
1846
+ def _reshape_tensor(tensor, dst_shape):
1847
+ """reshape tensor to dst shape"""
1848
+ np_tensor = tensor.asnumpy()
1849
+ np_tensor = np_tensor.reshape(dst_shape)
1850
+ return Tensor(np_tensor, tensor.dtype)
1851
+
1852
+
1853
+ def _check_param_for_integrate_save(pipeline_stages, uniform_split):
1854
+ """check whether current settings and parameters are supported in integrated save checkpoint mode"""
1855
+ if pipeline_stages > 1:
1856
+ raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
1857
+ if uniform_split == 0:
1858
+ raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
1859
+ "'integrated_save' to True, the checkpoint will be integrated save, it "
1860
+ "is only supports uniform split tensor now.")
1861
+
1862
+
1497
1863
  def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
1498
1864
  """
1499
1865
  Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
@@ -1507,7 +1873,7 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1507
1873
  Tensor, the combined tensor which with the whole data value.
1508
1874
  """
1509
1875
  layout = parameter_layout_dict[param_name]
1510
- if len(layout) < 6:
1876
+ if len(layout) < 8:
1511
1877
  logger.info("The layout dict does not contain the key %s", param_name)
1512
1878
  return param_data
1513
1879
 
@@ -1515,6 +1881,13 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1515
1881
  tensor_map = layout[1]
1516
1882
  uniform_split = layout[4]
1517
1883
  opt_shard_group = layout[5]
1884
+ before_reshape_slice_shape = layout[2]
1885
+ before_reshape_full_shape = layout[6]
1886
+ after_reshape_slice_shape = layout[7]
1887
+ do_reshape = False
1888
+ if before_reshape_full_shape and after_reshape_slice_shape \
1889
+ and after_reshape_slice_shape != before_reshape_slice_shape:
1890
+ do_reshape = True
1518
1891
 
1519
1892
  allgather_net = None
1520
1893
  mp_weight = False
@@ -1527,26 +1900,26 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1527
1900
  else:
1528
1901
  logger.info("Need to create allgather net for %s", param_name)
1529
1902
  if integrated_save:
1530
- if context.get_auto_parallel_context("pipeline_stages") > 1:
1531
- raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
1532
- if uniform_split == 0:
1533
- raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
1534
- "'integrated_save' to True, the checkpoint will be integrated save, it "
1535
- "is only supports uniform split tensor now.")
1903
+ _check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
1536
1904
  # while any dim is not equal to -1, means param is split and needs to be merged
1537
1905
  # pipeline parallel need to be supported here later
1538
1906
  if mp_weight:
1539
- allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
1907
+ allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
1908
+ tuple(after_reshape_slice_shape))
1540
1909
  object.__setattr__(allgather_net, "keep_input_unchanged", True)
1541
1910
  elif opt_shard_group:
1542
- allgather_net = get_allgather_cell(opt_shard_group, False)
1911
+ allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1912
+ tuple(after_reshape_slice_shape))
1543
1913
  elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
1544
- allgather_net = get_allgather_cell(opt_shard_group, False)
1914
+ allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1915
+ tuple(after_reshape_slice_shape))
1545
1916
  net.parallel_parameter_merge_net_dict[param_name] = allgather_net
1546
1917
  if allgather_net:
1547
1918
  param_data = allgather_net(param_data)
1548
1919
  if mp_weight and integrated_save:
1549
1920
  param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
1921
+ if do_reshape:
1922
+ param_data = _reshape_tensor(param_data, before_reshape_full_shape)
1550
1923
  return param_data
1551
1924
 
1552
1925
 
@@ -1556,7 +1929,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1556
1929
 
1557
1930
  Note:
1558
1931
  1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1559
- 2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
1932
+ 2. When `file_name` does not have a suffix, the system will automatically add one
1933
+ according to the `file_format`.
1560
1934
  3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1561
1935
  4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1562
1936
  class properties in calculations.
@@ -1576,7 +1950,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1576
1950
  - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
1577
1951
  - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
1578
1952
  - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
1579
- for MindSpore models.
1953
+ for MindSpore models. MINDIR does not support operators which have dictionary attribute.
1580
1954
 
1581
1955
  kwargs (dict): Configuration options dictionary.
1582
1956
 
@@ -1586,9 +1960,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1586
1960
  - For 'AIR' and 'ONNX' models, only customized encryption is supported.
1587
1961
  - For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
1588
1962
  or Customized encryption.
1589
- Default: 'AES-GCM'.
1963
+ Default: ``'AES-GCM'``.
1590
1964
  - For details of using the customized encryption, please check the `tutorial
1591
- <https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
1965
+ <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
1592
1966
 
1593
1967
  - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
1594
1968
  preprocessing of the dataset into MindIR.
@@ -1602,32 +1976,49 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1602
1976
  - customized_func (function): A python function used for customized function mode, which used for control
1603
1977
  the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1604
1978
  Reference to 'my_func()' in
1605
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
1979
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1606
1980
  This function needs to ensure that its result is constant for any input. Users can refer to opaque
1607
1981
  predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1608
1982
  obfuscated model.
1609
1983
  - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1610
1984
  structure of obfuscated models corresponding to different random seeds is different. If
1611
- `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
1985
+ `obf_random_seed` is set, then it should be passed
1986
+ to :class:`mindspore.nn.GraphCell` interface when loading
1612
1987
  obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1613
1988
  be set, and the latter mode would be applied if both of them are set.
1614
1989
 
1615
1990
  - incremental (bool): export MindIR incrementally.
1616
1991
 
1992
+ - custom_func (function): Functions for custom defined export policies. This function will be used to
1993
+ customize the model during network export. Currently only support for files with mindir format. The
1994
+ function only accepts one input representing the proto object of the mindir file. When modifying a model,
1995
+ it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
1996
+ failure or functional errors. Default: ``None`` .
1997
+
1617
1998
  Examples:
1618
1999
  >>> import mindspore as ms
1619
2000
  >>> import numpy as np
1620
2001
  >>> from mindspore import Tensor
1621
2002
  >>>
1622
2003
  >>> # Define the network structure of LeNet5. Refer to
1623
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
2004
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1624
2005
  >>> net = LeNet5()
1625
2006
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1626
2007
  >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
2008
+ >>>
2009
+ >>> # Export model in MindIR format and modified the model info using custom_func
2010
+ >>> # The custom_func only support one input representing the Proto object of the model
2011
+ >>> # And custom_func does not support return value
2012
+ >>> def _custom_func(mindir_model):
2013
+ ... mindir_model.producer_name = "test11111"
2014
+ ... mindir_model.producer_version = "11.0"
2015
+ ... mindir_model.user_info["version"] = "11.0"
2016
+ >>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
2017
+
1627
2018
 
1628
2019
  Tutorial Examples:
1629
2020
  - `Saving and Loading the Model - Saving and Loading MindIR
1630
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
2021
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
1631
2022
  """
1632
2023
  old_ms_jit_value = context.get_context("jit_syntax_level")
1633
2024
  context.set_context(jit_syntax_level=mindspore.STRICT)
@@ -1658,7 +2049,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1658
2049
  + str(columns))
1659
2050
  inputs = tuple(inputs_col)
1660
2051
 
1661
- file_name = os.path.abspath(file_name)
2052
+ file_name = os.path.realpath(file_name)
1662
2053
  if 'enc_key' in kwargs.keys():
1663
2054
  kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
1664
2055
  _export(net, file_name, file_format, *inputs, **kwargs)
@@ -1690,7 +2081,7 @@ def _get_funcgraph(net, *inputs):
1690
2081
  >>> from mindspore import Tensor
1691
2082
  >>>
1692
2083
  >>> # Define the network structure of LeNet5. Refer to
1693
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
2084
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1694
2085
  >>> net = LeNet5()
1695
2086
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1696
2087
  >>> ms.get_funcgraph(net, input_tensor)
@@ -1712,6 +2103,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
1712
2103
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
1713
2104
  if "obf_config" in kwargs and file_format != "MINDIR":
1714
2105
  raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
2106
+ if "custom_func" in kwargs and file_format != "MINDIR":
2107
+ raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
1715
2108
  if file_format == 'AIR':
1716
2109
  _save_air(net, file_name, *inputs, **kwargs)
1717
2110
  elif file_format == 'ONNX':
@@ -1749,8 +2142,8 @@ def _save_air(net, file_name, *inputs, **kwargs):
1749
2142
  if os.path.exists(file_name):
1750
2143
  os.chmod(file_name, stat.S_IWUSR)
1751
2144
  if "/" in file_name:
1752
- real_path = os.path.abspath(file_name[:file_name.rfind("/")])
1753
- os.makedirs(real_path, exist_ok=True)
2145
+ real_path = os.path.realpath(file_name[:file_name.rfind("/")])
2146
+ os.makedirs(real_path, mode=0o700, exist_ok=True)
1754
2147
  if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
1755
2148
  _executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
1756
2149
  else:
@@ -1860,24 +2253,24 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1860
2253
  file_prefix = file_name.split("/")[-1]
1861
2254
  if file_prefix.endswith(".mindir"):
1862
2255
  file_prefix = file_prefix[:-7]
1863
- current_path = os.path.abspath(file_name)
2256
+ current_path = os.path.realpath(file_name)
1864
2257
  dirname = os.path.dirname(current_path)
1865
2258
  data_path = os.path.join(dirname, file_prefix + "_variables")
1866
2259
  if os.path.exists(data_path):
1867
2260
  shutil.rmtree(data_path)
1868
- os.makedirs(data_path, exist_ok=True)
2261
+ os.makedirs(data_path, mode=0o700, exist_ok=True)
1869
2262
  os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
1870
2263
  index = 0
1871
2264
  external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
1872
2265
  data_file_name = os.path.join(dirname, external_local)
1873
2266
  f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
1874
2267
  try:
1875
- round_ = 0
2268
+ round = 0
1876
2269
  names = []
1877
2270
  for param_proto in model.graph.parameter:
1878
2271
  name = param_proto.name[param_proto.name.find(":") + 1:]
1879
2272
  names.append((name, param_proto))
1880
- names.sort(key=lambda x: x[0])
2273
+ names.sort(key=lambda x: x[0])
1881
2274
  for pairs in names:
1882
2275
  name = pairs[0]
1883
2276
  param_proto = pairs[1]
@@ -1900,8 +2293,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1900
2293
  offset += (data_length + append_size)
1901
2294
  write_data = _encrypt_data(is_encrypt, write_data, kwargs)
1902
2295
  f.write(write_data)
1903
- round_ += 1
1904
- logger.debug(f"writing {round_}th split data, name:{name}")
2296
+ round += 1
2297
+ logger.debug(f"writing {round}th split data, name:{name}")
1905
2298
 
1906
2299
  graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
1907
2300
  if os.path.exists(graph_file_name):
@@ -1998,6 +2391,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
1998
2391
  dataset = kwargs.get('dataset')
1999
2392
  _save_dataset_to_mindir(model, dataset)
2000
2393
 
2394
+ custom_func = kwargs.get('custom_func', None)
2395
+ if custom_func is not None:
2396
+ custom_func(model)
2397
+
2001
2398
  save_together = _save_together(net_dict, model)
2002
2399
  is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
2003
2400
  if save_together:
@@ -2030,9 +2427,9 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
2030
2427
  "the data of parameter cannot be exported.".format(map_param_proto.name))
2031
2428
  if not file_name.endswith('.mindir'):
2032
2429
  file_name += ".mindir"
2033
- current_path = os.path.abspath(file_name)
2430
+ current_path = os.path.realpath(file_name)
2034
2431
  dirname = os.path.dirname(current_path)
2035
- os.makedirs(dirname, exist_ok=True)
2432
+ os.makedirs(dirname, mode=0o700, exist_ok=True)
2036
2433
  if os.path.exists(file_name):
2037
2434
  os.chmod(file_name, stat.S_IWUSR)
2038
2435
  with open(file_name, 'wb') as f:
@@ -2084,6 +2481,45 @@ def _save_dataset_to_mindir(model, dataset):
2084
2481
  model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
2085
2482
 
2086
2483
 
2484
+ def check_checkpoint(ckpt_file_name):
2485
+ """
2486
+ Check whether the checkpoint is valid.
2487
+
2488
+ Args:
2489
+ ckpt_file_name (str): Checkpoint file name.
2490
+
2491
+ Returns:
2492
+ bool, whether the checkpoint is valid.
2493
+
2494
+ Examples:
2495
+ >>> import mindspore as ms
2496
+ >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
2497
+ >>> check_result = ms.check_checkpoint(ckpt_file_name)
2498
+ >>> print(check_result)
2499
+ True
2500
+ """
2501
+ if not ckpt_file_name.endswith('.ckpt'):
2502
+ return False
2503
+ checkpoint_list = Checkpoint()
2504
+ with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
2505
+ pb_content = f.read()
2506
+ if pb_content[-17:-10] == b"crc_num":
2507
+ crc_num_bytes = pb_content[-10:]
2508
+ pb_content = pb_content[:-17]
2509
+ crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
2510
+ cal_crc_num = binascii.crc32(pb_content, 0)
2511
+ if cal_crc_num != crc_num:
2512
+ logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
2513
+ return False
2514
+ try:
2515
+ checkpoint_list.ParseFromString(pb_content)
2516
+ except google.protobuf.message.DecodeError as e:
2517
+ logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
2518
+ logger.warning(e)
2519
+ return False
2520
+ return True
2521
+
2522
+
2087
2523
  def parse_print(print_file_name):
2088
2524
  """
2089
2525
  Parse data file generated by :class:`mindspore.ops.Print`.
@@ -2122,7 +2558,7 @@ def parse_print(print_file_name):
2122
2558
  [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
2123
2559
  [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
2124
2560
  """
2125
- print_file_path = os.path.abspath(print_file_name)
2561
+ print_file_path = os.path.realpath(print_file_name)
2126
2562
 
2127
2563
  if os.path.getsize(print_file_path) == 0:
2128
2564
  raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
@@ -2411,14 +2847,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
2411
2847
  return merged_parameter
2412
2848
 
2413
2849
 
2414
- def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
2415
- train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
2850
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2851
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2852
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
2416
2853
  """
2417
2854
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2418
2855
 
2419
2856
  Args:
2420
2857
  network (Cell): Network for distributed predication.
2421
- checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
2858
+ checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
2422
2859
  predict_strategy (dict): Strategy of predication process. It means that using one device to predict
2423
2860
  when setting predict_strategy as None. Default: ``None`` .
2424
2861
  train_strategy_filename (str): The filename of training strategy protocol buffer file.
@@ -2428,13 +2865,21 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2428
2865
  in at least one of them. Default: ``None`` .
2429
2866
  strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
2430
2867
  into net when parameter name's suffix in checkpoint file is the same as the
2431
- parameter in the network. When the types are inconsistent perform type conversion
2868
+ parameter in the network. When the types are inconsistent, perform type conversion
2432
2869
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
2433
2870
  dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
2434
2871
  is not required. Default: ``None`` .
2435
2872
  dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
2436
2873
  mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
2437
2874
  Default: ``'AES-GCM'`` .
2875
+ format (str): Input weight format to be loaded into the network.
2876
+ It can be set to either "ckpt" or "safetensors". Default: "ckpt".
2877
+ unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
2878
+ Default: ``None`` .
2879
+ dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
2880
+ rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
2881
+ globally by initializing the network; In save mode, save the file according to the input
2882
+ sequence number. If it is not input, save the entire file.
2438
2883
 
2439
2884
  Raises:
2440
2885
  TypeError: The type of inputs do not match the requirements.
@@ -2449,14 +2894,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2449
2894
 
2450
2895
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2451
2896
  Please see the `rank table startup
2452
- <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
2897
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
2453
2898
  for more details.
2454
2899
 
2455
2900
  For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2456
- <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
2901
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
2457
2902
 
2458
2903
  For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2459
- Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
2904
+ Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
2460
2905
 
2461
2906
  >>> import os
2462
2907
  >>> import numpy as np
@@ -2538,6 +2983,54 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2538
2983
  ...
2539
2984
  [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2540
2985
  """
2986
+ if format not in ['safetensors', 'ckpt']:
2987
+ raise ValueError(
2988
+ f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
2989
+
2990
+ if format == 'safetensors':
2991
+ if unified_safetensors_dir is None:
2992
+ raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
2993
+ f"when format is 'safetensors'.")
2994
+ unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
2995
+ for param in unsupport_param:
2996
+ if param is not None:
2997
+ raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
2998
+ f"when format is 'safetensors'.")
2999
+ if strict_load or dec_mode != 'AES-GCM':
3000
+ raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
3001
+ f"when format is 'safetensors'.")
3002
+ if network is not None:
3003
+ rank_id = get_rank()
3004
+ _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
3005
+ else:
3006
+ if dst_safetensors_dir is None:
3007
+ raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3008
+ f"when network is None.")
3009
+ if rank_id is not None:
3010
+ _load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3011
+ rank_id)
3012
+ else:
3013
+ dst_strategy_dict = _build_searched_strategy(predict_strategy)
3014
+ dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3015
+ dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3016
+ dst_device_num = dst_stage_device_num * dst_stage_num
3017
+ processes = []
3018
+ activate_processes = 0
3019
+ for rank in range(0, dst_device_num):
3020
+ p = Process(target=_load_parallel_checkpoint, args=(
3021
+ unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
3022
+ p.start()
3023
+ processes.append(p)
3024
+ activate_processes += 1
3025
+ max_processes = 64
3026
+ if activate_processes >= max_processes:
3027
+ p = processes.pop(0)
3028
+ p.join()
3029
+ activate_processes -= 1
3030
+ for p in processes:
3031
+ p.join()
3032
+ return
3033
+
2541
3034
  network = Validator.check_isinstance("network", network, nn.Cell)
2542
3035
  _check_checkpoint_file(checkpoint_filenames)
2543
3036
  _check_predict_strategy(predict_strategy)
@@ -2582,17 +3075,24 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2582
3075
  param_rank = rank_list.get(param.name)[0]
2583
3076
  skip_merge_split = rank_list.get(param.name)[1]
2584
3077
  shard_stride = train_strategy.get(param.name)[4]
3078
+ tensor_map = train_strategy.get(param.name)[1]
3079
+ first_dim_shard_idx = tensor_map[0] if tensor_map else -1
3080
+ device_arrangement = train_strategy.get(param.name)[0]
3081
+ first_dim_shard_size = 1
3082
+ if first_dim_shard_idx >= 0:
3083
+ first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
2585
3084
  if train_strategy.get(param.name)[5]:
2586
- shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
3085
+ shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
2587
3086
  else:
2588
3087
  shard_size = 0
2589
3088
  for rank in param_rank:
2590
3089
  param_total_list = list(range(0, ckpt_file_len))
3090
+ if first_dim_shard_size != 1:
3091
+ param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
2591
3092
  if shard_size > 0:
2592
- shard_total_list = []
2593
- for i in range(0, ckpt_file_len, shard_size):
2594
- shard_total_list.append(param_total_list[i:i + shard_size])
2595
- param_total_list = shard_total_list[rank // shard_size]
3093
+ rank_index = param_total_list.index(rank)
3094
+ start = rank_index // shard_size * shard_size
3095
+ param_total_list = param_total_list[start:start + shard_size]
2596
3096
  if shard_stride > 0:
2597
3097
  param_stride = []
2598
3098
  # merge pre parameter
@@ -2722,11 +3222,10 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
2722
3222
  param_name = merged_param.name
2723
3223
  tensor_layout = predict_strategy[param_name]
2724
3224
  rank = get_rank()
2725
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
3225
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
2726
3226
  requires_grad = merged_param.requires_grad
2727
3227
  layerwise_parallel = merged_param.layerwise_parallel
2728
- data_type = merged_param.data.dtype
2729
- if data_type == mstype.bfloat16:
3228
+ if merged_param.data.dtype == mstype.bfloat16:
2730
3229
  split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
2731
3230
  else:
2732
3231
  split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
@@ -2765,7 +3264,7 @@ def _get_mindir_inputs(file_name):
2765
3264
  >>> input_tensor = get_mindir_inputs("lenet.mindir")
2766
3265
  """
2767
3266
  Validator.check_file_name_by_regular(file_name)
2768
- file_name = os.path.abspath(file_name)
3267
+ file_name = os.path.realpath(file_name)
2769
3268
  model = read_proto(file_name)
2770
3269
  input_tensor = []
2771
3270
 
@@ -2794,7 +3293,7 @@ def _get_mindir_inputs(file_name):
2794
3293
 
2795
3294
  def convert_model(mindir_file, convert_file, file_format):
2796
3295
  """
2797
- Convert mindir model to other format model. Current version only support convert to "ONNX" format.
3296
+ Convert mindir model to other format model. The current version only supports conversion to ONNX models.
2798
3297
 
2799
3298
  .. warning::
2800
3299
  This is an experimental API that is subject to change or deletion.