mindspore 2.2.11__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1151) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +7 -5
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +76 -18
  7. mindspore/_extends/builtin_operations.py +2 -1
  8. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  9. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  10. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  11. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  14. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  15. mindspore/_extends/parse/__init__.py +18 -14
  16. mindspore/_extends/parse/compile_config.py +258 -0
  17. mindspore/_extends/parse/namespace.py +2 -2
  18. mindspore/_extends/parse/parser.py +174 -62
  19. mindspore/_extends/parse/resources.py +45 -14
  20. mindspore/_extends/parse/standard_method.py +142 -240
  21. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  22. mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
  23. mindspore/_extends/remote/kernel_build_server.py +2 -0
  24. mindspore/_profiler.py +30 -0
  25. mindspore/amp.py +51 -24
  26. mindspore/avcodec-59.dll +0 -0
  27. mindspore/avdevice-59.dll +0 -0
  28. mindspore/avfilter-8.dll +0 -0
  29. mindspore/avformat-59.dll +0 -0
  30. mindspore/avutil-57.dll +0 -0
  31. mindspore/boost/adasum.py +1 -1
  32. mindspore/boost/base.py +1 -1
  33. mindspore/boost/boost_cell_wrapper.py +2 -2
  34. mindspore/boost/grad_freeze.py +2 -2
  35. mindspore/boost/group_loss_scale_manager.py +1 -1
  36. mindspore/boost/less_batch_normalization.py +9 -6
  37. mindspore/common/__init__.py +15 -4
  38. mindspore/common/_jit_fallback_utils.py +2 -3
  39. mindspore/common/_register_for_adapter.py +7 -0
  40. mindspore/common/_register_for_recompute.py +48 -0
  41. mindspore/common/_register_for_tensor.py +8 -9
  42. mindspore/common/_stub_tensor.py +7 -1
  43. mindspore/common/_utils.py +5 -17
  44. mindspore/common/api.py +411 -106
  45. mindspore/common/auto_dynamic_shape.py +27 -14
  46. mindspore/common/dtype.py +17 -10
  47. mindspore/common/dump.py +6 -8
  48. mindspore/common/file_system.py +48 -0
  49. mindspore/common/generator.py +260 -0
  50. mindspore/common/hook_handle.py +51 -4
  51. mindspore/common/initializer.py +1 -1
  52. mindspore/common/jit_config.py +34 -14
  53. mindspore/common/lazy_inline.py +72 -19
  54. mindspore/common/mindir_util.py +12 -2
  55. mindspore/common/mutable.py +79 -14
  56. mindspore/common/no_inline.py +54 -0
  57. mindspore/common/np_dtype.py +25 -0
  58. mindspore/common/parameter.py +30 -11
  59. mindspore/common/recompute.py +262 -0
  60. mindspore/common/seed.py +9 -9
  61. mindspore/common/sparse_tensor.py +272 -24
  62. mindspore/common/symbol.py +122 -0
  63. mindspore/common/tensor.py +468 -496
  64. mindspore/communication/__init__.py +6 -11
  65. mindspore/communication/_comm_helper.py +5 -0
  66. mindspore/communication/comm_func.py +1140 -0
  67. mindspore/communication/management.py +118 -102
  68. mindspore/config/op_info.config +22 -54
  69. mindspore/context.py +378 -65
  70. mindspore/dataset/__init__.py +5 -5
  71. mindspore/dataset/audio/__init__.py +6 -6
  72. mindspore/dataset/audio/transforms.py +711 -158
  73. mindspore/dataset/callback/ds_callback.py +2 -2
  74. mindspore/dataset/engine/cache_client.py +2 -2
  75. mindspore/dataset/engine/datasets.py +163 -83
  76. mindspore/dataset/engine/datasets_audio.py +14 -14
  77. mindspore/dataset/engine/datasets_standard_format.py +33 -3
  78. mindspore/dataset/engine/datasets_text.py +38 -38
  79. mindspore/dataset/engine/datasets_user_defined.py +78 -59
  80. mindspore/dataset/engine/datasets_vision.py +77 -73
  81. mindspore/dataset/engine/offload.py +5 -7
  82. mindspore/dataset/engine/queue.py +56 -38
  83. mindspore/dataset/engine/validators.py +11 -5
  84. mindspore/dataset/text/__init__.py +3 -3
  85. mindspore/dataset/text/transforms.py +408 -121
  86. mindspore/dataset/text/utils.py +9 -9
  87. mindspore/dataset/transforms/__init__.py +1 -1
  88. mindspore/dataset/transforms/transforms.py +261 -76
  89. mindspore/dataset/utils/browse_dataset.py +9 -9
  90. mindspore/dataset/vision/__init__.py +8 -8
  91. mindspore/dataset/vision/c_transforms.py +10 -10
  92. mindspore/dataset/vision/py_transforms_util.py +3 -3
  93. mindspore/dataset/vision/transforms.py +2844 -549
  94. mindspore/dataset/vision/utils.py +161 -10
  95. mindspore/dataset/vision/validators.py +14 -2
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/experimental/optim/__init__.py +12 -2
  98. mindspore/experimental/optim/adadelta.py +161 -0
  99. mindspore/experimental/optim/adagrad.py +168 -0
  100. mindspore/experimental/optim/adam.py +35 -34
  101. mindspore/experimental/optim/adamax.py +170 -0
  102. mindspore/experimental/optim/adamw.py +40 -16
  103. mindspore/experimental/optim/asgd.py +153 -0
  104. mindspore/experimental/optim/lr_scheduler.py +71 -127
  105. mindspore/experimental/optim/nadam.py +157 -0
  106. mindspore/experimental/optim/optimizer.py +15 -8
  107. mindspore/experimental/optim/radam.py +194 -0
  108. mindspore/experimental/optim/rmsprop.py +154 -0
  109. mindspore/experimental/optim/rprop.py +164 -0
  110. mindspore/experimental/optim/sgd.py +28 -19
  111. mindspore/hal/__init__.py +40 -0
  112. mindspore/hal/_ascend.py +57 -0
  113. mindspore/hal/_base.py +57 -0
  114. mindspore/hal/_cpu.py +56 -0
  115. mindspore/hal/_gpu.py +57 -0
  116. mindspore/hal/device.py +356 -0
  117. mindspore/hal/event.py +179 -0
  118. mindspore/hal/memory.py +326 -0
  119. mindspore/hal/stream.py +339 -0
  120. mindspore/include/api/data_type.h +2 -2
  121. mindspore/include/api/dual_abi_helper.h +16 -3
  122. mindspore/include/api/model.h +4 -3
  123. mindspore/include/api/status.h +14 -0
  124. mindspore/include/c_api/model_c.h +173 -0
  125. mindspore/include/c_api/ms/base/types.h +1 -0
  126. mindspore/include/c_api/types_c.h +19 -0
  127. mindspore/include/dataset/execute.h +1 -3
  128. mindspore/include/dataset/vision.h +54 -2
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/log.py +2 -2
  131. mindspore/mindrecord/__init__.py +5 -1
  132. mindspore/mindrecord/config.py +809 -0
  133. mindspore/mindrecord/filereader.py +25 -0
  134. mindspore/mindrecord/filewriter.py +76 -58
  135. mindspore/mindrecord/mindpage.py +40 -6
  136. mindspore/mindrecord/shardutils.py +3 -2
  137. mindspore/mindrecord/shardwriter.py +7 -0
  138. mindspore/mindrecord/tools/cifar100_to_mr.py +53 -66
  139. mindspore/mindrecord/tools/cifar10_to_mr.py +48 -63
  140. mindspore/mindrecord/tools/csv_to_mr.py +7 -17
  141. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  142. mindspore/mindrecord/tools/mnist_to_mr.py +11 -21
  143. mindspore/mindrecord/tools/tfrecord_to_mr.py +2 -10
  144. mindspore/mindspore_backend.dll +0 -0
  145. mindspore/mindspore_common.dll +0 -0
  146. mindspore/mindspore_core.dll +0 -0
  147. mindspore/mindspore_glog.dll +0 -0
  148. mindspore/mindspore_np_dtype.dll +0 -0
  149. mindspore/mindspore_shared_lib.dll +0 -0
  150. mindspore/mint/__init__.py +1137 -0
  151. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  152. mindspore/mint/nn/__init__.py +512 -0
  153. mindspore/mint/nn/functional.py +573 -0
  154. mindspore/mint/optim/__init__.py +24 -0
  155. mindspore/mint/optim/adamw.py +185 -0
  156. mindspore/multiprocessing/__init__.py +72 -0
  157. mindspore/nn/__init__.py +1 -0
  158. mindspore/nn/cell.py +213 -257
  159. mindspore/nn/dynamic_lr.py +2 -2
  160. mindspore/nn/extend/__init__.py +29 -0
  161. mindspore/nn/extend/basic.py +140 -0
  162. mindspore/nn/extend/embedding.py +143 -0
  163. mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
  164. mindspore/nn/extend/layer/normalization.py +109 -0
  165. mindspore/nn/extend/pooling.py +117 -0
  166. mindspore/nn/layer/activation.py +84 -94
  167. mindspore/nn/layer/basic.py +177 -82
  168. mindspore/nn/layer/channel_shuffle.py +3 -16
  169. mindspore/nn/layer/container.py +3 -3
  170. mindspore/nn/layer/conv.py +75 -66
  171. mindspore/nn/layer/embedding.py +103 -45
  172. mindspore/nn/layer/embedding_service.py +531 -0
  173. mindspore/nn/layer/embedding_service_layer.py +393 -0
  174. mindspore/nn/layer/image.py +4 -7
  175. mindspore/nn/layer/math.py +1 -1
  176. mindspore/nn/layer/normalization.py +52 -66
  177. mindspore/nn/layer/padding.py +30 -39
  178. mindspore/nn/layer/pooling.py +18 -9
  179. mindspore/nn/layer/rnn_cells.py +6 -16
  180. mindspore/nn/layer/rnns.py +6 -5
  181. mindspore/nn/layer/thor_layer.py +1 -2
  182. mindspore/nn/layer/timedistributed.py +1 -1
  183. mindspore/nn/layer/transformer.py +52 -50
  184. mindspore/nn/learning_rate_schedule.py +6 -5
  185. mindspore/nn/loss/loss.py +63 -84
  186. mindspore/nn/optim/ada_grad.py +6 -4
  187. mindspore/nn/optim/adadelta.py +3 -1
  188. mindspore/nn/optim/adafactor.py +1 -1
  189. mindspore/nn/optim/adam.py +102 -181
  190. mindspore/nn/optim/adamax.py +4 -2
  191. mindspore/nn/optim/adasum.py +3 -3
  192. mindspore/nn/optim/asgd.py +4 -2
  193. mindspore/nn/optim/ftrl.py +31 -61
  194. mindspore/nn/optim/lamb.py +5 -3
  195. mindspore/nn/optim/lars.py +2 -2
  196. mindspore/nn/optim/lazyadam.py +6 -4
  197. mindspore/nn/optim/momentum.py +13 -25
  198. mindspore/nn/optim/optimizer.py +6 -3
  199. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  200. mindspore/nn/optim/rmsprop.py +9 -3
  201. mindspore/nn/optim/rprop.py +4 -2
  202. mindspore/nn/optim/sgd.py +7 -4
  203. mindspore/nn/optim/thor.py +2 -2
  204. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  205. mindspore/nn/probability/distribution/beta.py +2 -2
  206. mindspore/nn/probability/distribution/categorical.py +4 -6
  207. mindspore/nn/probability/distribution/cauchy.py +2 -2
  208. mindspore/nn/probability/distribution/exponential.py +2 -2
  209. mindspore/nn/probability/distribution/geometric.py +1 -1
  210. mindspore/nn/probability/distribution/gumbel.py +2 -2
  211. mindspore/nn/probability/distribution/logistic.py +1 -1
  212. mindspore/nn/probability/distribution/poisson.py +2 -2
  213. mindspore/nn/probability/distribution/uniform.py +2 -2
  214. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  215. mindspore/nn/wrap/__init__.py +2 -1
  216. mindspore/nn/wrap/cell_wrapper.py +58 -13
  217. mindspore/nn/wrap/grad_reducer.py +148 -8
  218. mindspore/nn/wrap/loss_scale.py +32 -9
  219. mindspore/numpy/__init__.py +2 -0
  220. mindspore/numpy/array_creations.py +2 -0
  221. mindspore/numpy/array_ops.py +6 -6
  222. mindspore/numpy/dtypes.py +3 -3
  223. mindspore/numpy/fft.py +431 -0
  224. mindspore/numpy/math_ops.py +61 -67
  225. mindspore/numpy/utils.py +3 -0
  226. mindspore/opencv_core452.dll +0 -0
  227. mindspore/opencv_imgcodecs452.dll +0 -0
  228. mindspore/opencv_imgproc452.dll +0 -0
  229. mindspore/ops/__init__.py +8 -4
  230. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -160
  231. mindspore/ops/_grad_experimental/grad_comm_ops.py +93 -36
  232. mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
  233. mindspore/ops/_grad_experimental/grad_math_ops.py +92 -287
  234. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  235. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  236. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  237. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  238. mindspore/ops/_op_impl/__init__.py +0 -1
  239. mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
  240. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  241. mindspore/ops/_op_impl/{cpu/concat.py → aicpu/generate_eod_mask.py} +16 -17
  242. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  243. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  244. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  245. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  246. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  247. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  248. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  249. mindspore/ops/_vmap/vmap_array_ops.py +164 -101
  250. mindspore/ops/_vmap/vmap_base.py +8 -1
  251. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  252. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  253. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  254. mindspore/ops/_vmap/vmap_math_ops.py +130 -58
  255. mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
  256. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  257. mindspore/ops/auto_generate/__init__.py +31 -0
  258. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
  259. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
  260. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  261. mindspore/ops/auto_generate/gen_extend_func.py +980 -0
  262. mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
  263. mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
  264. mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
  265. mindspore/ops/composite/__init__.py +5 -2
  266. mindspore/ops/composite/base.py +121 -23
  267. mindspore/ops/composite/math_ops.py +10 -49
  268. mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
  269. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  270. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  271. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  272. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  273. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  274. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  275. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  276. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  277. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  278. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  279. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  280. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  281. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  282. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  283. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  284. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  285. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  286. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  287. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  288. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  289. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  290. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  291. mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
  292. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  293. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  294. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  296. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  297. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  298. mindspore/ops/deprecated.py +14 -3
  299. mindspore/ops/extend/__init__.py +53 -0
  300. mindspore/ops/extend/array_func.py +218 -0
  301. mindspore/ops/extend/math_func.py +76 -0
  302. mindspore/ops/extend/nn_func.py +308 -0
  303. mindspore/ops/function/__init__.py +31 -11
  304. mindspore/ops/function/array_func.py +848 -1736
  305. mindspore/ops/function/clip_func.py +19 -31
  306. mindspore/ops/function/debug_func.py +2 -5
  307. mindspore/ops/function/fft_func.py +31 -0
  308. mindspore/ops/function/grad/grad_func.py +27 -20
  309. mindspore/ops/function/image_func.py +27 -21
  310. mindspore/ops/function/linalg_func.py +30 -53
  311. mindspore/ops/function/math_func.py +916 -2791
  312. mindspore/ops/function/nn_func.py +1445 -889
  313. mindspore/ops/function/other_func.py +6 -7
  314. mindspore/ops/function/parameter_func.py +6 -92
  315. mindspore/ops/function/random_func.py +254 -108
  316. mindspore/ops/function/reshard_func.py +102 -0
  317. mindspore/ops/function/sparse_func.py +4 -4
  318. mindspore/ops/function/sparse_unary_func.py +11 -18
  319. mindspore/ops/function/spectral_func.py +1 -1
  320. mindspore/ops/function/vmap_func.py +15 -14
  321. mindspore/ops/functional.py +342 -343
  322. mindspore/ops/op_info_register.py +16 -43
  323. mindspore/ops/operations/__init__.py +32 -23
  324. mindspore/ops/operations/_embedding_cache_ops.py +1 -1
  325. mindspore/ops/operations/_grad_ops.py +21 -853
  326. mindspore/ops/operations/_infer_ops.py +19 -0
  327. mindspore/ops/operations/_inner_ops.py +155 -511
  328. mindspore/ops/operations/_quant_ops.py +4 -4
  329. mindspore/ops/operations/_rl_inner_ops.py +3 -3
  330. mindspore/ops/operations/_scalar_ops.py +5 -480
  331. mindspore/ops/operations/_sequence_ops.py +6 -36
  332. mindspore/ops/operations/_tensor_array.py +8 -8
  333. mindspore/ops/operations/array_ops.py +112 -2698
  334. mindspore/ops/operations/comm_ops.py +801 -118
  335. mindspore/ops/operations/custom_ops.py +62 -121
  336. mindspore/ops/operations/debug_ops.py +105 -36
  337. mindspore/ops/operations/image_ops.py +3 -219
  338. mindspore/ops/operations/inner_ops.py +54 -40
  339. mindspore/ops/operations/linalg_ops.py +1 -49
  340. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  341. mindspore/ops/operations/manually_defined/_inner.py +61 -0
  342. mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
  343. mindspore/ops/operations/math_ops.py +621 -4654
  344. mindspore/ops/operations/nn_ops.py +316 -2226
  345. mindspore/ops/operations/other_ops.py +53 -45
  346. mindspore/ops/operations/random_ops.py +4 -51
  347. mindspore/ops/operations/reshard_ops.py +53 -0
  348. mindspore/ops/operations/sparse_ops.py +8 -8
  349. mindspore/ops/primitive.py +204 -103
  350. mindspore/ops/silent_check.py +162 -0
  351. mindspore/ops_generate/__init__.py +27 -0
  352. mindspore/ops_generate/arg_dtype_cast.py +250 -0
  353. mindspore/ops_generate/arg_handler.py +197 -0
  354. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  355. mindspore/ops_generate/gen_ops.py +1084 -0
  356. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  357. mindspore/ops_generate/gen_pyboost_func.py +968 -0
  358. mindspore/ops_generate/gen_utils.py +209 -0
  359. mindspore/ops_generate/op_proto.py +138 -0
  360. mindspore/ops_generate/pyboost_utils.py +354 -0
  361. mindspore/ops_generate/template.py +239 -0
  362. mindspore/parallel/__init__.py +7 -4
  363. mindspore/parallel/_auto_parallel_context.py +155 -6
  364. mindspore/parallel/_cell_wrapper.py +16 -9
  365. mindspore/parallel/_cost_model_context.py +1 -1
  366. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  367. mindspore/parallel/_parallel_serialization.py +62 -14
  368. mindspore/parallel/_ps_context.py +1 -1
  369. mindspore/parallel/_recovery_context.py +1 -1
  370. mindspore/parallel/_tensor.py +18 -9
  371. mindspore/parallel/_transformer/__init__.py +1 -1
  372. mindspore/parallel/_transformer/layers.py +1 -1
  373. mindspore/parallel/_transformer/loss.py +1 -1
  374. mindspore/parallel/_transformer/moe.py +1 -1
  375. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  376. mindspore/parallel/_transformer/transformer.py +10 -10
  377. mindspore/parallel/_utils.py +161 -6
  378. mindspore/parallel/algo_parameter_config.py +6 -8
  379. mindspore/parallel/checkpoint_transform.py +369 -64
  380. mindspore/parallel/cluster/__init__.py +15 -0
  381. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  382. mindspore/parallel/cluster/process_entity/_api.py +344 -0
  383. mindspore/parallel/cluster/process_entity/_utils.py +126 -0
  384. mindspore/parallel/cluster/run.py +136 -0
  385. mindspore/parallel/mpi/__init__.py +1 -1
  386. mindspore/parallel/mpi/_mpi_config.py +1 -1
  387. mindspore/parallel/parameter_broadcast.py +152 -0
  388. mindspore/parallel/shard.py +128 -17
  389. mindspore/profiler/__init__.py +3 -2
  390. mindspore/profiler/common/process_pool.py +41 -0
  391. mindspore/profiler/common/singleton.py +28 -0
  392. mindspore/profiler/common/util.py +125 -0
  393. mindspore/profiler/envprofiling.py +2 -2
  394. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  395. mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
  396. mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
  397. mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
  398. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
  399. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
  400. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
  401. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
  402. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  403. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
  404. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  405. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  406. mindspore/profiler/parser/ascend_flops_generator.py +27 -5
  407. mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
  408. mindspore/profiler/parser/ascend_hccl_generator.py +31 -280
  409. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  410. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  411. mindspore/profiler/parser/ascend_msprof_exporter.py +151 -126
  412. mindspore/profiler/parser/ascend_msprof_generator.py +75 -274
  413. mindspore/profiler/parser/ascend_op_generator.py +94 -36
  414. mindspore/profiler/parser/ascend_timeline_generator.py +297 -131
  415. mindspore/profiler/parser/base_timeline_generator.py +17 -3
  416. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
  417. mindspore/profiler/parser/framework_parser.py +11 -4
  418. mindspore/profiler/parser/integrator.py +3 -1
  419. mindspore/profiler/parser/memory_usage_parser.py +8 -2
  420. mindspore/profiler/parser/minddata_analyzer.py +8 -2
  421. mindspore/profiler/parser/minddata_parser.py +73 -4
  422. mindspore/profiler/parser/msadvisor_analyzer.py +5 -3
  423. mindspore/profiler/parser/msadvisor_parser.py +10 -4
  424. mindspore/profiler/parser/profiler_info.py +16 -1
  425. mindspore/profiler/profiling.py +522 -195
  426. mindspore/rewrite/__init__.py +2 -13
  427. mindspore/rewrite/api/node.py +123 -37
  428. mindspore/rewrite/api/pattern_engine.py +2 -3
  429. mindspore/rewrite/api/scoped_value.py +16 -15
  430. mindspore/rewrite/api/symbol_tree.py +46 -30
  431. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  432. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  433. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  434. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  435. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  436. mindspore/rewrite/common/__init__.py +1 -2
  437. mindspore/rewrite/common/config.py +24 -0
  438. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  439. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  440. mindspore/rewrite/common/namespace.py +118 -0
  441. mindspore/rewrite/node/__init__.py +5 -5
  442. mindspore/rewrite/node/call_function.py +23 -7
  443. mindspore/rewrite/node/cell_container.py +7 -3
  444. mindspore/rewrite/node/control_flow.py +53 -28
  445. mindspore/rewrite/node/node.py +212 -196
  446. mindspore/rewrite/node/node_manager.py +51 -22
  447. mindspore/rewrite/node/node_topological_manager.py +3 -23
  448. mindspore/rewrite/parsers/__init__.py +12 -0
  449. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  450. mindspore/rewrite/parsers/assign_parser.py +637 -413
  451. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  452. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  453. mindspore/rewrite/parsers/constant_parser.py +5 -5
  454. mindspore/rewrite/parsers/container_parser.py +4 -6
  455. mindspore/rewrite/parsers/expr_parser.py +55 -0
  456. mindspore/rewrite/parsers/for_parser.py +31 -98
  457. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  458. mindspore/rewrite/parsers/if_parser.py +28 -10
  459. mindspore/rewrite/parsers/module_parser.py +8 -182
  460. mindspore/rewrite/parsers/parser.py +1 -5
  461. mindspore/rewrite/parsers/parser_register.py +1 -1
  462. mindspore/rewrite/parsers/return_parser.py +5 -10
  463. mindspore/rewrite/parsers/while_parser.py +59 -0
  464. mindspore/rewrite/sparsify/utils.py +1 -1
  465. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  466. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
  467. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  468. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  469. mindspore/run_check/_check_version.py +6 -14
  470. mindspore/run_check/run_check.py +1 -1
  471. mindspore/safeguard/rewrite_obfuscation.py +9 -19
  472. mindspore/swresample-4.dll +0 -0
  473. mindspore/swscale-6.dll +0 -0
  474. mindspore/tinyxml2.dll +0 -0
  475. mindspore/train/__init__.py +6 -5
  476. mindspore/train/_utils.py +178 -4
  477. mindspore/train/amp.py +167 -245
  478. mindspore/train/anf_ir_pb2.py +14 -2
  479. mindspore/train/callback/__init__.py +5 -2
  480. mindspore/train/callback/_backup_and_restore.py +5 -5
  481. mindspore/train/callback/_callback.py +4 -4
  482. mindspore/train/callback/_checkpoint.py +151 -37
  483. mindspore/train/callback/_cluster_monitor.py +201 -0
  484. mindspore/train/callback/_early_stop.py +2 -2
  485. mindspore/train/callback/_flops_collector.py +238 -0
  486. mindspore/train/callback/_landscape.py +16 -11
  487. mindspore/train/callback/_loss_monitor.py +2 -2
  488. mindspore/train/callback/_mindio_ttp.py +443 -0
  489. mindspore/train/callback/_on_request_exit.py +2 -2
  490. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  491. mindspore/train/callback/_summary_collector.py +13 -14
  492. mindspore/train/callback/_time_monitor.py +3 -3
  493. mindspore/train/data_sink.py +6 -5
  494. mindspore/train/dataset_helper.py +66 -21
  495. mindspore/train/loss_scale_manager.py +2 -2
  496. mindspore/train/metrics/accuracy.py +7 -7
  497. mindspore/train/metrics/confusion_matrix.py +8 -6
  498. mindspore/train/metrics/cosine_similarity.py +6 -4
  499. mindspore/train/metrics/error.py +2 -2
  500. mindspore/train/metrics/metric.py +3 -3
  501. mindspore/train/metrics/perplexity.py +2 -1
  502. mindspore/train/metrics/topk.py +2 -2
  503. mindspore/train/mind_ir_pb2.py +89 -15
  504. mindspore/train/model.py +298 -56
  505. mindspore/train/serialization.py +501 -221
  506. mindspore/train/summary/_summary_adapter.py +1 -1
  507. mindspore/train/summary/_writer_pool.py +1 -1
  508. mindspore/train/summary/summary_record.py +56 -34
  509. mindspore/train/train_thor/convert_utils.py +3 -3
  510. mindspore/turbojpeg.dll +0 -0
  511. mindspore/version.py +1 -1
  512. {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
  513. mindspore-2.3.0.dist-info/RECORD +1400 -0
  514. {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
  515. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  516. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  517. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  518. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  519. mindspore/gen_ops.py +0 -273
  520. mindspore/nn/layer/flash_attention.py +0 -189
  521. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  522. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  523. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  524. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  525. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  526. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  527. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  528. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  529. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  530. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  531. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  532. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  533. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  534. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  535. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  536. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  537. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  538. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  539. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  540. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  541. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  542. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  543. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  544. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  545. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  546. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  547. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  548. mindspore/ops/_op_impl/tbe/add.py +0 -42
  549. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  550. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  551. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  552. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  553. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  554. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  555. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  556. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  557. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  558. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  559. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  560. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  561. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  562. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  563. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  564. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  565. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  566. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  567. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  568. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  569. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  570. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  571. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  572. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  573. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  574. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  575. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  576. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  577. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  578. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  579. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  580. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  581. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  582. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  583. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  584. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  585. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  586. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  587. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  588. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  589. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  590. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  591. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  592. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  593. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  594. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  595. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  596. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  597. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  601. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  602. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  603. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  604. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  605. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  606. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  607. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  608. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  609. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  610. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  611. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  612. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  613. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  614. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  615. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  616. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  617. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  618. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  619. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  620. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  621. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  622. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  623. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  624. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  625. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  626. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  627. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  628. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  629. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  630. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  631. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  632. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  633. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  634. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  635. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  636. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  637. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  638. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  639. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  640. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  641. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  642. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  643. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  644. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  645. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  646. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  647. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  648. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  649. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  650. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  651. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  653. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  654. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  655. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  656. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  657. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  658. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  659. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  660. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  661. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  662. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  663. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  664. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  665. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  666. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  667. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  668. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  669. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  670. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  671. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  672. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  673. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  674. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  675. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  676. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  677. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  678. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  679. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  680. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  681. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  682. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  683. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  684. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  685. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  686. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  687. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  688. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  689. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  690. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  691. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  692. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  693. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  694. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  695. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  696. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  697. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  698. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  699. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  700. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  701. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  702. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  703. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  704. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  705. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  706. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  707. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  708. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  709. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  710. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  711. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  712. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  713. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  714. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  715. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  716. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  717. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  718. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  719. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  720. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  721. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  722. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  723. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  724. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  725. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  726. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  727. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  728. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  729. mindspore/ops/_op_impl/tbe/div.py +0 -41
  730. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  731. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  732. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  733. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  734. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  735. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  736. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  737. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  738. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  739. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  740. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  741. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  742. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  743. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  744. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  745. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  746. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  747. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  748. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  749. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  750. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  751. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  752. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  753. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  754. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  756. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  757. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  758. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  759. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  760. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  761. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  762. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  763. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  764. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  765. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  766. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  767. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  768. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  769. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  770. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  771. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  772. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  773. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  774. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  775. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  776. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  777. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  778. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  779. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  780. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  781. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  782. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  783. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  784. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  785. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  786. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  787. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  788. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  789. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  790. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  791. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  792. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  793. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  794. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  795. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  796. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  797. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  798. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  799. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  800. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  801. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  802. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  803. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  804. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  805. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  806. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  807. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  808. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  809. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  810. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  811. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  812. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  813. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  814. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  815. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  816. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  817. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  818. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  819. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  820. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  821. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  822. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  823. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  824. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  825. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  826. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  827. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  828. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  829. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  830. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  831. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  832. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  833. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  834. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  835. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  836. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  837. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  838. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  839. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  840. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  841. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  842. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  843. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  844. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  845. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  846. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  847. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  848. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  849. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  850. mindspore/ops/_op_impl/tbe/less.py +0 -41
  851. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  852. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  853. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  854. mindspore/ops/_op_impl/tbe/log.py +0 -40
  855. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  856. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  857. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  858. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  859. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  861. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  862. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  863. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  864. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  865. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  866. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  867. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  868. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  869. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  870. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  871. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  872. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  873. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  874. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  875. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  876. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  877. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  878. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  879. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  880. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  881. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  882. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  883. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  884. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  885. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  886. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  887. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  888. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  889. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  890. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  891. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  892. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  893. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  894. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  895. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  896. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  897. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  898. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  899. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  900. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  901. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  902. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  903. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  904. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  905. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  906. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  907. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  908. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  909. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  910. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  911. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  912. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  913. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  914. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  915. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  916. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  917. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  918. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  919. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  920. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  921. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  922. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  923. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  924. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  925. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  926. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  927. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  928. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  929. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  930. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  931. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  932. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  933. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  934. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  935. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  936. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  937. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  938. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  939. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  940. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  941. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  942. mindspore/ops/_op_impl/tbe/range.py +0 -39
  943. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  944. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  945. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  946. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  947. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  948. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  949. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  950. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  951. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  952. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  953. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  954. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  955. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  956. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  957. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  958. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  959. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  960. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  962. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  963. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  964. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  965. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  966. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  967. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  968. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  969. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  971. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  972. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  973. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  974. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  975. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  976. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  977. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  978. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  979. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  980. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  981. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  982. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  983. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  984. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  985. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  986. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  987. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  988. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  989. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  990. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  991. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  992. mindspore/ops/_op_impl/tbe/round.py +0 -38
  993. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  994. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  995. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  996. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  997. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  998. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  999. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1000. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1001. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1002. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1003. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1004. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1005. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1006. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1007. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1008. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1009. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1010. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1011. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1013. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1014. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1015. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1016. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1017. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1018. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1019. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1020. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1021. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1022. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1023. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1024. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1025. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1026. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1027. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1029. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1030. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1031. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1032. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1033. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1034. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1035. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1036. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1037. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1038. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1039. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1040. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1041. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1042. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1043. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1044. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1045. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1046. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1047. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1048. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1049. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1050. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1051. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1052. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1053. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1054. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1055. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1057. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1058. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1060. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1061. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1062. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1063. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1064. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1065. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1066. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1067. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1068. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1069. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1070. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1071. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1072. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1073. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1074. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1075. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1076. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1077. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1078. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1079. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1080. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1081. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1082. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1083. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1084. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1085. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1086. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1087. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1088. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1089. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1090. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1092. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1093. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1094. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1095. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1096. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1097. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1098. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1099. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1100. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1101. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1102. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1103. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1104. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1105. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1106. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1107. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1108. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1109. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1110. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1111. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1113. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1114. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1115. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1116. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1117. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1118. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1119. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1120. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1122. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1123. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1124. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1125. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1126. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1127. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1128. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1129. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1130. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1131. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1132. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1133. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1134. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1135. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1136. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1137. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1138. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1141. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1142. mindspore/ops/_tracefunc.py +0 -241
  1143. mindspore/ops/arg_dtype_cast.py +0 -54
  1144. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1145. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1146. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1147. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1148. mindspore/rewrite/namespace.py +0 -53
  1149. mindspore-2.2.11.dist-info/RECORD +0 -1920
  1150. {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
  1151. {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
17
17
  from __future__ import absolute_import
18
18
  from __future__ import division
19
19
 
20
+ import binascii
20
21
  import copy
21
22
  import json
22
23
  import os
@@ -30,6 +31,7 @@ from io import BytesIO
30
31
  import math
31
32
  import sys
32
33
  import time
34
+ import google
33
35
  import numpy as np
34
36
 
35
37
  from mindspore.train.checkpoint_pb2 import Checkpoint
@@ -50,9 +52,12 @@ from mindspore.common.api import _generate_branch_control_input
50
52
  from mindspore.common.initializer import initializer, One
51
53
  from mindspore.common.parameter import Parameter, _offload_if_config
52
54
  from mindspore.common.tensor import Tensor
55
+ from mindspore._c_expression import Tensor as Tensor_
53
56
  from mindspore.common._utils import is_shape_unknown
57
+ from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
54
58
  from mindspore.communication.management import get_rank, get_group_size
55
59
  from mindspore.experimental import MapParameter
60
+ from mindspore.ops import Cast
56
61
  from mindspore.parallel._cell_wrapper import get_allgather_cell
57
62
  from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
58
63
  from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
@@ -61,21 +66,23 @@ from mindspore.parallel._parallel_serialization import _convert_to_list, _conver
61
66
  _restore_group_info_list
62
67
  from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
63
68
  _store_warm_up_ptr_by_tensor_list, _cache_enable
69
+ from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
64
70
  from mindspore.train._utils import read_proto
65
71
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
66
72
  split_mindir, split_dynamic_mindir
73
+ from mindspore.common.generator import Generator
74
+ from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
75
+ from mindspore.parallel.parameter_broadcast import parameter_broadcast
67
76
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
68
- from ..ops.operations import Cast
69
77
 
70
78
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
71
79
  "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
72
80
  "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
73
- "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
81
+ "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
74
82
 
75
83
  tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
76
84
  "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
77
- "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U",
78
- "BFloat16": np.float32}
85
+ "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
79
86
 
80
87
  np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
81
88
 
@@ -95,6 +102,30 @@ INT_64_MAX = 9223372036854775807
95
102
 
96
103
  cpu_cast = Cast().set_device("CPU")
97
104
 
105
+ _ckpt_fs = FileSystem()
106
+
107
+
108
+ def init_ckpt_file_system(fs: FileSystem):
109
+ """Initialize checkpoint file system"""
110
+ if _register_mindio_file_system(fs):
111
+ return
112
+ _register_basic_file_system(fs)
113
+
114
+
115
+ # Initialize checkpoint file system
116
+ init_ckpt_file_system(_ckpt_fs)
117
+
118
+
119
+ class ParamDictFuture:
120
+ def __init__(self, executor, param_dict_future):
121
+ self.executor = executor
122
+ self.param_dict_future = param_dict_future
123
+
124
+ def result(self):
125
+ param_dict = self.param_dict_future.result()
126
+ self.executor.shutdown()
127
+ return param_dict
128
+
98
129
 
99
130
  def _special_process_par(par, new_par):
100
131
  """
@@ -176,7 +207,7 @@ def _update_param(param, new_param, strict_load):
176
207
 
177
208
  def _type_convert(param, new_param, strict_load):
178
209
  """Whether to convert parameter's type during load checkpoint into network."""
179
- float_type = (mstype.float16, mstype.float32, mstype.float64)
210
+ float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16)
180
211
  int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
181
212
  if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
182
213
  {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
@@ -221,18 +252,19 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
221
252
  logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
222
253
 
223
254
 
224
- def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False):
255
+ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False):
225
256
  """Execute the process of saving checkpoint into file."""
226
257
  try:
227
258
  with _ckpt_mutex:
228
259
  if os.path.exists(ckpt_file_name):
229
260
  os.chmod(ckpt_file_name, stat.S_IWUSR)
230
261
  os.remove(ckpt_file_name)
231
- with open(ckpt_file_name, "ab") as f:
262
+ with _ckpt_fs.create(ckpt_file_name, *_ckpt_fs.create_args) as f:
232
263
  plain_data = None
233
264
  if enc_key is not None:
234
265
  plain_data = BytesIO()
235
266
 
267
+ crc_num = 0
236
268
  for name, value in data_list.items():
237
269
  if name == "random_op":
238
270
  _write_random_seed(name, value, f)
@@ -242,21 +274,21 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
242
274
  continue
243
275
  if value[0] == "offload_parameter":
244
276
  new_value = value[1:]
245
- if value[3].dtype == mstype.bfloat16:
246
- new_value[2] = cpu_cast(value[3], mstype.float32).asnumpy().reshape(-1)
247
- else:
248
- new_value[2] = value[3].asnumpy().reshape(-1)
249
- _write_parameter_data(name, new_value, f, enc_key, plain_data)
277
+ new_value[2] = value[3]
278
+ _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
250
279
  _offload_if_config(value[3])
251
280
  continue
252
- if value[0] == "BFloat16_tensor":
253
- _write_bfloat16_data(name, value, f, enc_key, plain_data)
281
+ if value[1] == "str":
282
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
283
+ continue
284
+ if isinstance(value[2], np.ndarray):
285
+ crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
254
286
  continue
255
- if isinstance(value[2], Tensor):
287
+ if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
256
288
  _write_hugeparameter(name, value, f)
257
289
  continue
258
290
 
259
- _write_parameter_data(name, value, f, enc_key, plain_data)
291
+ crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
260
292
 
261
293
  if enc_key is not None:
262
294
  plain_data.seek(0)
@@ -266,7 +298,10 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
266
298
  f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
267
299
  block_data = plain_data.read(max_block_size)
268
300
 
269
- os.chmod(ckpt_file_name, stat.S_IRUSR)
301
+ if crc_check:
302
+ f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
303
+
304
+ os.chmod(ckpt_file_name, stat.S_IRUSR)
270
305
 
271
306
  except BaseException as e:
272
307
  logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
@@ -286,22 +321,7 @@ def _write_random_seed(name, value, f):
286
321
  f.write(checkpoint_list.SerializeToString())
287
322
 
288
323
 
289
- def _write_bfloat16_data(name, value, f, enc_key, plain_data):
290
- """Write bfloat16 data into protobuf file"""
291
- checkpoint_list = Checkpoint()
292
- param_value = checkpoint_list.value.add()
293
- param_value.tag = name
294
- param_tensor = param_value.tensor
295
- param_tensor.dims.extend(value[1])
296
- param_tensor.tensor_type = value[2]
297
- param_tensor.tensor_content = value[3].get_bytes()
298
- if enc_key is None:
299
- f.write(checkpoint_list.SerializeToString())
300
- else:
301
- plain_data.write(checkpoint_list.SerializeToString())
302
-
303
-
304
- def _write_parameter_data(name, value, f, enc_key, plain_data):
324
+ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
305
325
  """Write parameter data into protobuf file."""
306
326
  data_size = value[2].nbytes / 1024
307
327
  if data_size > SLICE_SIZE:
@@ -320,10 +340,40 @@ def _write_parameter_data(name, value, f, enc_key, plain_data):
320
340
  param_tensor.tensor_content = param_slice.tobytes()
321
341
 
322
342
  if enc_key is None:
323
- f.write(checkpoint_list.SerializeToString())
343
+ output_data = checkpoint_list.SerializeToString()
344
+ if crc_check:
345
+ crc_num = binascii.crc32(output_data, crc_num)
346
+ f.write(output_data)
347
+ else:
348
+ plain_data.write(checkpoint_list.SerializeToString())
349
+
350
+ return crc_num
351
+
352
+
353
+ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
354
+ """Write parameter bytes data into protobuf file."""
355
+ bytes_value = value[2].get_bytes()
356
+ chunk_size = 1024 * SLICE_SIZE
357
+
358
+ for i in range(0, len(bytes_value), chunk_size):
359
+ checkpoint_list = Checkpoint()
360
+ param_value = checkpoint_list.value.add()
361
+ param_value.tag = name
362
+ param_tensor = param_value.tensor
363
+ param_tensor.dims.extend(value[0])
364
+ param_tensor.tensor_type = value[1]
365
+ param_tensor.tensor_content = bytes_value[i:i + chunk_size]
366
+
367
+ if enc_key is None:
368
+ output_data = checkpoint_list.SerializeToString()
369
+ if crc_check:
370
+ crc_num = binascii.crc32(output_data, crc_num)
371
+ f.write(output_data)
324
372
  else:
325
373
  plain_data.write(checkpoint_list.SerializeToString())
326
374
 
375
+ return crc_num
376
+
327
377
 
328
378
  def _write_mapparameter(name, value, f, map_param_inc=False):
329
379
  """Write map parameter into protobuf file."""
@@ -384,10 +434,14 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
384
434
 
385
435
 
386
436
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
387
- async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
437
+ async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
438
+ crc_check=False, **kwargs):
388
439
  r"""
389
440
  Save checkpoint to a specified file.
390
441
 
442
+ Note:
443
+ The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
444
+
391
445
  Args:
392
446
  save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
393
447
  list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
@@ -409,6 +463,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
409
463
  If returns ``True`` , the Parameter that matching the custom condition will be saved.
410
464
  If returns ``False`` , the Parameter that not matching the custom condition will not
411
465
  be saved. Default: ``None`` .
466
+ crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
467
+ result to the file. Default: ``False`` .
412
468
  kwargs (dict): Configuration options dictionary.
413
469
 
414
470
  Raises:
@@ -420,7 +476,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
420
476
  >>> import mindspore as ms
421
477
  >>>
422
478
  >>> # Define the network structure of LeNet5. Refer to
423
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
479
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
424
480
  >>> net = LeNet5()
425
481
  >>> ms.save_checkpoint(net, "./lenet.ckpt",
426
482
  ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
@@ -440,7 +496,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
440
496
 
441
497
  Tutorial Examples:
442
498
  - `Saving and Loading the Model - Saving and Loading the Model Weight
443
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
499
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
444
500
  """
