mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import os
23
23
  import shutil
24
24
  import stat
25
25
  import threading
26
- from threading import Thread, Lock
26
+ from threading import Thread, RLock
27
27
  from collections import defaultdict, OrderedDict
28
28
  from io import BytesIO
29
29
 
@@ -59,9 +59,11 @@ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_
59
59
  from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
60
60
  from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
61
61
  _restore_group_info_list
62
+ from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
63
+ _store_warm_up_ptr_by_tensor_list, _cache_enable
62
64
  from mindspore.train._utils import read_proto
63
65
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
64
- split_mindir
66
+ split_mindir, split_dynamic_mindir
65
67
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
66
68
 
67
69
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
@@ -79,7 +81,7 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4:
79
81
  5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
80
82
  11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
81
83
 
82
- _ckpt_mutex = Lock()
84
+ _ckpt_mutex = RLock()
83
85
 
84
86
  # unit is KB
85
87
  SLICE_SIZE = 512 * 1024
@@ -333,8 +335,8 @@ def _write_hugeparameter(name, value, f):
333
335
 
334
336
  def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
335
337
  """Check save_obj and ckpt_file_name for save_checkpoint."""
336
- if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
337
- raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell or list, "
338
+ if not isinstance(save_obj, (nn.Cell, list, dict)):
339
+ raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
338
340
  "but got {}.".format(type(save_obj)))
339
341
  if not isinstance(ckpt_file_name, str):
340
342
  raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
@@ -351,14 +353,15 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
351
353
 
352
354
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
353
355
  async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
354
- """
356
+ r"""
355
357
  Save checkpoint to a specified file.
356
358
 
357
359
  Args:
358
- save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
359
- [{"name": param_name, "data": param_data},...], the type of
360
- param_name would be string, and the type of param_data would
361
- be parameter or Tensor).
360
+ save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
361
+ list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
362
+ elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
363
+ `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
364
+ it can be the returned value of `mindspore.load_checkpoint()`.
362
365
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
363
366
  integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
364
367
  async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
@@ -370,16 +373,14 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
370
373
  mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
371
374
  Default: ``"AES-GCM"`` .
372
375
  choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
373
- a parameter name in string type, and the return value is a bool.
376
+ a parameter name in string type, and the returned value is a bool.
374
377
  If returns ``True`` , the Parameter that matching the custom condition will be saved.
375
378
  If returns ``False`` , the Parameter that not matching the custom condition will not
376
379
  be saved. Default: ``None`` .
377
380
  kwargs (dict): Configuration options dictionary.
378
381
 
379
- - incremental (bool): Whether export checkpoint for MapParameter incrementally.
380
-
381
382
  Raises:
382
- TypeError: If the parameter `save_obj` is not `nn.Cell` or list type.
383
+ TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
383
384
  TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
384
385
  TypeError: If the parameter `ckpt_file_name` is not string type.
385
386
 
@@ -387,17 +388,27 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
387
388
  >>> import mindspore as ms
388
389
  >>>
389
390
  >>> # Define the network structure of LeNet5. Refer to
390
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
391
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
391
392
  >>> net = LeNet5()
392
393
  >>> ms.save_checkpoint(net, "./lenet.ckpt",
393
- >>> choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
394
- >>> param_dict = ms.load_checkpoint("./lenet.ckpt")
395
- >>> print(param_dict)
394
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
395
+ >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
396
+ >>> print(param_dict1)
396
397
  {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
398
+ >>> params_list = net.trainable_params()
399
+ >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
400
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
401
+ >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
402
+ >>> print(param_dict2)
403
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
404
+ >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
405
+ >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
406
+ >>> print(param_dict3)
407
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
397
408
 
398
409
  Tutorial Examples:
399
410
  - `Saving and Loading the Model - Saving and Loading the Model Weight