445
501
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
446
502
  integrated_save = Validator.check_bool(integrated_save)
@@ -448,24 +504,32 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
448
504
  append_dict = _check_append_dict(append_dict)
449
505
  enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
450
506
  enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
507
+ crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
451
508
  map_param_inc = kwargs.get('incremental', False)
452
509
  logger.info("Execute the process of saving checkpoint files.")
510
+ global_step_num = kwargs.get('global_step_num', None)
453
511
 
454
512
  save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
455
513
 
456
514
  if append_dict:
457
515
  append_info_list = []
458
516
  for k_name, value in append_dict.items():
459
- if not isinstance(value, str):
517
+ if isinstance(value, Generator):
518
+ value = value.get_state()
519
+ elif not isinstance(value, str):
460
520
  value = Tensor(value)
461
521
  append_info_list.append({"name": k_name, "data": value})
462
522
  save_obj.extend(append_info_list)
463
523
 
464
524
  data_list = OrderedDict()
525
+ data_list_np = OrderedDict()
465
526
  with _ckpt_mutex:
466
527
  for param in save_obj:
467
528
  if param["name"] == "random_op":
468
- data_list["random_op"] = param["data"]
529
+ if os.getenv("AITURBO") == "1":
530
+ data_list_np["random_op"] = param["data"]
531
+ else:
532
+ data_list["random_op"] = param["data"]
469
533
  continue
470
534
  key = param["name"]
471
535
  data_list[key] = []
@@ -479,49 +543,41 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
479
543
  elif param["data"][0] == "offload_parameter":
480
544
  data_list[key].append("offload_parameter")
481
545
  _save_param_list_data(data_list, key, param)
482
- elif param["data"][0] == "BFloat16_tensor":
483
- data_list[key].append("BFloat16_tensor")
484
- _save_param_list_data(data_list, key, param)
485
- continue
486
546
 
487
547
  if isinstance(param["data"], str):
488
- data_list[key].append([0])
489
- data_list[key].append('str')
490
- data = np.array(param["data"])
491
- data_list[key].append(data)
548
+ if os.getenv("AITURBO") == "1":
549
+ data_list_np[key] = np.array(param["data"])
550
+ else:
551
+ data_list[key].append([0])
552
+ data_list[key].append('str')
553
+ data = np.array(param["data"])
554
+ data_list[key].append(data)
492
555
  else:
493
556
  if isinstance(param["data"], Parameter):
494
557
  param["data"].init_data()
495
- if isinstance(param["data"], Tensor) and param["data"].dtype == mstype.bfloat16:
496
- data_list[key].append("BFloat16_tensor")
497
- dims = []
498
- for dim in param["data"].shape:
499
- dims.append(dim)
500
- data_list[key].append(dims)
501
- data_list[key].append("BFloat16")
502
- data_list[key].append(cpu_cast(param["data"], mstype.float32))
503
- continue
504
- dims = []
505
- if param['data'].shape == ():
506
- dims.append(0)
558
+ if os.getenv("AITURBO") == "1":
559
+ data_list_np[key] = param["data"].asnumpy()
507
560
  else:
561
+ dims = []
508
562
  for dim in param['data'].shape:
509
563
  dims.append(dim)
510
- data_list[key].append(dims)
511
- tensor_type = str(param["data"].dtype)
512
- data_list[key].append(tensor_type)
513
- if param["data"].dtype == mstype.bfloat16:
514
- data = cpu_cast(param["data"], mstype.float32).asnumpy().reshape(-1)
515
- else:
516
- data = param["data"].asnumpy().reshape(-1)
517
- data_list[key].append(data)
518
-
519
- if async_save:
564
+ data_list[key].append(dims)
565
+ tensor_type = str(param["data"].dtype)
566
+ data_list[key].append(tensor_type)
567
+ data = param["data"]
568
+ data_list[key].append(data)
569
+
570
+ if os.getenv("AITURBO") == "1":
571
+ import aiturbo
572
+ ckpt_name = os.path.basename(ckpt_file_name)
573
+ aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
574
+ elif async_save:
520
575
  data_copy = copy.deepcopy(data_list)
521
- thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")
576
+ thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check),
577
+ name="asyn_save_ckpt")
522
578
  thr.start()
523
579
  else:
524
- _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
580
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check)
525
581
 
526
582
  logger.info("Saving checkpoint process is finished.")
527
583
 
@@ -532,7 +588,21 @@ def _convert_list_to_param_list(save_obj, choice_func):
532
588
  if not save_obj:
533
589
  return param_list
534
590
  if isinstance(save_obj[0], dict):
535
- param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
591
+ for param in save_obj:
592
+ if isinstance(param, dict) and "name" in param and "data" in param:
593
+ if not isinstance(param["name"], str):
594
+ raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
595
+ f"should be string, but got {type(param['name'])}.")
596
+ if not isinstance(param["data"], Tensor):
597
+ raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
598
+ f"should be parameter, but got {type(param['data'])}.")
599
+ if choice_func is not None and not choice_func(param["name"]):
600
+ continue
601
+ each_param = {"name": param["name"], "data": param["data"]}
602
+ param_list.append(each_param)
603
+ else:
604
+ raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
605
+ f"have key values 'name' and 'value', but got {type(param)} and {param}.")
536
606
  else:
537
607
  for param in save_obj:
538
608
  if isinstance(param, Parameter):
@@ -585,6 +655,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
585
655
 
586
656
  def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
587
657
  """Convert nn.Cell to param_list."""
658
+ sync_pipeline_shared_parameters(save_obj)
588
659
  param_list = []
589
660
  parameter_layout_dict = save_obj.parameter_layout_dict
590
661
  if _is_in_auto_parallel_mode() and not parameter_layout_dict:
@@ -597,7 +668,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
597
668
  if phase in save_obj.compile_cache and _executor.has_compiled(phase):
598
669
  random_byte = _executor._graph_executor.get_random_status(phase)
599
670
  param_list.append({"name": "random_op", "data": random_byte})
600
- append_dict.pop("random_op")
671
+ append_dict.pop("random_op")
601
672
  for (key, value) in param_dict.items():
602
673
  each_param = {"name": key}
603
674
  if isinstance(value, MapParameter):
@@ -619,18 +690,13 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
619
690
  param_data.append(param_tensor.shape)
620
691
  param_data.append(str(param_tensor.dtype))
621
692
  param_data.append(value.key)
622
- elif value.data.dtype == mstype.bfloat16:
623
- param_data = ["BFloat16_tensor"]
624
- param_data.append(cpu_cast(value.data, mstype.float32))
625
- param_data.append(value.data.shape)
626
- param_data.append("BFloat16")
627
- param_data.append(value.key)
628
693
  else:
629
- param_data = Tensor(value.data.asnumpy())
694
+ param_data = value.data
630
695
 
631
696
  # in automatic model parallel scenario, some parameters were split to all the devices,
632
697
  # which should be combined before saving
633
698
  if key in parameter_layout_dict:
699
+ param_data = Tensor(value.data)
634
700
  param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
635
701
  integrated_save)
636
702
 
@@ -670,9 +736,9 @@ def _check_append_dict(append_dict):
670
736
  raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
671
737
  "{}.".format(type(append_dict)))
672
738
  for key, value in append_dict.items():
673
- if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor)):
739
+ if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)):
674
740
  raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
675
- f"value: int, float or bool, but got key: {type(key)}, value: {type(value)}")
741
+ f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}")
676
742
  return append_dict
677
743
 
678
744
 
@@ -699,13 +765,13 @@ def load(file_name, **kwargs):
699
765
  - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
700
766
  - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
701
767
 
702
- - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'.
768
+ - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
703
769
  - For details of using the customized decryption, please check the `tutorial
704
- <https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
770
+ <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
705
771
 
706
772
  - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
707
773
  `obfuscate_model()
708
- <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
774
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
709
775
 
710
776
  Returns:
711
777
  GraphCell, a compiled graph that can executed by `GraphCell`.
@@ -735,7 +801,7 @@ def load(file_name, **kwargs):
735
801
 
736
802
  Tutorial Examples:
737
803
  - `Saving and Loading the Model - Saving and Loading MindIR
738
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
804
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
739
805
  """
740
806
  if not isinstance(file_name, str):
741
807
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
@@ -776,7 +842,7 @@ def load(file_name, **kwargs):
776
842
  return graph
777
843
 
778
844
 
779
- def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
845
+ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
780
846
  """
781
847
  Auto Split MindIR.
782
848
 
@@ -784,10 +850,10 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
784
850
 
785
851
  Args:
786
852
  file_name (str): MindIR file name.
787
- device_num (int): device number.
788
- rank_id (int): rank id.
789
- dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
790
- sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
853
+ device_num (int): device number. Default: '8'.
854
+ rank_id (int): rank id. Default: '0'.
855
+ dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
856
+ sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
791
857
 
792
858
  Raises:
793
859
  ValueError: MindIR file does not exist or `file_name` is not a string.
@@ -909,13 +975,14 @@ def obfuscate_model(obf_config, **kwargs):
909
975
  - customized_func (function): A python function used for customized function mode, which used for control
910
976
  the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
911
977
  Reference to 'my_func()' in
912
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
978
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
913
979
  This function needs to ensure that its result is constant for any input. Users can refer to opaque
914
980
  predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
915
981
  when loading obfuscated model.
916
982
  - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
917
983
  structure of obfuscated models corresponding to different random seeds is different. If
918
- `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
984
+ `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
985
+ interface when loading
919
986
  obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
920
987
  be set, and the latter mode would be applied if both of them are set.
921
988
 
@@ -923,7 +990,7 @@ def obfuscate_model(obf_config, **kwargs):
923
990
 
924
991
  - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
925
992
  - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
926
- Option: 'AES-GCM' | 'AES-CBC' | 'SM4-CBC'. Default: 'AES-GCM'.
993
+ Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
927
994
 
928
995
  Raises:
929
996
  TypeError: If `obf_config` is not a dict.
@@ -934,11 +1001,15 @@ def obfuscate_model(obf_config, **kwargs):
934
1001
  ValueError: If `obf_ratio` is not provided in `obf_config`.
935
1002
  ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
936
1003
  ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
937
- ValueError: If `original_model_path` is not exist or `original_model_path` is not end with '.mindir'.
1004
+ ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
938
1005
 
939
1006
  Examples:
940
1007
  >>> import mindspore as ms
941
1008
  >>> import mindspore.nn as nn
1009
+ >>> import numpy as np
1010
+ >>> # Download ori_net.mindir
1011
+ >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
1012
+ >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
942
1013
  >>> obf_config = {'original_model_path': "./net.mindir",
943
1014
  ... 'save_model_path': "./obf_net",
944
1015
  ... 'model_inputs': [input1, ],
@@ -998,12 +1069,76 @@ def obfuscate_model(obf_config, **kwargs):
998
1069
  obf_net = nn.GraphCell(obf_graph)
999
1070
  if obf_random_seed != 0:
1000
1071
  append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
1001
- model_inputs += [append_y_tensor,]
1072
+ model_inputs += [append_y_tensor]
1002
1073
  export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
1003
1074
 
1004
1075
 
1076
+ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1077
+ dec_mode, crc_check):
1078
+ """load parameter into parameter_dict"""
1079
+ ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
1080
+ checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1081
+ try:
1082
+ param_data_list = []
1083
+ map_data_list = [[], [], []]
1084
+ map_shape_list = [0, 0, 0]
1085
+ if specify_prefix:
1086
+ logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
1087
+ "please use `choice_func` instead.")
1088
+ if filter_prefix:
1089
+ logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1090
+ "please use `choice_func` instead.")
1091
+ for element_id, element in enumerate(checkpoint_list.value):
1092
+ if element.tag == "random_op":
1093
+ parameter_dict["random_op"] = element.tensor.tensor_content
1094
+ continue
1095
+ if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1096
+ continue
1097
+ if specify_prefix is None and filter_prefix is None and \
1098
+ choice_func is not None and not choice_func(element.tag):
1099
+ continue
1100
+ if element.tensor.ByteSize() == 0:
1101
+ _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
1102
+ parameter_dict)
1103
+ if element.tag in parameter_dict:
1104
+ map_data_list = [[], [], []]
1105
+ map_shape_list = [0, 0, 0]
1106
+ continue
1107
+ data = element.tensor.tensor_content
1108
+ data_type = element.tensor.tensor_type
1109
+ np_type = tensor_to_np_type.get(data_type)
1110
+ ms_type = tensor_to_ms_type[data_type]
1111
+ if data_type == 'str':
1112
+ str_length = int(len(data) / 4)
1113
+ np_type = np_type + str(str_length)
1114
+ param_data_list.append(data)
1115
+ if (element_id == len(checkpoint_list.value) - 1) or \
1116
+ (element.tag != checkpoint_list.value[element_id + 1].tag):
1117
+ new_data = b"".join(param_data_list)
1118
+ param_data_list.clear()
1119
+ dims = element.tensor.dims
1120
+ if data_type == 'str':
1121
+ str_value = np.frombuffer(new_data, np_type)
1122
+ parameter_dict[element.tag] = str(str_value[0])
1123
+ else:
1124
+ if dims == [0]:
1125
+ dims = []
1126
+ param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
1127
+ parameter = Parameter(param_data, name=element.tag)
1128
+ parameter_dict[element.tag] = parameter
1129
+ _offload_if_config(parameter)
1130
+
1131
+ logger.info("Loading checkpoint files process is finished.")
1132
+
1133
+ except BaseException as e:
1134
+ logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
1135
+ raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
1136
+ "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
1137
+
1138
+
1005
1139
  def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
1006
- dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
1140
+ dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
1141
+ crc_check=False):
1007
1142
  """
1008
1143
  Load checkpoint info from a specified file.
1009
1144
 
@@ -1034,6 +1169,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1034
1169
  and the return value is a bool. If returns ``True`` , the Parameter
1035
1170
  that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1036
1171
  matches the custom condition will be removed. Default: ``None`` .
1172
+ crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1037
1173
 
1038
1174
  Returns:
1039
1175
  Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
@@ -1076,83 +1212,31 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1076
1212
 
1077
1213
  Tutorial Examples:
1078
1214
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1079
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1215
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1080
1216
  """
1081
- ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
1082
1217
  specify_prefix = _check_prefix(specify_prefix)
1083
1218
  filter_prefix = _check_prefix(filter_prefix)
1084
1219
  dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1085
1220
  dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1221
+ crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1086
1222
  logger.info("Execute the process of loading checkpoint files.")
1087
- checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
1088
1223
 
1089
1224
  parameter_dict = {}
1090
- try:
1091
- param_data_list = []
1092
- map_data_list = [[], [], []]
1093
- map_shape_list = [0, 0, 0]
1094
- if specify_prefix:
1095
- logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
1096
- "please use `choice_func` instead.")
1097
- if filter_prefix:
1098
- logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1099
- "please use `choice_func` instead.")
1100
- for element_id, element in enumerate(checkpoint_list.value):
1101
- if element.tag == "random_op":
1102
- parameter_dict["random_op"] = element.tensor.tensor_content
1103
- continue
1104
- if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1105
- continue
1106
- if specify_prefix is None and filter_prefix is None and \
1107
- choice_func is not None and not choice_func(element.tag):
1108
- continue
1109
- if element.tensor.ByteSize() == 0:
1110
- _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
1111
- if element.tag in parameter_dict:
1112
- map_data_list = [[], [], []]
1113
- map_shape_list = [0, 0, 0]
1114
- continue
1115
- data = element.tensor.tensor_content
1116
- data_type = element.tensor.tensor_type
1117
- np_type = tensor_to_np_type.get(data_type)
1118
- ms_type = tensor_to_ms_type[data_type]
1119
- if data_type == 'str':
1120
- str_length = int(len(data) / 4)
1121
- np_type = np_type + str(str_length)
1122
- if data_type == "BFloat16":
1123
- dims = element.tensor.dims
1124
- param_data = np.frombuffer(data, np_type)
1125
- param_data = param_data.reshape(list(dims))
1126
- parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
1127
- parameter_dict[element.tag] = parameter
1128
- continue
1129
- element_data = np.frombuffer(data, np_type)
1130
- param_data_list.append(element_data)
1131
- if (element_id == len(checkpoint_list.value) - 1) or \
1132
- (element.tag != checkpoint_list.value[element_id + 1].tag):
1133
- new_data = b"".join(param_data_list)
1134
- param_data = np.frombuffer(new_data, np_type)
1135
- param_data_list.clear()
1136
- dims = element.tensor.dims
1137
- if dims == [0] and data_type == 'str':
1138
- parameter_dict[element.tag] = str(element_data[0])
1139
- else:
1140
- if dims == [0] and 'Float' in data_type:
1141
- param_data = float(param_data[0])
1142
- if dims == [0] and 'Int' in data_type:
1143
- param_data = int(param_data[0])
1144
- if dims not in ([0], [1]):
1145
- param_data = param_data.reshape(list(dims))
1146
- parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
1147
- parameter_dict[element.tag] = parameter
1148
- _offload_if_config(parameter)
1149
-
1150
- logger.info("Loading checkpoint files process is finished.")
1151
1225
 
1152
- except BaseException as e:
1153
- logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
1154
- raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
1155
- "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
1226
+ if os.getenv("AITURBO") == "1":
1227
+ rank_id = get_rank()
1228
+ import aiturbo
1229
+ ckpt_path = os.path.dirname(ckpt_file_name)
1230
+ ckpt_name = os.path.basename(ckpt_file_name)
1231
+ np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id)
1232
+ for key, value in np_dict.items():
1233
+ if isinstance(value, str):
1234
+ parameter_dict[key] = value
1235
+ else:
1236
+ parameter_dict[key] = Parameter(Tensor(value), name=key)
1237
+ else:
1238
+ _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1239
+ dec_mode, crc_check)
1156
1240
 
1157
1241
  if not parameter_dict:
1158
1242
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
@@ -1168,6 +1252,86 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1168
1252
  return parameter_dict
1169
1253
 
1170
1254
 
1255
+ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
1256
+ dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
1257
+ """
1258
+ Load checkpoint info from a specified file asyncly.
1259
+
1260
+ .. warning::
1261
+ This is an experimental API that is subject to change or deletion.
1262
+
1263
+ Note:
1264
+ - `specify_prefix` and `filter_prefix` do not affect each other.
1265
+ - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1266
+ - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1267
+ `choice_func` is recommended instead.
1268
+ And using either of those two args will override `choice_func` at the same time.
1269
+
1270
+ Args:
1271
+ ckpt_file_name (str): Checkpoint file name.
1272
+ net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1273
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1274
+ parameter into net when parameter name's suffix in checkpoint file is the
1275
+ same as the parameter in the network. When the types are inconsistent
1276
+ perform type conversion on the parameters of the same type, such as float32
1277
+ to float16. Default: ``False`` .
1278
+ filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1279
+ starting with the `filter_prefix` will not be loaded. Default: ``None`` .
1280
+ dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
1281
+ the decryption is not required. Default: ``None`` .
1282
+ dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
1283
+ the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
1284
+ and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
1285
+ specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1286
+ starting with the specify_prefix will be loaded. Default: ``None`` .
1287
+ choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
1288
+ string, and the return value is a bool. If returns ``True`` , the Parameter
1289
+ that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1290
+ matches the custom condition will be removed. Default: ``None`` .
1291
+
1292
+ Returns:
1293
+ A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
1294
+
1295
+ Raises:
1296
+ ValueError: Checkpoint file's format is incorrect.
1297
+ ValueError: Parameter's dict is None after load checkpoint file.
1298
+ TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
1299
+
1300
+ Examples:
1301
+ >>> import mindspore
1302
+ >>> from mindspore import nn
1303
+ >>> from mindspore.train import Model
1304
+ >>> from mindspore.amp import FixedLossScaleManager
1305
+ >>> from mindspore import context
1306
+ >>> from mindspore import load_checkpoint_async
1307
+ >>> from mindspore import load_param_into_net
1308
+ >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1309
+ >>> # Create the dataset taking MNIST as an example. Refer to
1310
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1311
+ >>> dataset = create_dataset()
1312
+ >>> # Define the network structure of LeNet5. Refer to
1313
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1314
+ >>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
1315
+ >>> net = LeNet5()
1316
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
1317
+ >>> loss_scale_manager = FixedLossScaleManager()
1318
+ >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1319
+ >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
1320
+ ... loss_scale_manager=loss_scale_manager)
1321
+ >>> pd_future = load_checkpoint_async(ckpt_file)
1322
+ >>> model.build(train_dataset=dataset, epoch=2)
1323
+ >>> param_dict = pd_future.result()
1324
+ >>> load_param_into_net(net, param_dict)
1325
+ >>> model.train(2, dataset)
1326
+ >>> print("param dict len: ", len(param_dict), flush=True)
1327
+ """
1328
+ from concurrent.futures import ThreadPoolExecutor
1329
+ executor = ThreadPoolExecutor(max_workers=2)
1330
+ param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
1331
+ dec_key, dec_mode, specify_prefix, choice_func)
1332
+ return ParamDictFuture(executor, param_dict_future)
1333
+
1334
+
1171
1335
  def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1172
1336
  map_shape_list, parameter_dict):
1173
1337
  """load map parameter."""
@@ -1239,17 +1403,28 @@ def _check_prefix(prefix):
1239
1403
  return prefix
1240
1404
 
1241
1405
 
1242
- def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
1406
+ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1243
1407
  """Parse checkpoint protobuf."""
1244
1408
  checkpoint_list = Checkpoint()
1245
1409
  try:
1246
1410
  if dec_key is None:
1247
- with open(ckpt_file_name, "rb") as f:
1411
+ with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1248
1412
  pb_content = f.read()
1249
1413
  else:
1250
1414
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1251
1415
  if pb_content is None:
1252
1416
  raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
1417
+ if crc_check and pb_content[-17:-10] == b"crc_num":
1418
+ logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
1419
+ if pb_content[-17:-10] == b"crc_num":
1420
+ crc_num_bytes = pb_content[-10:]
1421
+ pb_content = pb_content[:-17]
1422
+ if crc_check:
1423
+ crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
1424
+ cal_crc_num = binascii.crc32(pb_content, 0)
1425
+ if cal_crc_num != crc_num:
1426
+ raise ValueError("For 'load_checkpoint', the crc check is failed, "
1427
+ "please check whether the ckpt file is damaged.")
1253
1428
  checkpoint_list.ParseFromString(pb_content)
1254
1429
  except BaseException as e:
1255
1430
  if _is_cipher_file(ckpt_file_name):
@@ -1282,13 +1457,33 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
1282
1457
 
1283
1458
  def _init_parameter_data_in_parallel_mode(net, parameter_dict):
1284
1459
  """In parallel mode, only init the paraemters in ckpt."""
1460
+ is_train_phase = net.phase.startswith('train')
1285
1461
  for _, param in net.parameters_and_names():
1462
+ if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
1463
+ param.shape = tuple(parameter_dict[param.name].shape)
1464
+ continue
1286
1465
  if param.name in parameter_dict and param.has_init:
1287
1466
  logger.warning("{} is not init while load ckpt.".format(param.name))
1288
1467
  new_tensor = param.init_data()
1289
1468
  param._update_tensor_data(new_tensor)
1290
1469
 
1291
1470
 
1471
+ def _check_load_param_into_net(net, parameter_dict):
1472
+ """check load_param_into_net"""
1473
+ if not isinstance(net, nn.Cell):
1474
+ logger.critical("Failed to combine the net and the parameters.")
1475
+ msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1476
+ raise TypeError(msg)
1477
+ if not isinstance(parameter_dict, dict):
1478
+ logger.critical("Failed to combine the net and the parameters.")
1479
+ msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1480
+ "but got {}.".format(type(parameter_dict)))
1481
+ raise TypeError(msg)
1482
+ if "random_op" in parameter_dict.keys():
1483
+ net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1484
+ parameter_dict.pop("random_op")
1485
+
1486
+
1292
1487
  def load_param_into_net(net, parameter_dict, strict_load=False):
1293
1488
  """
1294
1489
  Load parameters into network, return parameter list that are not loaded in the network.
@@ -1303,8 +1498,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1303
1498
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1304
1499
 
1305
1500
  Returns:
1306
- param_not_load (List), the parameter name in model which are not loaded into the network.
1307
- ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
1501
+ - param_not_load (List), the parameter name in model which are not loaded into the network.
1502
+ - ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
1308
1503
 
1309
1504
  Raises:
1310
1505
  TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@@ -1313,7 +1508,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1313
1508
  >>> import mindspore as ms
1314
1509
  >>>
1315
1510
  >>> # Define the network structure of LeNet5. Refer to
1316
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1511
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1317
1512
  >>> net = LeNet5()
1318
1513
  >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1319
1514
  >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
@@ -1323,20 +1518,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1323
1518
 
1324
1519
  Tutorial Examples:
1325
1520
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1326
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1521
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1327
1522
  """
1328
- if not isinstance(net, nn.Cell):
1329
- logger.critical("Failed to combine the net and the parameters.")
1330
- msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1331
- raise TypeError(msg)
1332
- if not isinstance(parameter_dict, dict):
1333
- logger.critical("Failed to combine the net and the parameters.")
1334
- msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1335
- "but got {}.".format(type(parameter_dict)))
1336
- raise TypeError(msg)
1337
- if "random_op" in parameter_dict.keys():
1338
- net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1339
- parameter_dict.pop("random_op")
1523
+ _check_load_param_into_net(net, parameter_dict)
1340
1524
  for key, value in parameter_dict.items():
1341
1525
  if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1342
1526
  logger.critical("Load parameters into net failed.")
@@ -1346,6 +1530,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1346
1530
 
1347
1531
  strict_load = Validator.check_bool(strict_load)
1348
1532
  logger.info("Execute the process of loading parameters into net.")
1533
+ for _, param in net.parameters_and_names():
1534
+ param.from_ckpt = True
1349
1535
  if not _is_in_auto_parallel_mode():
1350
1536
  net.init_parameters_data()
1351
1537
  else:
@@ -1360,7 +1546,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1360
1546
  # Add has attr protection when load server checkpoint file on worker.
1361
1547
  if not hasattr(parameter_dict[param.name], "data"):
1362
1548
  continue
1363
- new_param = copy.deepcopy(parameter_dict[param.name])
1549
+ new_param = parameter_dict[param.name]
1364
1550
  _update_param(param, new_param, strict_load)
1365
1551
  ckpt_not_load.remove(param.name)
1366
1552
  else:
@@ -1369,18 +1555,21 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1369
1555
  if param_not_load and not strict_load:
1370
1556
  _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1371
1557
 
1372
- logger.debug("Params not matched(in net but not in parameter_dict):")
1373
- for param_name in param_not_load:
1374
- logger.debug("%s", param_name)
1375
-
1376
1558
  logger.info("Loading parameters into net is finished.")
1377
1559
  if param_not_load:
1378
1560
  logger.warning("For 'load_param_into_net', "
1379
1561
  "{} parameters in the 'net' are not loaded, because they are not in the "
1380
1562
  "'parameter_dict', please check whether the network structure is consistent "
1381
1563
  "when training and loading checkpoint.".format(len(param_not_load)))
1382
- for param_name in param_not_load:
1383
- logger.warning("{} is not loaded.".format(param_name))
1564
+ logger.warning("{} are not loaded.".format(param_not_load))
1565
+ if os.getenv("AITURBO") == "1" and net.parameter_layout_dict is not None:
1566
+ param_layout = net.parameter_layout_dict
1567
+ param_redundancy = get_parameter_redundancy(param_layout)
1568
+ remove_param_redundancy_dict = remove_param_redundancy(param_redundancy)
1569
+ target_parameter_name_set = set(parameter_dict.keys())
1570
+ for rank_id, param_name_set in remove_param_redundancy_dict:
1571
+ if param_name_set == target_parameter_name_set:
1572
+ parameter_broadcast(net, param_layout, rank_id)
1384
1573
  return param_not_load, ckpt_not_load
1385
1574
 
1386
1575
 
@@ -1494,6 +1683,23 @@ def _save_graph(network, file_name):
1494
1683
  f.write(graph_pb)
1495
1684
 
1496
1685
 
1686
+ def _reshape_tensor(tensor, dst_shape):
1687
+ """reshape tensor to dst shape"""
1688
+ np_tensor = tensor.asnumpy()
1689
+ np_tensor = np_tensor.reshape(dst_shape)
1690
+ return Tensor(np_tensor, tensor.dtype)
1691
+
1692
+
1693
+ def _check_param_for_integrate_save(pipeline_stages, uniform_split):
1694
+ """check whether current settings and parameters are supported in integrated save checkpoint mode"""
1695
+ if pipeline_stages > 1:
1696
+ raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
1697
+ if uniform_split == 0:
1698
+ raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
1699
+ "'integrated_save' to True, the checkpoint will be integrated save, it "
1700
+ "is only supports uniform split tensor now.")
1701
+
1702
+
1497
1703
  def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
1498
1704
  """
1499
1705
  Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
@@ -1507,7 +1713,7 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1507
1713
  Tensor, the combined tensor which with the whole data value.
1508
1714
  """
1509
1715
  layout = parameter_layout_dict[param_name]
1510
- if len(layout) < 6:
1716
+ if len(layout) < 8:
1511
1717
  logger.info("The layout dict does not contain the key %s", param_name)
1512
1718
  return param_data
1513
1719
 
@@ -1515,6 +1721,13 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1515
1721
  tensor_map = layout[1]
1516
1722
  uniform_split = layout[4]
1517
1723
  opt_shard_group = layout[5]
1724
+ before_reshape_slice_shape = layout[2]
1725
+ before_reshape_full_shape = layout[6]
1726
+ after_reshape_slice_shape = layout[7]
1727
+ do_reshape = False
1728
+ if before_reshape_full_shape and after_reshape_slice_shape \
1729
+ and after_reshape_slice_shape != before_reshape_slice_shape:
1730
+ do_reshape = True
1518
1731
 
1519
1732
  allgather_net = None
1520
1733
  mp_weight = False
@@ -1527,26 +1740,26 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1527
1740
  else:
1528
1741
  logger.info("Need to create allgather net for %s", param_name)
1529
1742
  if integrated_save:
1530
- if context.get_auto_parallel_context("pipeline_stages") > 1:
1531
- raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
1532
- if uniform_split == 0:
1533
- raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
1534
- "'integrated_save' to True, the checkpoint will be integrated save, it "
1535
- "is only supports uniform split tensor now.")
1743
+ _check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
1536
1744
  # while any dim is not equal to -1, means param is split and needs to be merged
1537
1745
  # pipeline parallel need to be supported here later
1538
1746
  if mp_weight:
1539
- allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
1747
+ allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
1748
+ tuple(after_reshape_slice_shape))
1540
1749
  object.__setattr__(allgather_net, "keep_input_unchanged", True)
1541
1750
  elif opt_shard_group:
1542
- allgather_net = get_allgather_cell(opt_shard_group, False)
1751
+ allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1752
+ tuple(after_reshape_slice_shape))
1543
1753
  elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
1544
- allgather_net = get_allgather_cell(opt_shard_group, False)
1754
+ allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1755
+ tuple(after_reshape_slice_shape))
1545
1756
  net.parallel_parameter_merge_net_dict[param_name] = allgather_net
1546
1757
  if allgather_net:
1547
1758
  param_data = allgather_net(param_data)
1548
1759
  if mp_weight and integrated_save:
1549
1760
  param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
1761
+ if do_reshape:
1762
+ param_data = _reshape_tensor(param_data, before_reshape_full_shape)
1550
1763
  return param_data
1551
1764
 
1552
1765
 
@@ -1556,10 +1769,13 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1556
1769
 
1557
1770
  Note:
1558
1771
  1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1559
- 2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
1772
+ 2. When `file_name` does not have a suffix, the system will automatically add one
1773
+ according to the `file_format`.
1560
1774
  3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1561
1775
  4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1562
1776
  class properties in calculations.
1777
+ 5. AIR format is deprecated, and will be removed in a future version, please use other format or use
1778
+ MindSpore Lite to do offline inference.
1563
1779
 
1564
1780
  Args:
1565
1781
  net (Union[Cell, function]): MindSpore network.
@@ -1584,9 +1800,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1584
1800
  - For 'AIR' and 'ONNX' models, only customized encryption is supported.
1585
1801
  - For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
1586
1802
  or Customized encryption.
1587
- Default: 'AES-GCM'.
1803
+ Default: ``'AES-GCM'``.
1588
1804
  - For details of using the customized encryption, please check the `tutorial