400
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-the-model-weight>`_
411
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
401
412
  """
402
413
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
403
414
  integrated_save = Validator.check_bool(integrated_save)
@@ -408,70 +419,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
408
419
  map_param_inc = kwargs.get('incremental', False)
409
420
  logger.info("Execute the process of saving checkpoint files.")
410
421
 
411
- if isinstance(save_obj, nn.Cell):
412
- if save_obj.ge_init and not save_obj.ge_sync_data:
413
- from mindspore.train.callback._callback import set_cur_net
414
- set_cur_net(save_obj)
415
- save_obj.exec_checkpoint_graph()
416
- parameter_layout_dict = save_obj.parameter_layout_dict
417
- if _is_in_auto_parallel_mode() and not parameter_layout_dict:
418
- parameter_layout_dict = _get_parameter_layout()
419
- if not _is_in_auto_parallel_mode():
420
- save_obj.init_parameters_data()
421
- param_dict = OrderedDict()
422
- for _, param in save_obj.parameters_and_names():
423
- not_sliced = not param.sliced
424
- is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
425
- # All parameters are initialized immediately under PyNative mode, skip this judgement.
426
- if is_graph_mode and _is_in_auto_parallel_mode() and (not_sliced or param.has_init):
427
- continue
428
- if choice_func is not None and not choice_func(param.name):
429
- continue
430
- param_dict[param.name] = param
431
- param_list = []
432
- if append_dict and "random_op" in append_dict:
433
- phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
434
- if phase in save_obj.compile_cache and _executor.has_compiled(phase):
435
- random_byte = _executor._graph_executor.get_random_status(phase)
436
- param_list.append({"name": "random_op", "data": random_byte})
437
- append_dict.pop("random_op")
438
- for (key, value) in param_dict.items():
439
- each_param = {"name": key}
440
- if isinstance(value, MapParameter):
441
- each_param["data"] = value
442
- param_list.append(each_param)
443
- continue
444
-
445
- if value.data.is_persistent_data():
446
- # list save persistent_data: [Tensor, shape, type, param.key]
447
- param_data = ["persistent_data"]
448
- param_data.append(value.data)
449
- param_data.append(value.param_info.origin_shape)
450
- param_data.append(str(value.dtype))
451
- param_data.append(value.key)
452
- elif value.data.offload_file_path() != "":
453
- # list save offload data: [Param, shape, type, param.key]
454
- param_data = ["offload_parameter"]
455
- param_tensor = value.data
456
- if key in parameter_layout_dict:
457
- param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
458
- integrated_save)
459
- param_data.append(param_tensor)
460
- param_data.append(param_tensor.shape)
461
- param_data.append(str(param_tensor.dtype))
462
- param_data.append(value.key)
463
- else:
464
- param_data = Tensor(value.data.asnumpy())
465
-
466
- # in automatic model parallel scenario, some parameters were split to all the devices,
467
- # which should be combined before saving
468
- if key in parameter_layout_dict:
469
- param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
470
- integrated_save)
471
-
472
- each_param["data"] = param_data
473
- param_list.append(each_param)
474
- save_obj = param_list
422
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
475
423
 
476
424
  if append_dict:
477
425
  append_info_list = []
@@ -479,7 +427,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
479
427
  if not isinstance(value, str):
480
428
  value = Tensor(value)
481
429
  append_info_list.append({"name": k_name, "data": value})
482
- save_obj.extend(append_info_list)
430
+ save_obj.extend(append_info_list)
483
431
 
484
432
  data_list = OrderedDict()
485
433
  with _ckpt_mutex:
@@ -530,6 +478,124 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
530
478
  logger.info("Saving checkpoint process is finished.")
531
479
 
532
480
 
481
+ def _convert_list_to_param_list(save_obj, choice_func):
482
+ """Convert a list of Parameter to param_list."""
483
+ param_list = []
484
+ if not save_obj:
485
+ return param_list
486
+ if isinstance(save_obj[0], dict):
487
+ param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
488
+ else:
489
+ for param in save_obj:
490
+ if isinstance(param, Parameter):
491
+ if choice_func is not None and not choice_func(param.name):
492
+ continue
493
+ each_param = {"name": param.name, "data": param}
494
+ param_list.append(each_param)
495
+ else:
496
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
497
+ f"the param should be parameter, but got {type(param)}")
498
+ return param_list
499
+
500
+
501
+ def _convert_dict_to_param_dict(save_obj, choice_func):
502
+ """Convert a dict of Parameter to param_list."""
503
+ param_list = []
504
+ for (key, value) in save_obj.items():
505
+ if isinstance(key, str) and isinstance(value, (Parameter, str)):
506
+ if choice_func is not None and not choice_func(key):
507
+ continue
508
+ each_param = {"name": key, "data": value}
509
+ param_list.append(each_param)
510
+ else:
511
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
512
+ f"value should be Parameter, but got the type of key is {type(key)} and"
513
+ f"the type of value is {type(value)}")
514
+ return param_list
515
+
516
+
517
+ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
518
+ """Convert cell.parameters_and_names to OrderedDict."""
519
+ param_dict = OrderedDict()
520
+ for _, param in save_obj.parameters_and_names():
521
+ not_sliced = not param.sliced
522
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
523
+ # All parameters are initialized immediately under PyNative mode, skip this judgement.
524
+ judgment = not_sliced or param.has_init
525
+ if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
526
+ continue
527
+ if choice_func is not None and not choice_func(param.name):
528
+ continue
529
+ # Add suffix for cache_enabled parameter, and then parameter can carry key info.
530
+ # Notice that suffix needs be removed when loading into net.
531
+ if param.cache_enable:
532
+ param_dict[param.name + ".__param_key__" + str(param.key)] = param
533
+ else:
534
+ param_dict[param.name] = param
535
+ return param_dict
536
+
537
+
538
+ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
539
+ """Convert nn.Cell to param_list."""
540
+ param_list = []
541
+ parameter_layout_dict = save_obj.parameter_layout_dict
542
+ if _is_in_auto_parallel_mode() and not parameter_layout_dict:
543
+ parameter_layout_dict = _get_parameter_layout()
544
+ if not _is_in_auto_parallel_mode():
545
+ save_obj.init_parameters_data()
546
+ param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
547
+ if append_dict and "random_op" in append_dict:
548
+ phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
549
+ if phase in save_obj.compile_cache and _executor.has_compiled(phase):
550
+ random_byte = _executor._graph_executor.get_random_status(phase)
551
+ param_list.append({"name": "random_op", "data": random_byte})
552
+ append_dict.pop("random_op")
553
+ for (key, value) in param_dict.items():
554
+ each_param = {"name": key}
555
+ if isinstance(value, MapParameter):
556
+ each_param["data"] = value
557
+ param_list.append(each_param)
558
+ continue
559
+
560
+ if value.data.is_persistent_data():
561
+ # list save persistent_data: [Tensor, shape, type, param.key]
562
+ param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
563
+ elif value.data.offload_file_path() != "":
564
+ # list save offload data: [Param, shape, type, param.key]
565
+ param_data = ["offload_parameter"]
566
+ param_tensor = value.data
567
+ if key in parameter_layout_dict:
568
+ param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
569
+ integrated_save)
570
+ param_data.append(param_tensor)
571
+ param_data.append(param_tensor.shape)
572
+ param_data.append(str(param_tensor.dtype))
573
+ param_data.append(value.key)
574
+ else:
575
+ param_data = Tensor(value.data.asnumpy())
576
+
577
+ # in automatic model parallel scenario, some parameters were split to all the devices,
578
+ # which should be combined before saving
579
+ if key in parameter_layout_dict:
580
+ param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
581
+ integrated_save)
582
+
583
+ each_param["data"] = param_data
584
+ param_list.append(each_param)
585
+ return param_list
586
+
587
+
588
+ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
589
+ """Convert a save_obj to param_list."""
590
+ if isinstance(save_obj, list):
591
+ return _convert_list_to_param_list(save_obj, choice_func)
592
+
593
+ if isinstance(save_obj, dict):
594
+ return _convert_dict_to_param_dict(save_obj, choice_func)
595
+
596
+ return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
597
+
598
+
533
599
  def _save_param_list_data(data_list, key, param):
534
600
  """Save persistent data into save_obj."""
535
601
  dims = []
@@ -585,7 +651,7 @@ def load(file_name, **kwargs):
585
651
 
586
652
  - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
587
653
  `obfuscate_model()