1589
- <https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
1805
+ <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
1590
1806
 
1591
1807
  - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
1592
1808
  preprocessing of the dataset into MindIR.
@@ -1600,32 +1816,49 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1600
1816
  - customized_func (function): A python function used for customized function mode, which used for control
1601
1817
  the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1602
1818
  Reference to 'my_func()' in
1603
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
1819
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1604
1820
  This function needs to ensure that its result is constant for any input. Users can refer to opaque
1605
1821
  predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1606
1822
  obfuscated model.
1607
1823
  - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1608
1824
  structure of obfuscated models corresponding to different random seeds is different. If
1609
- `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
1825
+ `obf_random_seed` is set, then it should be passed
1826
+ to :class:`mindspore.nn.GraphCell` interface when loading
1610
1827
  obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1611
1828
  be set, and the latter mode would be applied if both of them are set.
1612
1829
 
1613
1830
  - incremental (bool): export MindIR incrementally.
1614
1831
 
1832
+ - custom_func (function): Functions for custom defined export policies. This function will be used to
1833
+ customize the model during network export. Currently only support for files with mindir format. The
1834
+ function only accepts one input representing the proto object of the mindir file. When modifying a model,
1835
+ it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
1836
+ failure or functional errors. Default: ``None`` .
1837
+
1615
1838
  Examples:
1616
1839
  >>> import mindspore as ms
1617
1840
  >>> import numpy as np
1618
1841
  >>> from mindspore import Tensor
1619
1842
  >>>
1620
1843
  >>> # Define the network structure of LeNet5. Refer to
1621
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1844
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1622
1845
  >>> net = LeNet5()
1623
1846
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1624
1847
  >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
1848
+ >>>
1849
+ >>> # Export model in MindIR format and modified the model info using custom_func
1850
+ >>> # The custom_func only support one input representing the Proto object of the model
1851
+ >>> # And custom_func does not support return value
1852
+ >>> def _custom_func(mindir_model):
1853
+ ... mindir_model.producer_name = "test11111"
1854
+ ... mindir_model.producer_version = "11.0"
1855
+ ... mindir_model.user_info["version"] = "11.0"
1856
+ >>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
1857
+
1625
1858
 
1626
1859
  Tutorial Examples:
1627
1860
  - `Saving and Loading the Model - Saving and Loading MindIR
1628
- <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
1861
+ <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
1629
1862
  """
1630
1863
  old_ms_jit_value = context.get_context("jit_syntax_level")
1631
1864
  context.set_context(jit_syntax_level=mindspore.STRICT)
@@ -1633,6 +1866,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1633
1866
  supported_formats = ['AIR', 'ONNX', 'MINDIR']
1634
1867
  if file_format not in supported_formats:
1635
1868
  raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
1869
+ if file_format == 'AIR':
1870
+ logger.warning("AIR format is deprecated, and will be removed in a future version, please use other format or "
1871
+ "use MindSpore Lite to do offline inference")
1636
1872
  Validator.check_file_name_by_regular(file_name)
1637
1873
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
1638
1874
 
@@ -1685,7 +1921,7 @@ def _get_funcgraph(net, *inputs):
1685
1921
  >>> from mindspore import Tensor
1686
1922
  >>>
1687
1923
  >>> # Define the network structure of LeNet5. Refer to
1688
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1924
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1689
1925
  >>> net = LeNet5()
1690
1926
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1691
1927
  >>> ms.get_funcgraph(net, input_tensor)
@@ -1707,6 +1943,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
1707
1943
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
1708
1944
  if "obf_config" in kwargs and file_format != "MINDIR":
1709
1945
  raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
1946
+ if "custom_func" in kwargs and file_format != "MINDIR":
1947
+ raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
1710
1948
  if file_format == 'AIR':
1711
1949
  _save_air(net, file_name, *inputs, **kwargs)
1712
1950
  elif file_format == 'ONNX':
@@ -1867,12 +2105,12 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1867
2105
  data_file_name = os.path.join(dirname, external_local)
1868
2106
  f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
1869
2107
  try:
1870
- round_ = 0
2108
+ round = 0
1871
2109
  names = []
1872
2110
  for param_proto in model.graph.parameter:
1873
2111
  name = param_proto.name[param_proto.name.find(":") + 1:]
1874
2112
  names.append((name, param_proto))
1875
- names.sort(key=lambda x: x[0])
2113
+ names.sort(key=lambda x: x[0])
1876
2114
  for pairs in names:
1877
2115
  name = pairs[0]
1878
2116
  param_proto = pairs[1]
@@ -1895,8 +2133,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1895
2133
  offset += (data_length + append_size)
1896
2134
  write_data = _encrypt_data(is_encrypt, write_data, kwargs)
1897
2135
  f.write(write_data)
1898
- round_ += 1
1899
- logger.debug(f"writing {round_}th split data, name:{name}")
2136
+ round += 1
2137
+ logger.debug(f"writing {round}th split data, name:{name}")
1900
2138
 
1901
2139
  graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
1902
2140
  if os.path.exists(graph_file_name):
@@ -1993,6 +2231,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
1993
2231
  dataset = kwargs.get('dataset')
1994
2232
  _save_dataset_to_mindir(model, dataset)
1995
2233
 
2234
+ custom_func = kwargs.get('custom_func', None)
2235
+ if custom_func is not None:
2236
+ custom_func(model)
2237
+
1996
2238
  save_together = _save_together(net_dict, model)
1997
2239
  is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
1998
2240
  if save_together:
@@ -2079,6 +2321,45 @@ def _save_dataset_to_mindir(model, dataset):
2079
2321
  model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
2080
2322
 
2081
2323
 
2324
+ def check_checkpoint(ckpt_file_name):
2325
+ """
2326
+ Check whether the checkpoint is valid.
2327
+
2328
+ Args:
2329
+ ckpt_file_name (str): Checkpoint file name.
2330
+
2331
+ Returns:
2332
+ bool, whether the checkpoint is valid.
2333
+
2334
+ Examples:
2335
+ >>> import mindspore as ms
2336
+ >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
2337
+ >>> check_result = ms.check_checkpoint(ckpt_file_name)
2338
+ >>> print(check_result)
2339
+ True
2340
+ """
2341
+ if not ckpt_file_name.endswith('.ckpt'):
2342
+ return False
2343
+ checkpoint_list = Checkpoint()
2344
+ with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
2345
+ pb_content = f.read()
2346
+ if pb_content[-17:-10] == b"crc_num":
2347
+ crc_num_bytes = pb_content[-10:]
2348
+ pb_content = pb_content[:-17]
2349
+ crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
2350
+ cal_crc_num = binascii.crc32(pb_content, 0)
2351
+ if cal_crc_num != crc_num:
2352
+ logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
2353
+ return False
2354
+ try:
2355
+ checkpoint_list.ParseFromString(pb_content)
2356
+ except google.protobuf.message.DecodeError as e:
2357
+ logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
2358
+ logger.warning(e)
2359
+ return False
2360
+ return True
2361
+
2362
+
2082
2363
  def parse_print(print_file_name):
2083
2364
  """
2084
2365
  Parse data file generated by :class:`mindspore.ops.Print`.
@@ -2423,7 +2704,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2423
2704
  in at least one of them. Default: ``None`` .
2424
2705
  strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
2425
2706
  into net when parameter name's suffix in checkpoint file is the same as the
2426
- parameter in the network. When the types are inconsistent perform type conversion
2707
+ parameter in the network. When the types are inconsistent, perform type conversion
2427
2708
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
2428
2709
  dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
2429
2710
  is not required. Default: ``None`` .
@@ -2444,14 +2725,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2444
2725
 
2445
2726
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2446
2727
  Please see the `rank table startup
2447
- <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
2728
+ <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
2448
2729
  for more details.
2449
2730
 
2450
2731
  For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2451
- <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
2732
+ <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
2452
2733
 
2453
2734
  For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2454
- Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
2735
+ Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
2455
2736
 
2456
2737
  >>> import os
2457
2738
  >>> import numpy as np
@@ -2717,11 +2998,10 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
2717
2998
  param_name = merged_param.name
2718
2999
  tensor_layout = predict_strategy[param_name]
2719
3000
  rank = get_rank()
2720
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
3001
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
2721
3002
  requires_grad = merged_param.requires_grad
2722
3003
  layerwise_parallel = merged_param.layerwise_parallel
2723
- data_type = merged_param.data.dtype
2724
- if data_type == mstype.bfloat16:
3004
+ if merged_param.data.dtype == mstype.bfloat16:
2725
3005
  split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
2726
3006
  else:
2727
3007
  split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
@@ -2789,7 +3069,7 @@ def _get_mindir_inputs(file_name):
2789
3069
 
2790
3070
  def convert_model(mindir_file, convert_file, file_format):
2791
3071
  """
2792
- Convert mindir model to other format model. Current version only support convert to "ONNX" format.
3072
+ Convert mindir model to other format model. The current version only supports conversion to ONNX models.
2793
3073
 
2794
3074
  .. warning::
2795
3075
  This is an experimental API that is subject to change or deletion.