588
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore/mindspore.obfuscate_model.html>`_.
654
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
589
655
 
590
656
  Returns:
591
657
  GraphCell, a compiled graph that can executed by `GraphCell`.
@@ -615,7 +681,7 @@ def load(file_name, **kwargs):
615
681
 
616
682
  Tutorial Examples:
617
683
  - `Saving and Loading the Model - Saving and Loading MindIR
618
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-mindir>`_
684
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
619
685
  """
620
686
  if not isinstance(file_name, str):
621
687
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
@@ -656,7 +722,7 @@ def load(file_name, **kwargs):
656
722
  return graph
657
723
 
658
724
 
659
- def export_split_mindir(file_name):
725
+ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
660
726
  """
661
727
  Auto Split MindIR.
662
728
 
@@ -664,6 +730,10 @@ def export_split_mindir(file_name):
664
730
 
665
731
  Args:
666
732
  file_name (str): MindIR file name.
733
+ device_num (int): device number.
734
+ rank_id (int): rank id.
735
+ dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
736
+ sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
667
737
 
668
738
  Raises:
669
739
  ValueError: MindIR file does not exist or `file_name` is not a string.
@@ -671,11 +741,9 @@ def export_split_mindir(file_name):
671
741
 
672
742
  Examples:
673
743
  >>> import mindspore as ms
674
- >>> from mindspore.communication import init
675
744
  >>> context.set_context(mode=context.GRAPH_MODE)
676
745
  >>>
677
- >>> init(backend_name="hccl")
678
- >>> ms.export_split_mindir("net.mindir")
746
+ >>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
679
747
 
680
748
  """
681
749
  if not isinstance(file_name, str):
@@ -690,8 +758,11 @@ def export_split_mindir(file_name):
690
758
  file_name = os.path.abspath(file_name)
691
759
 
692
760
  logger.info("Execute the process of export and split mindir.")
693
-
694
- graph = split_mindir(file_name)
761
+ dynamic = True
762
+ if dynamic:
763
+ graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
764
+ else:
765
+ graph = split_mindir(file_name)
695
766
 
696
767
  if graph is None:
697
768
  if _is_cipher_file(file_name):
@@ -779,17 +850,20 @@ def obfuscate_model(obf_config, **kwargs):
779
850
  - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
780
851
  is the same as using :func:`mindspore.export`.
781
852
  - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
782
- should be in range of (0, 1] or in ["small", "medium", "large"].
853
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
854
+ correspond to 0.1, 0.3, and 0.6 respectively.
783
855
  - customized_func (function): A python function used for customized function mode, which used for control
784
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
785
- function needs to ensure that its result is constant for any input. Users can refer to opaque
856
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
857
+ Reference to 'my_func()' in
858
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
859
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
786
860
  predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
787
861
  when loading obfuscated model.
788
- - obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
789
- weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
790
- then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should be
791
- noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
792
- would be applied if both of them are set.
862
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
863
+ structure of obfuscated models corresponding to different random seeds is different. If
864
+ `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
865
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
866
+ be set, and the latter mode would be applied if both of them are set.
793
867
 
794
868
  kwargs (dict): Configuration options dictionary.
795
869
 
@@ -928,27 +1002,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
928
1002
  >>> print(param_dict["conv2.weight"])
929
1003
  Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
930
1004
  >>> def func(param_name):
931
- >>> whether_load = False
932
- >>> if param_name.startswith("conv"):
933
- >>> whether_load = True
934
- >>> if param_name.startswith("conv1"):
935
- >>> whether_load = False
936
- >>> return whether_load
1005
+ ... whether_load = False
1006
+ ... if param_name.startswith("conv"):
1007
+ ... whether_load = True
1008
+ ... if param_name.startswith("conv1"):
1009
+ ... whether_load = False
1010
+ ... return whether_load
937
1011
  >>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
938
1012
  >>> print(param_dict1["conv2.weight"])
939
1013
  Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
940
1014
  >>> def func(param_name):
941
- >>> whether_load = False
942
- >>> if param_name.startswith("conv1"):
943
- >>> whether_load = True
944
- >>> return whether_load
1015
+ ... whether_load = False
1016
+ ... if param_name.startswith("conv1"):
1017
+ ... whether_load = True
1018
+ ... return whether_load
945
1019
  >>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
946
1020
  >>> print(param_dict2)
947
1021
  {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
948
1022
 
949
1023
  Tutorial Examples:
950
1024
  - `Saving and Loading the Model - Saving and Loading the Model Weight
951
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1025
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
952
1026
  """
953
1027
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
954
1028
  specify_prefix = _check_prefix(specify_prefix)
@@ -979,8 +1053,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
979
1053
  choice_func is not None and not choice_func(element.tag):
980
1054
  continue
981
1055
  if element.tensor.ByteSize() == 0:
982
- _load_map_parameter(checkpoint_list, element, element_id,
983
- map_data_list, map_shape_list, parameter_dict)
1056
+ _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
984
1057
  if element.tag in parameter_dict:
985
1058
  map_data_list = [[], [], []]
986
1059
  map_shape_list = [0, 0, 0]
@@ -1024,8 +1097,12 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1024
1097
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
1025
1098
  f"'filter_prefix' or 'specify_prefix' are set correctly.")
1026
1099
 
1100
+ if _warm_up_host_cache_enabled(parameter_dict):
1101
+ (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
1027
1102
  if net is not None:
1028
1103
  load_param_into_net(net, parameter_dict, strict_load)
1104
+ if _warm_up_host_cache_enabled(parameter_dict):
1105
+ _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1029
1106
 
1030
1107
  return parameter_dict
1031
1108
 
@@ -1061,7 +1138,7 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1061
1138
 
1062
1139
 
1063
1140
  def _check_ckpt_file_name(ckpt_file_name):
1064
- """Check function load_checkpoint's cket_file_name."""
1141
+ """Check function load_checkpoint's ckpt_file_name."""
1065
1142
  if not isinstance(ckpt_file_name, str):
1066
1143
  raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
1067
1144
  "but got {}.".format(type(ckpt_file_name)))
@@ -1175,7 +1252,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1175
1252
  >>> import mindspore as ms
1176
1253
  >>>
1177
1254
  >>> # Define the network structure of LeNet5. Refer to
1178
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
1255
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1179
1256
  >>> net = LeNet5()
1180
1257
  >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1181
1258
  >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
@@ -1185,7 +1262,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1185
1262
 
1186
1263
  Tutorial Examples:
1187
1264
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1188
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1265
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1189
1266
  """
1190
1267
  if not isinstance(net, nn.Cell):
1191
1268
  logger.critical("Failed to combine the net and the parameters.")
@@ -1219,6 +1296,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1219
1296
  if isinstance(param, MapParameter):
1220
1297
  param.import_data(parameter_dict[param.name])
1221
1298
  continue
1299
+ # Add has attr protection when load server checkpoint file on worker.
1300
+ if not hasattr(parameter_dict[param.name], "data"):
1301
+ continue
1222
1302
  new_param = copy.deepcopy(parameter_dict[param.name])
1223
1303
  _update_param(param, new_param, strict_load)
1224
1304
  ckpt_not_load.remove(param.name)
@@ -1243,6 +1323,72 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1243
1323
  return param_not_load, ckpt_not_load
1244
1324
 
1245
1325
 
1326
+ def _warm_up_host_cache_enabled(parameter_dict):
1327
+ """Warm up host cache enabled."""
1328
+ if _cache_enable():
1329
+ return True
1330
+ for key in parameter_dict.keys():
1331
+ if key.find(".__param_key__") != -1:
1332
+ return True
1333
+ return False
1334
+
1335
+
1336
+ def _warm_up_host_cache(parameter_dict, net):
1337
+ """Warm up host cache."""
1338
+ ms_role = os.getenv("MS_ROLE")
1339
+ is_worker = ms_role == "MS_WORKER"
1340
+ param_key_dict = {}
1341
+ # Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
1342
+ if is_worker:
1343
+ net.init_parameters_data()
1344
+ net_dict = {}
1345
+ for name, value in net.parameters_and_names():
1346
+ net_dict[name] = value
1347
+ for param_name, value in parameter_dict.items():
1348
+ pos = param_name.find(".__param_key__")
1349
+ if pos != -1:
1350
+ net_param_name = param_name[:pos]
1351
+ param_key_dict[param_name] = net_param_name
1352
+ net_value = None
1353
+ if net_param_name not in net_dict:
1354
+ logger.warning("net param name : %s is not in net", net_param_name)
1355
+ else:
1356
+ net_value = net_dict.get(net_param_name, None)
1357
+ pos += len(".__param_key__")
1358
+ param_key = int(param_name[pos:])
1359
+ value_is_map_parameter = isinstance(value, list) and len(value) == 3
1360
+ if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
1361
+ key_tensor = Tensor.from_numpy(value[0])
1362
+ value_tensor = Tensor.from_numpy(value[1])
1363
+ status_tensor = Tensor.from_numpy(value[2])
1364
+ _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
1365
+ elif not isinstance(value, list) and isinstance(net_value, Parameter):
1366
+ _store_warm_up_ptr_by_tensor(param_key, value)
1367
+ else:
1368
+ logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
1369
+ else:
1370
+ for param_name, value in parameter_dict.items():
1371
+ pos = param_name.find(".__param_key__")
1372
+ if pos != -1:
1373
+ net_param_name = param_name[:pos]
1374
+ param_key_dict[param_name] = net_param_name
1375
+ # Split param key from parameter_dict since worker cannot load param key.
1376
+ warm_up_dict = {}
1377
+ for key, value in param_key_dict.items():
1378
+ if is_worker:
1379
+ warm_up_dict[value] = parameter_dict.pop(key)
1380
+ else:
1381
+ parameter_dict[value] = parameter_dict.pop(key)
1382
+ return (is_worker, parameter_dict, warm_up_dict)
1383
+
1384
+
1385
+ def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
1386
+ """Warm up host cache post process."""
1387
+ if is_worker:
1388
+ net_dict.update(warm_up_dict)
1389
+ _set_checkpoint_load_status(True)
1390
+
1391
+
1246
1392
  def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
1247
1393
  """When some net parameter did not load, try to continue loading."""
1248
1394
  prefix_name = ""
@@ -1350,9 +1496,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1350
1496
  Note:
1351
1497
  1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1352
1498
  2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
1353
- 3. Exporting functions decorated with 'jit' to mindir format is supported.
1354
- 4. When exporting a function decorated with 'jit', the function should not involve class properties in
1355
- calculations.
1499
+ 3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1500
+ 4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1501
+ class properties in calculations.
1356
1502
 
1357
1503
  Args:
1358
1504
  net (Union[Cell, function]): MindSpore network.
@@ -1388,17 +1534,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1388
1534
 
1389
1535
  - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1390
1536
  - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1391
- should be in range of (0, 1] or in ["small", "medium", "large"].
1537
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1538
+ correspond to 0.1, 0.3, and 0.6 respectively.
1392
1539
  - customized_func (function): A python function used for customized function mode, which used for control
1393
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
1394
- function needs to ensure that its result is constant for any input. Users can refer to opaque
1540
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1541
+ Reference to 'my_func()' in
1542
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
1543
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
1395
1544
  predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1396
1545
  obfuscated model.
1397
- - obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
1398
- weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
1399
- then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should
1400
- be noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
1401
- would be applied if both of them are set.
1546
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1547
+ structure of obfuscated models corresponding to different random seeds is different. If
1548
+ `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
1549
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1550
+ be set, and the latter mode would be applied if both of them are set.
1402
1551
 
1403
1552
  - incremental (bool): export MindIR incrementally.
1404
1553
 
@@ -1408,14 +1557,14 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1408
1557
  >>> from mindspore import Tensor
1409
1558
  >>>
1410
1559
  >>> # Define the network structure of LeNet5. Refer to
1411
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
1560
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1412
1561
  >>> net = LeNet5()
1413
1562
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1414
1563
  >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
1415
1564
 
1416
1565
  Tutorial Examples:
1417
1566
  - `Saving and Loading the Model - Saving and Loading MindIR
1418
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-mindir>`_
1567
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
1419
1568
  """
1420
1569
  old_ms_jit_value = context.get_context("jit_syntax_level")
1421
1570
  context.set_context(jit_syntax_level=mindspore.STRICT)
@@ -1475,7 +1624,7 @@ def _get_funcgraph(net, *inputs):
1475
1624
  >>> from mindspore import Tensor
1476
1625
  >>>
1477
1626
  >>> # Define the network structure of LeNet5. Refer to
1478
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
1627
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1479
1628
  >>> net = LeNet5()
1480
1629
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1481
1630
  >>> ms.get_funcgraph(net, input_tensor)
@@ -1660,7 +1809,7 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1660
1809
  for param_proto in model.graph.parameter:
1661
1810
  name = param_proto.name[param_proto.name.find(":") + 1:]
1662
1811
  param = net_dict[name]
1663
- raw_data = param.data.asnumpy().tobytes()
1812
+ raw_data = param.data.get_bytes()
1664
1813
  data_length = len(raw_data)
1665
1814
  append_size = 0
1666
1815
  if data_length % 64 != 0:
@@ -1787,7 +1936,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
1787
1936
  for param_proto in model.graph.parameter:
1788
1937
  param_name = param_proto.name[param_proto.name.find(":") + 1:]
1789
1938
  if param_name in net_dict.keys():
1790
- param_data = net_dict[param_name].data.asnumpy().tobytes()
1939
+ param_data = net_dict[param_name].data.get_bytes()
1791
1940
  param_proto.raw_data = param_data
1792
1941
  else:
1793
1942
  raise ValueError("The parameter '{}' is not belongs to any cell,"
@@ -1797,10 +1946,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
1797
1946
  map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
1798
1947
  if map_param_name in net_dict.keys():
1799
1948
  map_parameter = net_dict[map_param_name]
1800
- key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
1801
- map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
1802
- map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
1803
- map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
1949
+ key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
1950
+ map_param_proto.key_tensor.raw_data = key_bytes
1951
+ map_param_proto.value_tensor.raw_data = value_bytes
1952
+ map_param_proto.status_tensor.raw_data = status_bytes
1804
1953
  else:
1805
1954
  raise ValueError("The map_parameter '{}' is not belongs to any cell,"
1806
1955
  "the data of parameter cannot be exported.".format(map_param_proto.name))
@@ -1831,7 +1980,7 @@ def _save_together(net_dict, model):
1831
1980
  for param_proto in model.graph.parameter:
1832
1981
  name = param_proto.name[param_proto.name.find(":") + 1:]
1833
1982
  if name in net_dict.keys():
1834
- data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
1983
+ data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
1835
1984
  else:
1836
1985
  raise ValueError("The parameter '{}' is not belongs to any cell,"
1837
1986
  "the data of parameter cannot be exported.".format(param_proto.name))
@@ -1862,7 +2011,7 @@ def _save_dataset_to_mindir(model, dataset):
1862
2011
 
1863
2012
  def parse_print(print_file_name):
1864
2013
  """
1865
- Parse data file generated by mindspore.ops.Print.
2014
+ Parse data file generated by :class:`mindspore.ops.Print`.
1866
2015
 
1867
2016
  Args:
1868
2017
  print_file_name (str): The file name needs to be parsed.
@@ -2039,8 +2188,8 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
2039
2188
  def restore_group_info_list(group_info_file_name):
2040
2189
  """
2041
2190
  Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2042
- who saves the group_info_file_name. To save the group info file, please export GROUP_INFO_FILE environment variables
2043
- like "export GROUP_INFO_FILE=/data/group_info.pb".
2191
+ who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2192
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2044
2193
 
2045
2194
  Args:
2046
2195
  group_info_file_name (str): Name of group information file.
@@ -2050,7 +2199,7 @@ def restore_group_info_list(group_info_file_name):
2050
2199
 
2051
2200
  Raises:
2052
2201
  ValueError: group information file is incorrect.
2053
- TypeError: group_info_file_name is not str.
2202
+ TypeError: `group_info_file_name` is not str.
2054
2203
 
2055
2204
  Examples:
2056
2205
  >>> import mindspore as ms
@@ -2072,9 +2221,6 @@ def restore_group_info_list(group_info_file_name):
2072
2221
  def build_searched_strategy(strategy_filename):
2073
2222
  """
2074
2223
  Build strategy of every parameter in network. Used in the case of distributed inference.
2075
- For details of it, please check:
2076
- `Saving and Loading Models in Hybrid Parallel Mode
2077
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
2078
2224
 
2079
2225
  Args:
2080
2226
  strategy_filename (str): Name of strategy file.
@@ -2096,8 +2242,6 @@ def build_searched_strategy(strategy_filename):
2096
2242
  def merge_sliced_parameter(sliced_parameters, strategy=None):
2097
2243
  """
2098
2244
  Merge parameter slices into one parameter. Used in the case of distributed inference.
2099
- For details of it, please check:
2100
- `<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
2101
2245
 
2102
2246
  Args:
2103
2247
  sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
@@ -2191,9 +2335,6 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2191
2335
  train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
2192
2336
  """
2193
2337
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2194
- For details of distributed inference, please check:
2195
- `Distributed Inference
2196
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/distributed_inference.html>`_ .
2197
2338
 
2198
2339
  Args:
2199
2340
  network (Cell): Network for distributed predication.
@@ -2218,6 +2359,104 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2218
2359
  Raises:
2219
2360
  TypeError: The type of inputs do not match the requirements.
2220
2361
  ValueError: Failed to load checkpoint into net.
2362
+
2363
+ Supported Platforms:
2364
+ ``Ascend`` ``GPU``
2365
+
2366
+ Examples:
2367
+ .. note::
2368
+ Before running the following examples, you need to configure the communication environment variables.
2369
+
2370
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2371
+ Please see the `rank table startup
2372
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
2373
+ for more details.
2374
+
2375
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2376
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
2377
+
2378
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2379
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
2380
+
2381
+ >>> import os
2382
+ >>> import numpy as np
2383
+ >>> import mindspore as ms
2384
+ >>> import mindspore.dataset as ds
2385
+ >>> from mindspore import nn, ops, train
2386
+ >>> from mindspore.communication import init
2387
+ >>>
2388
+ >>> step_per_epoch = 4
2389
+ >>> device_num = 8
2390
+ >>>
2391
+ >>> # Define the network structure.
2392
+ >>> class Net(nn.Cell):
2393
+ ... def __init__(self, matmul_size, strategy=None):
2394
+ ... super().__init__()
2395
+ ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
2396
+ ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
2397
+ ... self.matmul = ops.MatMul()
2398
+ ... self.neg = ops.Neg()
2399
+ ... if strategy is not None:
2400
+ ... self.matmul.shard(strategy)
2401
+ ...
2402
+ ... def construct(self, inputs):
2403
+ ... x = self.matmul(inputs, self.matmul_weight)
2404
+ ... x = self.neg(x)
2405
+ ... return x
2406
+ >>>
2407
+ >>> # Create dataset.
2408
+ >>> def get_dataset(*inputs):
2409
+ ... def generate():
2410
+ ... for _ in range(step_per_epoch):
2411
+ ... yield inputs
2412
+ ... return generate
2413
+ >>>
2414
+ >>> # Train network and save distributed checkpoint.
2415
+ >>> def train_net():
2416
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2417
+ ... init()
2418
+ ... np.random.seed(1)
2419
+ ... input_data = np.random.rand(16, 96).astype(np.float32)
2420
+ ... label_data = np.random.rand(16, 16).astype(np.float32)
2421
+ ... fake_dataset = get_dataset(input_data, label_data)
2422
+ ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
2423
+ ...
2424
+ ... # Set parallel strategy.
2425
+ ... strategy = ((1, 4), (4, 1))
2426
+ ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
2427
+ ... strategy_ckpt_save_file="./train_strategy.ckpt")
2428
+ ... network = Net(matmul_size=(96, 16), strategy=strategy)
2429
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
2430
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
2431
+ ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
2432
+ ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
2433
+ ... global_rank_id = int(os.getenv("RANK_ID"))
2434
+ ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
2435
+ ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
2436
+ ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
2437
+ ... ms.reset_auto_parallel_context()
2438
+ >>>
2439
+ >>> # Load distributed checkpoint and test.
2440
+ >>> def load_model():
2441
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2442
+ ... init()
2443
+ ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
2444
+ ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
2445
+ ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
2446
+ ... network = Net(matmul_size=(96, 16))
2447
+ ... model = ms.Model(network)
2448
+ ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
2449
+ ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
2450
+ ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
2451
+ ... predict_result = model.predict(predict_data)
2452
+ ... print(predict_result)
2453
+ >>>
2454
+ >>> train_net()
2455
+ >>> load_model()
2456
+ [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
2457
+ [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
2458
+ ...
2459
+ [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2221
2460
  """
2222
2461
  network = Validator.check_isinstance("network", network, nn.Cell)
2223
2462
  _check_checkpoint_file(checkpoint_filenames)
@@ -2395,7 +2634,8 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
2395
2634
  return merged_param
2396
2635
  param_name = merged_param.name
2397
2636
  tensor_layout = predict_strategy[param_name]
2398
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1])
2637
+ rank = get_rank()
2638
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
2399
2639
  requires_grad = merged_param.requires_grad
2400
2640
  layerwise_parallel = merged_param.layerwise_parallel
2401
2641
  split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)