mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import os
23
23
  import shutil
24
24
  import stat
25
25
  import threading
26
- from threading import Thread, Lock
26
+ from threading import Thread, RLock
27
27
  from collections import defaultdict, OrderedDict
28
28
  from io import BytesIO
29
29
 
@@ -48,7 +48,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor
48
48
  from mindspore.common.api import _get_parameter_layout
49
49
  from mindspore.common.api import _generate_branch_control_input
50
50
  from mindspore.common.initializer import initializer, One
51
- from mindspore.common.parameter import Parameter
51
+ from mindspore.common.parameter import Parameter, _offload_if_config
52
52
  from mindspore.common.tensor import Tensor
53
53
  from mindspore.common._utils import is_shape_unknown
54
54
  from mindspore.communication.management import get_rank, get_group_size
@@ -59,8 +59,11 @@ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_
59
59
  from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
60
60
  from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
61
61
  _restore_group_info_list
62
+ from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
63
+ _store_warm_up_ptr_by_tensor_list, _cache_enable
62
64
  from mindspore.train._utils import read_proto
63
- from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
65
+ from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
66
+ split_mindir, split_dynamic_mindir
64
67
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
65
68
 
66
69
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
@@ -72,11 +75,13 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
72
75
  "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
73
76
  "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
74
77
 
78
+ np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
79
+
75
80
  mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
76
81
  5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
77
82
  11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
78
83
 
79
- _ckpt_mutex = Lock()
84
+ _ckpt_mutex = RLock()
80
85
 
81
86
  # unit is KB
82
87
  SLICE_SIZE = 512 * 1024
@@ -124,7 +129,7 @@ def _update_param(param, new_param, strict_load):
124
129
  if param.data.dtype != new_param.data.dtype:
125
130
  if _type_convert(param, new_param, strict_load):
126
131
  new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
127
- param.set_data(new_tensor)
132
+ param.set_data(new_tensor, param.sliced)
128
133
  return
129
134
 
130
135
  logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
@@ -205,7 +210,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
205
210
  logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
206
211
 
207
212
 
208
- def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
213
+ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False):
209
214
  """Execute the process of saving checkpoint into file."""
210
215
  try:
211
216
  with _ckpt_mutex:
@@ -213,37 +218,28 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
213
218
  os.chmod(ckpt_file_name, stat.S_IWUSR)
214
219
  os.remove(ckpt_file_name)
215
220
  with open(ckpt_file_name, "ab") as f:
221
+ plain_data = None
216
222
  if enc_key is not None:
217
223
  plain_data = BytesIO()
218
224
 
219
225
  for name, value in data_list.items():
226
+ if name == "random_op":
227
+ _write_random_seed(name, value, f)
228
+ continue
220
229
  if value[0] == "mapparameter":
221
- _write_mapparameter(name, value, f)
230
+ _write_mapparameter(name, value, f, map_param_inc)
231
+ continue
232
+ if value[0] == "offload_parameter":
233
+ new_value = value[1:]
234
+ new_value[2] = value[3].asnumpy().reshape(-1)
235
+ _write_parameter_data(name, new_value, f, enc_key, plain_data)
236
+ _offload_if_config(value[3])
222
237
  continue
223
238
  if isinstance(value[2], Tensor):
224
239
  _write_hugeparameter(name, value, f)
225
240
  continue
226
241
 
227
- data_size = value[2].nbytes / 1024
228
- if data_size > SLICE_SIZE:
229
- slice_count = math.ceil(data_size / SLICE_SIZE)
230
- param_slice_list = np.array_split(value[2], slice_count)
231
- else:
232
- param_slice_list = [value[2]]
233
-
234
- for param_slice in param_slice_list:
235
- checkpoint_list = Checkpoint()
236
- param_value = checkpoint_list.value.add()
237
- param_value.tag = name
238
- param_tensor = param_value.tensor
239
- param_tensor.dims.extend(value[0])
240
- param_tensor.tensor_type = value[1]
241
- param_tensor.tensor_content = param_slice.tobytes()
242
-
243
- if enc_key is None:
244
- f.write(checkpoint_list.SerializeToString())
245
- else:
246
- plain_data.write(checkpoint_list.SerializeToString())
242
+ _write_parameter_data(name, value, f, enc_key, plain_data)
247
243
 
248
244
  if enc_key is not None:
249
245
  plain_data.seek(0)
@@ -261,18 +257,59 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
261
257
  raise e
262
258
 
263
259
 
264
- def _write_mapparameter(name, value, f):
265
- """Write map parameter into protobuf file."""
260
+ def _write_random_seed(name, value, f):
261
+ """Write random op into protobuf file."""
266
262
  checkpoint_list = Checkpoint()
267
263
  param_value = checkpoint_list.value.add()
268
264
  param_value.tag = name
269
- map_tensor = param_value.maptensor
270
- for v in value[1:]:
271
- tensor = map_tensor.tensor.add()
272
- tensor.dims.extend(v[0])
273
- tensor.tensor_type = v[1]
274
- tensor.tensor_content = v[2].tobytes()
265
+ param_tensor = param_value.tensor
266
+ param_tensor.dims.extend(0)
267
+ param_tensor.tensor_type = "random_op"
268
+ param_tensor.tensor_content = value
269
+ f.write(checkpoint_list.SerializeToString())
270
+
271
+
272
+ def _write_parameter_data(name, value, f, enc_key, plain_data):
273
+ """Write parameter data into protobuf file."""
274
+ data_size = value[2].nbytes / 1024
275
+ if data_size > SLICE_SIZE:
276
+ slice_count = math.ceil(data_size / SLICE_SIZE)
277
+ param_slice_list = np.array_split(value[2], slice_count)
278
+ else:
279
+ param_slice_list = [value[2]]
280
+
281
+ for param_slice in param_slice_list:
282
+ checkpoint_list = Checkpoint()
283
+ param_value = checkpoint_list.value.add()
284
+ param_value.tag = name
285
+ param_tensor = param_value.tensor
286
+ param_tensor.dims.extend(value[0])
287
+ param_tensor.tensor_type = value[1]
288
+ param_tensor.tensor_content = param_slice.tobytes()
289
+
290
+ if enc_key is None:
291
+ f.write(checkpoint_list.SerializeToString())
292
+ else:
293
+ plain_data.write(checkpoint_list.SerializeToString())
294
+
295
+
296
+ def _write_mapparameter(name, value, f, map_param_inc=False):
297
+ """Write map parameter into protobuf file."""
298
+ while True:
299
+ logger.info("Checkpoint save map_parameter.")
300
+ data_map_slice = value[1].export_slice_data(map_param_inc)
301
+ checkpoint_list = Checkpoint()
302
+ param_value = checkpoint_list.value.add()
303
+ param_value.tag = name
304
+ map_tensor = param_value.maptensor
305
+ for numpy_data in data_map_slice[:3]:
306
+ tensor_pro = map_tensor.tensor.add()
307
+ tensor_pro.dims.extend(numpy_data.shape)
308
+ tensor_pro.tensor_type = str(numpy_data.dtype)
309
+ tensor_pro.tensor_content = numpy_data.reshape(-1).tobytes()
275
310
  f.write(checkpoint_list.SerializeToString())
311
+ if data_map_slice[3]:
312
+ break
276
313
 
277
314
 
278
315
  def _write_hugeparameter(name, value, f):
@@ -298,8 +335,8 @@ def _write_hugeparameter(name, value, f):
298
335
 
299
336
  def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
300
337
  """Check save_obj and ckpt_file_name for save_checkpoint."""
301
- if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
302
- raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell or list, "
338
+ if not isinstance(save_obj, (nn.Cell, list, dict)):
339
+ raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
303
340
  "but got {}.".format(type(save_obj)))
304
341
  if not isinstance(ckpt_file_name, str):
305
342
  raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
@@ -315,34 +352,63 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
315
352
 
316
353
 
317
354
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
318
- async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
319
- """
355
+ async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
356
+ r"""
320
357
  Save checkpoint to a specified file.
321
358
 
322
359
  Args:
323
- save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
324
- [{"name": param_name, "data": param_data},...], the type of
325
- param_name would be string, and the type of param_data would
326
- be parameter or Tensor).
360
+ save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
361
+ list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
362
+ elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
363
+ `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
364
+ it can be the returned value of `mindspore.load_checkpoint()`.
327
365
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
328
- integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
329
- async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: False
366
+ integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
367
+ async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
330
368
  append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
331
- of dict must be one of int, float, bool, string, Parameter or Tensor. Default: None.
332
- enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
333
- is not required. Default: None.
334
- enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
335
- mode, currently supports 'AES-GCM' and 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
369
+ of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
370
+ enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
371
+ is not required. Default: ``None`` .
372
+ enc_mode (str): This parameter is valid only when enc_key is not set to ``None`` . Specifies the encryption
373
+ mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
374
+ Default: ``"AES-GCM"`` .
375
+ choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
376
+ a parameter name in string type, and the returned value is a bool.
377
+ If returns ``True`` , the Parameter that matching the custom condition will be saved.
378
+ If returns ``False`` , the Parameter that not matching the custom condition will not
379
+ be saved. Default: ``None`` .
380
+ kwargs (dict): Configuration options dictionary.
336
381
 
337
382
  Raises:
338
- TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter `integrated_save`
339
- and `async_save` are not bool type. If the parameter ckpt_file_name is not string type.
383
+ TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
384
+ TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
385
+ TypeError: If the parameter `ckpt_file_name` is not string type.
340
386
 
341
387
  Examples:
342
388
  >>> import mindspore as ms
343
389
  >>>
344
- >>> net = Net()
345
- >>> ms.save_checkpoint(net, "lenet.ckpt")
390
+ >>> # Define the network structure of LeNet5. Refer to
391
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
392
+ >>> net = LeNet5()
393
+ >>> ms.save_checkpoint(net, "./lenet.ckpt",
394
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
395
+ >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
396
+ >>> print(param_dict1)
397
+ {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
398
+ >>> params_list = net.trainable_params()
399
+ >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
400
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
401
+ >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
402
+ >>> print(param_dict2)
403
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
404
+ >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
405
+ >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
406
+ >>> print(param_dict3)
407
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
408
+
409
+ Tutorial Examples:
410
+ - `Saving and Loading the Model - Saving and Loading the Model Weight
411
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
346
412
  """
347
413
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
348
414
  integrated_save = Validator.check_bool(integrated_save)
@@ -350,46 +416,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
350
416
  append_dict = _check_append_dict(append_dict)
351
417
  enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
352
418
  enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
353
-
419
+ map_param_inc = kwargs.get('incremental', False)
354
420
  logger.info("Execute the process of saving checkpoint files.")
355
421
 
356
- if isinstance(save_obj, nn.Cell):
357
- parameter_layout_dict = save_obj.parameter_layout_dict
358
- if _is_in_auto_parallel_mode() and not parameter_layout_dict:
359
- parameter_layout_dict = _get_parameter_layout()
360
- save_obj.init_parameters_data()
361
- param_dict = OrderedDict()
362
- for _, param in save_obj.parameters_and_names():
363
- param_dict[param.name] = param
364
- param_list = []
365
- for (key, value) in param_dict.items():
366
- each_param = {"name": key}
367
- if isinstance(value, MapParameter):
368
- param_data = []
369
- for export_data in value.export_data():
370
- param_data.append(Tensor(export_data))
371
- each_param["data"] = param_data
372
- param_list.append(each_param)
373
- continue
374
-
375
- if value.data.is_persistent_data():
376
- # list save persistent_data: [Tensor, shape, type, param.key]
377
- param_data = ["persistent_data"]
378
- param_data.append(value.data)
379
- param_data.append(value.param_info.origin_shape)
380
- param_data.append(str(value.dtype))
381
- param_data.append(value.key)
382
- else:
383
- param_data = Tensor(value.data.asnumpy())
384
-
385
- # in automatic model parallel scenario, some parameters were split to all the devices,
386
- # which should be combined before saving
387
- if key in parameter_layout_dict:
388
- param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data, integrated_save)
389
-
390
- each_param["data"] = param_data
391
- param_list.append(each_param)
392
- save_obj = param_list
422
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
393
423
 
394
424
  if append_dict:
395
425
  append_info_list = []
@@ -397,19 +427,27 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
397
427
  if not isinstance(value, str):
398
428
  value = Tensor(value)
399
429
  append_info_list.append({"name": k_name, "data": value})
400
- save_obj.extend(append_info_list)
430
+ save_obj.extend(append_info_list)
401
431
 
402
432
  data_list = OrderedDict()
403
433
  with _ckpt_mutex:
404
434
  for param in save_obj:
435
+ if param["name"] == "random_op":
436
+ data_list["random_op"] = param["data"]
437
+ continue
405
438
  key = param["name"]
406
439
  data_list[key] = []
440
+ if isinstance(param["data"], MapParameter):
441
+ data_list[param["name"]].append("mapparameter")
442
+ data_list[param["name"]].append(param["data"])
443
+ continue
407
444
  if isinstance(param["data"], list):
408
445
  if param["data"][0] == "persistent_data":
409
- _save_persistent_data(data_list, key, param)
410
- else:
411
- _save_mapparameter(data_list, param)
412
- continue
446
+ _save_param_list_data(data_list, key, param)
447
+ elif param["data"][0] == "offload_parameter":
448
+ data_list[key].append("offload_parameter")
449
+ _save_param_list_data(data_list, key, param)
450
+
413
451
  if isinstance(param["data"], str):
414
452
  data_list[key].append([0])
415
453
  data_list[key].append('str')
@@ -435,28 +473,130 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
435
473
  thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")
436
474
  thr.start()
437
475
  else:
438
- _exec_save(ckpt_file_name, data_list, enc_key, enc_mode)
476
+ _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
439
477
 
440
478
  logger.info("Saving checkpoint process is finished.")
441
479
 
442
480
 
443
- def _save_mapparameter(data_list, param):
444
- """Save map parameter into save_obj."""
445
- data_list[param["name"]].append("mapparameter")
446
- for value in param["data"]:
447
- dims = []
448
- tmp_list = []
449
- for dim in value.shape:
450
- dims.append(dim)
451
- tmp_list.append(dims)
452
- tensor_type = str(value.dtype)
453
- tmp_list.append(tensor_type)
454
- data = value.asnumpy().reshape(-1)
455
- tmp_list.append(data)
456
- data_list[param["name"]].append(tmp_list)
481
+ def _convert_list_to_param_list(save_obj, choice_func):
482
+ """Convert a list of Parameter to param_list."""
483
+ param_list = []
484
+ if not save_obj:
485
+ return param_list
486
+ if isinstance(save_obj[0], dict):
487
+ param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
488
+ else:
489
+ for param in save_obj:
490
+ if isinstance(param, Parameter):
491
+ if choice_func is not None and not choice_func(param.name):
492
+ continue
493
+ each_param = {"name": param.name, "data": param}
494
+ param_list.append(each_param)
495
+ else:
496
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
497
+ f"the param should be parameter, but got {type(param)}")
498
+ return param_list
499
+
500
+
501
+ def _convert_dict_to_param_dict(save_obj, choice_func):
502
+ """Convert a dict of Parameter to param_list."""
503
+ param_list = []
504
+ for (key, value) in save_obj.items():
505
+ if isinstance(key, str) and isinstance(value, (Parameter, str)):
506
+ if choice_func is not None and not choice_func(key):
507
+ continue
508
+ each_param = {"name": key, "data": value}
509
+ param_list.append(each_param)
510
+ else:
511
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
512
+ f"value should be Parameter, but got the type of key is {type(key)} and"
513
+ f"the type of value is {type(value)}")
514
+ return param_list
515
+
516
+
517
+ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
518
+ """Convert cell.parameters_and_names to OrderedDict."""
519
+ param_dict = OrderedDict()
520
+ for _, param in save_obj.parameters_and_names():
521
+ not_sliced = not param.sliced
522
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
523
+ # All parameters are initialized immediately under PyNative mode, skip this judgement.
524
+ judgment = not_sliced or param.has_init
525
+ if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
526
+ continue
527
+ if choice_func is not None and not choice_func(param.name):
528
+ continue
529
+ # Add suffix for cache_enabled parameter, and then parameter can carry key info.
530
+ # Notice that suffix needs be removed when loading into net.
531
+ if param.cache_enable:
532
+ param_dict[param.name + ".__param_key__" + str(param.key)] = param
533
+ else:
534
+ param_dict[param.name] = param
535
+ return param_dict
457
536
 
458
537
 
459
- def _save_persistent_data(data_list, key, param):
538
+ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
539
+ """Convert nn.Cell to param_list."""
540
+ param_list = []
541
+ parameter_layout_dict = save_obj.parameter_layout_dict
542
+ if _is_in_auto_parallel_mode() and not parameter_layout_dict:
543
+ parameter_layout_dict = _get_parameter_layout()
544
+ if not _is_in_auto_parallel_mode():
545
+ save_obj.init_parameters_data()
546
+ param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
547
+ if append_dict and "random_op" in append_dict:
548
+ phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
549
+ if phase in save_obj.compile_cache and _executor.has_compiled(phase):
550
+ random_byte = _executor._graph_executor.get_random_status(phase)
551
+ param_list.append({"name": "random_op", "data": random_byte})
552
+ append_dict.pop("random_op")
553
+ for (key, value) in param_dict.items():
554
+ each_param = {"name": key}
555
+ if isinstance(value, MapParameter):
556
+ each_param["data"] = value
557
+ param_list.append(each_param)
558
+ continue
559
+
560
+ if value.data.is_persistent_data():
561
+ # list save persistent_data: [Tensor, shape, type, param.key]
562
+ param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
563
+ elif value.data.offload_file_path() != "":
564
+ # list save offload data: [Param, shape, type, param.key]
565
+ param_data = ["offload_parameter"]
566
+ param_tensor = value.data
567
+ if key in parameter_layout_dict:
568
+ param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
569
+ integrated_save)
570
+ param_data.append(param_tensor)
571
+ param_data.append(param_tensor.shape)
572
+ param_data.append(str(param_tensor.dtype))
573
+ param_data.append(value.key)
574
+ else:
575
+ param_data = Tensor(value.data.asnumpy())
576
+
577
+ # in automatic model parallel scenario, some parameters were split to all the devices,
578
+ # which should be combined before saving
579
+ if key in parameter_layout_dict:
580
+ param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
581
+ integrated_save)
582
+
583
+ each_param["data"] = param_data
584
+ param_list.append(each_param)
585
+ return param_list
586
+
587
+
588
+ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
589
+ """Convert a save_obj to param_list."""
590
+ if isinstance(save_obj, list):
591
+ return _convert_list_to_param_list(save_obj, choice_func)
592
+
593
+ if isinstance(save_obj, dict):
594
+ return _convert_dict_to_param_dict(save_obj, choice_func)
595
+
596
+ return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
597
+
598
+
599
+ def _save_param_list_data(data_list, key, param):
460
600
  """Save persistent data into save_obj."""
461
601
  dims = []
462
602
  # persistent_data shape can not be ()
@@ -511,7 +651,7 @@ def load(file_name, **kwargs):
511
651
 
512
652
  - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
513
653
  `obfuscate_model()
514
- <https://www.mindspore.cn/docs/en/r2.0/api_python/mindspore/mindspore.obfuscate_model.html>` .
654
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
515
655
 
516
656
  Returns:
517
657
  GraphCell, a compiled graph that can executed by `GraphCell`.
@@ -538,6 +678,10 @@ def load(file_name, **kwargs):
538
678
  [[[[4. 6. 4.]
539
679
  [6. 9. 6.]
540
680
  [4. 6. 4.]]]]
681
+
682
+ Tutorial Examples:
683
+ - `Saving and Loading the Model - Saving and Loading MindIR
684
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
541
685
  """
542
686
  if not isinstance(file_name, str):
543
687
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
@@ -578,6 +722,57 @@ def load(file_name, **kwargs):
578
722
  return graph
579
723
 
580
724
 
725
+ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
726
+ """
727
+ Auto Split MindIR.
728
+
729
+ The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
730
+
731
+ Args:
732
+ file_name (str): MindIR file name.
733
+ device_num (int): device number.
734
+ rank_id (int): rank id.
735
+ dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
736
+ sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
737
+
738
+ Raises:
739
+ ValueError: MindIR file does not exist or `file_name` is not a string.
740
+ RuntimeError: Failed to split MindIR file.
741
+
742
+ Examples:
743
+ >>> import mindspore as ms
744
+ >>> context.set_context(mode=context.GRAPH_MODE)
745
+ >>>
746
+ >>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
747
+
748
+ """
749
+ if not isinstance(file_name, str):
750
+ raise ValueError("For 'Split MindIR', the argument 'file_name' must be string, but "
751
+ "got {}.".format(type(file_name)))
752
+ if not file_name.endswith(".mindir"):
753
+ raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) should end with '.mindir', "
754
+ "please input the correct 'file_name'.")
755
+ if not os.path.exists(file_name):
756
+ raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
757
+ "please check whether the 'file_name' is correct.")
758
+ file_name = os.path.abspath(file_name)
759
+
760
+ logger.info("Execute the process of export and split mindir.")
761
+ dynamic = True
762
+ if dynamic:
763
+ graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
764
+ else:
765
+ graph = split_mindir(file_name)
766
+
767
+ if graph is None:
768
+ if _is_cipher_file(file_name):
769
+ raise RuntimeError("Export and split MindIR failed. The file may be encrypted and decrypt failed, you "
770
+ "can check whether the values of the arguments 'dec_key' and 'dec_mode'"
771
+ " are the same as when exported MindIR file, or check the file integrity.")
772
+ raise RuntimeError("Export and split MindIR failed.")
773
+ return graph
774
+
775
+
581
776
  def _check_param_type(param_config, key, target_type, requested):
582
777
  """check type of parameters"""
583
778
  if key in param_config:
@@ -655,17 +850,20 @@ def obfuscate_model(obf_config, **kwargs):
655
850
  - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
656
851
  is the same as using :func:`mindspore.export`.
657
852
  - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
658
- should be in range of (0, 1] or in ["small", "medium", "large"].
853
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
854
+ correspond to 0.1, 0.3, and 0.6 respectively.
659
855
  - customized_func (function): A python function used for customized function mode, which used for control
660
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
661
- function needs to ensure that its result is constant for any input. Users can refer to opaque
856
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
857
+ Reference to 'my_func()' in
858
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
859
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
662
860
  predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
663
861
  when loading obfuscated model.
664
- - obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
665
- weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
666
- then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should be
667
- noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
668
- would be applied if both of them are set.
862
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
863
+ structure of obfuscated models corresponding to different random seeds is different. If
864
+ `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
865
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
866
+ be set, and the latter mode would be applied if both of them are set.
669
867
 
670
868
  kwargs (dict): Configuration options dictionary.
671
869
 
@@ -685,12 +883,14 @@ def obfuscate_model(obf_config, **kwargs):
685
883
  ValueError: If `original_model_path` is not exist or `original_model_path` is not end with '.mindir'.
686
884
 
687
885
  Examples:
886
+ >>> import mindspore as ms
887
+ >>> import mindspore.nn as nn
688
888
  >>> obf_config = {'original_model_path': "./net.mindir",
689
889
  ... 'save_model_path': "./obf_net",
690
890
  ... 'model_inputs': [input1, ],
691
891
  ... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
692
- >>> obfuscate_model(obf_config)
693
- >>> obf_func = load("obf_net.mindir")
892
+ >>> ms.obfuscate_model(obf_config)
893
+ >>> obf_func = ms.load("obf_net.mindir")
694
894
  >>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
695
895
  >>> print(obf_net(input1).asnumpy())
696
896
  """
@@ -762,23 +962,24 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
762
962
 
763
963
  Args:
764
964
  ckpt_file_name (str): Checkpoint file name.
765
- net (Cell): The network where the parameters will be loaded. Default: None
766
- strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
965
+ net (Cell): The network where the parameters will be loaded. Default: ``None`` .
966
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
767
967
  into net when parameter name's suffix in checkpoint file is the same as the
768
968
  parameter in the network. When the types are inconsistent perform type conversion
769
- on the parameters of the same type, such as float32 to float16. Default: False.
969
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
770
970
  filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
771
- filter_prefix will not be loaded. Default: None.
772
- dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
773
- is not required. Default: None.
774
- dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
775
- mode, currently supports 'AES-GCM' and 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
971
+ filter_prefix will not be loaded. Default: ``None`` .
972
+ dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
973
+ is not required. Default: ``None`` .
974
+ dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
975
+ mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
976
+ Default: ``"AES-GCM"`` .
776
977
  specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
777
- specify_prefix will be loaded. Default: None.
978
+ specify_prefix will be loaded. Default: ``None`` .
778
979
  choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
779
- and the return value is a bool. If returns True, the Parameter
780
- that matches the custom condition will be loaded. If returns False, the Parameter that
781
- matches the custom condition will be removed. Default: None.
980
+ and the return value is a bool. If returns ``True`` , the Parameter
981
+ that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
982
+ matches the custom condition will be removed. Default: ``None`` .
782
983
 
783
984
  Returns:
784
985
  Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
@@ -801,23 +1002,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
801
1002
  >>> print(param_dict["conv2.weight"])
802
1003
  Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
803
1004
  >>> def func(param_name):
804
- >>> whether_load = False
805
- >>> if param_name.startswith("conv"):
806
- >>> whether_load = True
807
- >>> if param_name.startswith("conv1"):
808
- >>> whether_load = False
809
- >>> return whether_load
1005
+ ... whether_load = False
1006
+ ... if param_name.startswith("conv"):
1007
+ ... whether_load = True
1008
+ ... if param_name.startswith("conv1"):
1009
+ ... whether_load = False
1010
+ ... return whether_load
810
1011
  >>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
811
1012
  >>> print(param_dict1["conv2.weight"])
812
1013
  Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
813
1014
  >>> def func(param_name):
814
- >>> whether_load = False
815
- >>> if param_name.startswith("conv1"):
816
- >>> whether_load = True
817
- >>> return whether_load
1015
+ ... whether_load = False
1016
+ ... if param_name.startswith("conv1"):
1017
+ ... whether_load = True
1018
+ ... return whether_load
818
1019
  >>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
819
1020
  >>> print(param_dict2)
820
1021
  {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
1022
+
1023
+ Tutorial Examples:
1024
+ - `Saving and Loading the Model - Saving and Loading the Model Weight
1025
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
821
1026
  """
822
1027
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
823
1028
  specify_prefix = _check_prefix(specify_prefix)
@@ -830,6 +1035,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
830
1035
  parameter_dict = {}
831
1036
  try:
832
1037
  param_data_list = []
1038
+ map_data_list = [[], [], []]
1039
+ map_shape_list = [0, 0, 0]
833
1040
  if specify_prefix:
834
1041
  logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
835
1042
  "please use `choice_func` instead.")
@@ -837,13 +1044,19 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
837
1044
  logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
838
1045
  "please use `choice_func` instead.")
839
1046
  for element_id, element in enumerate(checkpoint_list.value):
1047
+ if element.tag == "random_op":
1048
+ parameter_dict["random_op"] = element.tensor.tensor_content
1049
+ continue
840
1050
  if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
841
1051
  continue
842
1052
  if specify_prefix is None and filter_prefix is None and \
843
1053
  choice_func is not None and not choice_func(element.tag):
844
1054
  continue
845
1055
  if element.tensor.ByteSize() == 0:
846
- _load_mapparameter(element, parameter_dict)
1056
+ _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
1057
+ if element.tag in parameter_dict:
1058
+ map_data_list = [[], [], []]
1059
+ map_shape_list = [0, 0, 0]
847
1060
  continue
848
1061
  data = element.tensor.tensor_content
849
1062
  data_type = element.tensor.tensor_type
@@ -856,7 +1069,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
856
1069
  param_data_list.append(element_data)
857
1070
  if (element_id == len(checkpoint_list.value) - 1) or \
858
1071
  (element.tag != checkpoint_list.value[element_id + 1].tag):
859
- param_data = np.concatenate((param_data_list), axis=0)
1072
+ new_data = b"".join(param_data_list)
1073
+ param_data = np.frombuffer(new_data, np_type)
860
1074
  param_data_list.clear()
861
1075
  dims = element.tensor.dims
862
1076
  if dims == [0] and data_type == 'str':
@@ -868,7 +1082,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
868
1082
  param_data = int(param_data[0])
869
1083
  if dims not in ([0], [1]):
870
1084
  param_data = param_data.reshape(list(dims))
871
- parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
1085
+ parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
1086
+ parameter_dict[element.tag] = parameter
1087
+ _offload_if_config(parameter)
872
1088
 
873
1089
  logger.info("Loading checkpoint files process is finished.")
874
1090
 
@@ -881,14 +1097,48 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
881
1097
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
882
1098
  f"'filter_prefix' or 'specify_prefix' are set correctly.")
883
1099
 
1100
+ if _warm_up_host_cache_enabled(parameter_dict):
1101
+ (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
884
1102
  if net is not None:
885
1103
  load_param_into_net(net, parameter_dict, strict_load)
1104
+ if _warm_up_host_cache_enabled(parameter_dict):
1105
+ _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
886
1106
 
887
1107
  return parameter_dict
888
1108
 
889
1109
 
1110
+ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1111
+ map_shape_list, parameter_dict):
1112
+ """load map parameter."""
1113
+ logger.info("Checkpoint load map_parameter.")
1114
+ if (element_id != len(checkpoint_list.value) - 1) and \
1115
+ element.tag == checkpoint_list.value[element_id + 1].tag:
1116
+ for index, tensor in enumerate(element.maptensor.tensor):
1117
+ data = tensor.tensor_content
1118
+ data_type = tensor.tensor_type
1119
+ np_type = np_type_convert.get(data_type)
1120
+ element_data = np.frombuffer(data, np_type)
1121
+ map_data_list[index].append(element_data)
1122
+ map_shape_list[index] += tensor.dims[0]
1123
+ else:
1124
+ map_array = []
1125
+ for index, tensor in enumerate(element.maptensor.tensor):
1126
+ data = tensor.tensor_content
1127
+ data_type = tensor.tensor_type
1128
+ np_type = np_type_convert.get(data_type)
1129
+ element_data = np.frombuffer(data, np_type)
1130
+ map_data_list[index].append(element_data)
1131
+ new_data = b"".join(map_data_list[index])
1132
+ param_data = np.frombuffer(new_data, np_type)
1133
+ dims = tensor.dims
1134
+ dims[0] += map_shape_list[index]
1135
+ param_data = param_data.reshape(list(dims))
1136
+ map_array.append(param_data)
1137
+ parameter_dict[element.tag] = map_array
1138
+
1139
+
890
1140
  def _check_ckpt_file_name(ckpt_file_name):
891
- """Check function load_checkpoint's cket_file_name."""
1141
+ """Check function load_checkpoint's ckpt_file_name."""
892
1142
  if not isinstance(ckpt_file_name, str):
893
1143
  raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
894
1144
  "but got {}.".format(type(ckpt_file_name)))
@@ -969,18 +1219,13 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
969
1219
  return whether_load
970
1220
 
971
1221
 
972
- def _load_mapparameter(element, parameter_dict):
973
- """Load map parameter from ckpt file."""
974
- map_array = []
975
- for tensor in element.maptensor.tensor:
976
- data = tensor.tensor_content
977
- data_type = tensor.tensor_type
978
- np_type = tensor_to_np_type.get(data_type)
979
- element_data = np.frombuffer(data, np_type)
980
- dims = tensor.dims
981
- param_data = element_data.reshape(list(dims))
982
- map_array.append(param_data)
983
- parameter_dict[element.tag] = map_array
1222
+ def _init_parameter_data_in_parallel_mode(net, parameter_dict):
1223
+ """In parallel mode, only init the paraemters in ckpt."""
1224
+ for _, param in net.parameters_and_names():
1225
+ if param.name in parameter_dict and param.has_init:
1226
+ logger.warning("{} is not init while load ckpt.".format(param.name))
1227
+ new_tensor = param.init_data()
1228
+ param._update_tensor_data(new_tensor)
984
1229
 
985
1230
 
986
1231
  def load_param_into_net(net, parameter_dict, strict_load=False):
@@ -991,10 +1236,10 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
991
1236
  net (Cell): The network where the parameters will be loaded.
992
1237
  parameter_dict (dict): The dictionary generated by load checkpoint file,
993
1238
  it is a dictionary consisting of key: parameters's name, value: parameter.
994
- strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
1239
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
995
1240
  into net when parameter name's suffix in checkpoint file is the same as the
996
1241
  parameter in the network. When the types are inconsistent perform type conversion
997
- on the parameters of the same type, such as float32 to float16. Default: False.
1242
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
998
1243
 
999
1244
  Returns:
1000
1245
  param_not_load (List), the parameter name in model which are not loaded into the network.
@@ -1006,25 +1251,33 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1006
1251
  Examples:
1007
1252
  >>> import mindspore as ms
1008
1253
  >>>
1009
- >>> net = Net()
1254
+ >>> # Define the network structure of LeNet5. Refer to
1255
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1256
+ >>> net = LeNet5()
1010
1257
  >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1011
1258
  >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
1012
1259
  >>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
1013
1260
  >>> print(param_not_load)
1014
1261
  ['conv1.weight']
1262
+
1263
+ Tutorial Examples:
1264
+ - `Saving and Loading the Model - Saving and Loading the Model Weight
1265
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1015
1266
  """
1016
1267
  if not isinstance(net, nn.Cell):
1017
1268
  logger.critical("Failed to combine the net and the parameters.")
1018
1269
  msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1019
1270
  raise TypeError(msg)
1020
-
1021
1271
  if not isinstance(parameter_dict, dict):
1022
1272
  logger.critical("Failed to combine the net and the parameters.")
1023
1273
  msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1024
1274
  "but got {}.".format(type(parameter_dict)))
1025
1275
  raise TypeError(msg)
1276
+ if "random_op" in parameter_dict.keys():
1277
+ net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1278
+ parameter_dict.pop("random_op")
1026
1279
  for key, value in parameter_dict.items():
1027
- if not isinstance(key, str) or not isinstance(value, (Parameter, str)):
1280
+ if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1028
1281
  logger.critical("Load parameters into net failed.")
1029
1282
  msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
1030
1283
  "'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
@@ -1032,11 +1285,20 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1032
1285
 
1033
1286
  strict_load = Validator.check_bool(strict_load)
1034
1287
  logger.info("Execute the process of loading parameters into net.")
1035
- net.init_parameters_data()
1288
+ if not _is_in_auto_parallel_mode():
1289
+ net.init_parameters_data()
1290
+ else:
1291
+ _init_parameter_data_in_parallel_mode(net, parameter_dict)
1036
1292
  param_not_load = []
1037
1293
  ckpt_not_load = list(parameter_dict.keys())
1038
1294
  for _, param in net.parameters_and_names():
1039
1295
  if param.name in parameter_dict:
1296
+ if isinstance(param, MapParameter):
1297
+ param.import_data(parameter_dict[param.name])
1298
+ continue
1299
+ # Add has attr protection when load server checkpoint file on worker.
1300
+ if not hasattr(parameter_dict[param.name], "data"):
1301
+ continue
1040
1302
  new_param = copy.deepcopy(parameter_dict[param.name])
1041
1303
  _update_param(param, new_param, strict_load)
1042
1304
  ckpt_not_load.remove(param.name)
@@ -1061,6 +1323,72 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1061
1323
  return param_not_load, ckpt_not_load
1062
1324
 
1063
1325
 
1326
+ def _warm_up_host_cache_enabled(parameter_dict):
1327
+ """Warm up host cache enabled."""
1328
+ if _cache_enable():
1329
+ return True
1330
+ for key in parameter_dict.keys():
1331
+ if key.find(".__param_key__") != -1:
1332
+ return True
1333
+ return False
1334
+
1335
+
1336
+ def _warm_up_host_cache(parameter_dict, net):
1337
+ """Warm up host cache."""
1338
+ ms_role = os.getenv("MS_ROLE")
1339
+ is_worker = ms_role == "MS_WORKER"
1340
+ param_key_dict = {}
1341
+ # Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
1342
+ if is_worker:
1343
+ net.init_parameters_data()
1344
+ net_dict = {}
1345
+ for name, value in net.parameters_and_names():
1346
+ net_dict[name] = value
1347
+ for param_name, value in parameter_dict.items():
1348
+ pos = param_name.find(".__param_key__")
1349
+ if pos != -1:
1350
+ net_param_name = param_name[:pos]
1351
+ param_key_dict[param_name] = net_param_name
1352
+ net_value = None
1353
+ if net_param_name not in net_dict:
1354
+ logger.warning("net param name : %s is not in net", net_param_name)
1355
+ else:
1356
+ net_value = net_dict.get(net_param_name, None)
1357
+ pos += len(".__param_key__")
1358
+ param_key = int(param_name[pos:])
1359
+ value_is_map_parameter = isinstance(value, list) and len(value) == 3
1360
+ if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
1361
+ key_tensor = Tensor.from_numpy(value[0])
1362
+ value_tensor = Tensor.from_numpy(value[1])
1363
+ status_tensor = Tensor.from_numpy(value[2])
1364
+ _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
1365
+ elif not isinstance(value, list) and isinstance(net_value, Parameter):
1366
+ _store_warm_up_ptr_by_tensor(param_key, value)
1367
+ else:
1368
+ logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
1369
+ else:
1370
+ for param_name, value in parameter_dict.items():
1371
+ pos = param_name.find(".__param_key__")
1372
+ if pos != -1:
1373
+ net_param_name = param_name[:pos]
1374
+ param_key_dict[param_name] = net_param_name
1375
+ # Split param key from parameter_dict since worker cannot load param key.
1376
+ warm_up_dict = {}
1377
+ for key, value in param_key_dict.items():
1378
+ if is_worker:
1379
+ warm_up_dict[value] = parameter_dict.pop(key)
1380
+ else:
1381
+ parameter_dict[value] = parameter_dict.pop(key)
1382
+ return (is_worker, parameter_dict, warm_up_dict)
1383
+
1384
+
1385
+ def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
1386
+ """Warm up host cache post process."""
1387
+ if is_worker:
1388
+ net_dict.update(warm_up_dict)
1389
+ _set_checkpoint_load_status(True)
1390
+
1391
+
1064
1392
  def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
1065
1393
  """When some net parameter did not load, try to continue loading."""
1066
1394
  prefix_name = ""
@@ -1161,31 +1489,6 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
1161
1489
  return param_data
1162
1490
 
1163
1491
 
1164
- def _fill_param_into_net(net, parameter_list):
1165
- """
1166
- Fills parameter_list into net.
1167
-
1168
- Args:
1169
- net (Cell): train network.
1170
- parameter_list (list): parameters list from ge callback.
1171
- """
1172
- parameter_dict = {}
1173
- for each_param in parameter_list:
1174
- param_name = each_param["name"]
1175
- if isinstance(each_param["data"], Parameter):
1176
- each_param["data"].init_data()
1177
- np_val = each_param["data"].asnumpy()
1178
- if np_val.shape == (1,):
1179
- parameter_dict[param_name] = Parameter(np_val, name=param_name)
1180
- elif np_val.shape == ():
1181
- parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
1182
- name=param_name)
1183
- else:
1184
- parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
1185
-
1186
- load_param_into_net(net, parameter_dict, strict_load=True)
1187
-
1188
-
1189
1492
  def export(net, *inputs, file_name, file_format, **kwargs):
1190
1493
  """
1191
1494
  Export the MindSpore network into an offline model in the specified format.
@@ -1193,9 +1496,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1193
1496
  Note:
1194
1497
  1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1195
1498
  2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
1196
- 3. Exporting functions decorated with 'jit' to mindir format is supported.
1197
- 4. When exporting a function decorated with 'jit', the function should not involve class properties in
1198
- calculations.
1499
+ 3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1500
+ 4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1501
+ class properties in calculations.
1199
1502
 
1200
1503
  Args:
1201
1504
  net (Union[Cell, function]): MindSpore network.
@@ -1231,17 +1534,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1231
1534
 
1232
1535
  - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1233
1536
  - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1234
- should be in range of (0, 1] or in ["small", "medium", "large"].
1537
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1538
+ correspond to 0.1, 0.3, and 0.6 respectively.
1235
1539
  - customized_func (function): A python function used for customized function mode, which used for control
1236
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
1237
- function needs to ensure that its result is constant for any input. Users can refer to opaque
1540
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1541
+ Reference to 'my_func()' in
1542
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
1543
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
1238
1544
  predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1239
1545
  obfuscated model.
1240
- - obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
1241
- weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
1242
- then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should
1243
- be noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
1244
- would be applied if both of them are set.
1546
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1547
+ structure of obfuscated models corresponding to different random seeds is different. If
1548
+ `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
1549
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1550
+ be set, and the latter mode would be applied if both of them are set.
1245
1551
 
1246
1552
  - incremental (bool): export MindIR incrementally.
1247
1553
 
@@ -1250,10 +1556,19 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1250
1556
  >>> import numpy as np
1251
1557
  >>> from mindspore import Tensor
1252
1558
  >>>
1253
- >>> net = LeNet()
1559
+ >>> # Define the network structure of LeNet5. Refer to
1560
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1561
+ >>> net = LeNet5()
1254
1562
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1255
1563
  >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
1564
+
1565
+ Tutorial Examples:
1566
+ - `Saving and Loading the Model - Saving and Loading MindIR
1567
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
1256
1568
  """
1569
+ old_ms_jit_value = context.get_context("jit_syntax_level")
1570
+ context.set_context(jit_syntax_level=mindspore.STRICT)
1571
+
1257
1572
  supported_formats = ['AIR', 'ONNX', 'MINDIR']
1258
1573
  if file_format not in supported_formats:
1259
1574
  raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
@@ -1282,6 +1597,47 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1282
1597
  kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
1283
1598
  _export(net, file_name, file_format, *inputs, **kwargs)
1284
1599
 
1600
+ context.set_context(jit_syntax_level=old_ms_jit_value)
1601
+
1602
+
1603
+ def _get_funcgraph(net, *inputs):
1604
+ """
1605
+ Compile the MindSpore network and get FuncGraph.
1606
+
1607
+ Arg:
1608
+ net (Union[Cell, function]): MindSpore network.
1609
+ inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
1610
+ of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
1611
+ it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
1612
+ In second situation, you should adjust batch size of dataset script manually which will impact on
1613
+ the batch size of 'net' input. Only supports parse "image" column from dataset currently.
1614
+
1615
+ Returns:
1616
+ FuncGraph, a mindspore._c_expression.FuncGraph obj.
1617
+
1618
+ Raises:
1619
+ ValueError: input `net` is not a nn.Cell.
1620
+
1621
+ Examples:
1622
+ >>> import mindspore as ms
1623
+ >>> import numpy as np
1624
+ >>> from mindspore import Tensor
1625
+ >>>
1626
+ >>> # Define the network structure of LeNet5. Refer to
1627
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1628
+ >>> net = LeNet5()
1629
+ >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1630
+ >>> ms.get_funcgraph(net, input_tensor)
1631
+
1632
+ """
1633
+ if not isinstance(net, nn.Cell):
1634
+ raise ValueError(f"For get_funcgraph's parameter 'net', currently only support Cell right now.")
1635
+ phase_name = "lite_infer_predict" if _is_in_auto_parallel_mode() else "lite_infer_get_func_graph"
1636
+ graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
1637
+ # pylint: disable=protected-access
1638
+ func_graph = _executor._get_func_graph(net, graph_id)
1639
+ return func_graph
1640
+
1285
1641
 
1286
1642
  def _export(net, file_name, file_format, *inputs, **kwargs):
1287
1643
  """
@@ -1290,7 +1646,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
1290
1646
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
1291
1647
  if "obf_config" in kwargs and file_format != "MINDIR":
1292
1648
  raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
1293
-
1294
1649
  if file_format == 'AIR':
1295
1650
  _save_air(net, file_name, *inputs, **kwargs)
1296
1651
  elif file_format == 'ONNX':
@@ -1454,7 +1809,7 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1454
1809
  for param_proto in model.graph.parameter:
1455
1810
  name = param_proto.name[param_proto.name.find(":") + 1:]
1456
1811
  param = net_dict[name]
1457
- raw_data = param.data.asnumpy().tobytes()
1812
+ raw_data = param.data.get_bytes()
1458
1813
  data_length = len(raw_data)
1459
1814
  append_size = 0
1460
1815
  if data_length % 64 != 0:
@@ -1508,7 +1863,7 @@ def _msfunc_info(net, *inputs):
1508
1863
 
1509
1864
  def _cell_info(net, incremental, *inputs):
1510
1865
  """Get mindir stream and net dict of cell"""
1511
- phase_name = "predict" if _is_in_auto_parallel_mode() else "export.mindir"
1866
+ phase_name = "export.mindir"
1512
1867
  graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
1513
1868
  # pylint: disable=protected-access
1514
1869
  mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
@@ -1581,7 +1936,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
1581
1936
  for param_proto in model.graph.parameter:
1582
1937
  param_name = param_proto.name[param_proto.name.find(":") + 1:]
1583
1938
  if param_name in net_dict.keys():
1584
- param_data = net_dict[param_name].data.asnumpy().tobytes()
1939
+ param_data = net_dict[param_name].data.get_bytes()
1585
1940
  param_proto.raw_data = param_data
1586
1941
  else:
1587
1942
  raise ValueError("The parameter '{}' is not belongs to any cell,"
@@ -1591,10 +1946,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
1591
1946
  map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
1592
1947
  if map_param_name in net_dict.keys():
1593
1948
  map_parameter = net_dict[map_param_name]
1594
- key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
1595
- map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
1596
- map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
1597
- map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
1949
+ key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
1950
+ map_param_proto.key_tensor.raw_data = key_bytes
1951
+ map_param_proto.value_tensor.raw_data = value_bytes
1952
+ map_param_proto.status_tensor.raw_data = status_bytes
1598
1953
  else:
1599
1954
  raise ValueError("The map_parameter '{}' is not belongs to any cell,"
1600
1955
  "the data of parameter cannot be exported.".format(map_param_proto.name))
@@ -1625,7 +1980,7 @@ def _save_together(net_dict, model):
1625
1980
  for param_proto in model.graph.parameter:
1626
1981
  name = param_proto.name[param_proto.name.find(":") + 1:]
1627
1982
  if name in net_dict.keys():
1628
- data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
1983
+ data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
1629
1984
  else:
1630
1985
  raise ValueError("The parameter '{}' is not belongs to any cell,"
1631
1986
  "the data of parameter cannot be exported.".format(param_proto.name))
@@ -1656,7 +2011,7 @@ def _save_dataset_to_mindir(model, dataset):
1656
2011
 
1657
2012
  def parse_print(print_file_name):
1658
2013
  """
1659
- Parse data file generated by mindspore.ops.Print.
2014
+ Parse data file generated by :class:`mindspore.ops.Print`.
1660
2015
 
1661
2016
  Args:
1662
2017
  print_file_name (str): The file name needs to be parsed.
@@ -1671,9 +2026,7 @@ def parse_print(print_file_name):
1671
2026
  Examples:
1672
2027
  >>> import numpy as np
1673
2028
  >>> import mindspore as ms
1674
- >>> import mindspore.ops as ops
1675
- >>> from mindspore import nn
1676
- >>> from mindspore import Tensor
2029
+ >>> from mindspore import nn, Tensor, ops
1677
2030
  >>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
1678
2031
  >>> class PrintInputTensor(nn.Cell):
1679
2032
  ... def __init__(self):
@@ -1688,8 +2041,7 @@ def parse_print(print_file_name):
1688
2041
  >>> net = PrintInputTensor()
1689
2042
  >>> net(input_pra)
1690
2043
  >>>
1691
- >>> import mindspore
1692
- >>> data = mindspore.parse_print('./log.data')
2044
+ >>> data = ms.parse_print('./log.data')
1693
2045
  >>> print(data)
1694
2046
  ['print:', Tensor(shape=[2, 4], dtype=Float32, value=
1695
2047
  [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
@@ -1836,8 +2188,8 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
1836
2188
  def restore_group_info_list(group_info_file_name):
1837
2189
  """
1838
2190
  Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
1839
- who saves the group_info_file_name. To save the group info file, please export GROUP_INFO_FILE environment variables
1840
- like "export GROUP_INFO_FILE=/data/group_info.pb".
2191
+ who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2192
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
1841
2193
 
1842
2194
  Args:
1843
2195
  group_info_file_name (str): Name of group information file.
@@ -1847,10 +2199,11 @@ def restore_group_info_list(group_info_file_name):
1847
2199
 
1848
2200
  Raises:
1849
2201
  ValueError: group information file is incorrect.
1850
- TypeError: group_info_file_name is not str.
2202
+ TypeError: `group_info_file_name` is not str.
1851
2203
 
1852
2204
  Examples:
1853
- >>> restore_list = restore_group_info_list("./group_info.pb")
2205
+ >>> import mindspore as ms
2206
+ >>> ms.restore_list = restore_group_info_list("./group_info.pb")
1854
2207
  """
1855
2208
  if not isinstance(group_info_file_name, str):
1856
2209
  raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
@@ -1868,9 +2221,6 @@ def restore_group_info_list(group_info_file_name):
1868
2221
  def build_searched_strategy(strategy_filename):
1869
2222
  """
1870
2223
  Build strategy of every parameter in network. Used in the case of distributed inference.
1871
- For details of it, please check:
1872
- `Saving and Loading Models in Hybrid Parallel Mode
1873
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
1874
2224
 
1875
2225
  Args:
1876
2226
  strategy_filename (str): Name of strategy file.
@@ -1880,10 +2230,11 @@ def build_searched_strategy(strategy_filename):
1880
2230
 
1881
2231
  Raises:
1882
2232
  ValueError: Strategy file is incorrect.
1883
- TypeError: strategy_filename is not a string.
2233
+ TypeError: `strategy_filename` is not a string.
1884
2234
 
1885
2235
  Examples:
1886
- >>> strategy = build_searched_strategy("./strategy_train.ckpt")
2236
+ >>> import mindspore as ms
2237
+ >>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
1887
2238
  """
1888
2239
  return _build_searched_strategy(strategy_filename)
1889
2240
 
@@ -1891,14 +2242,12 @@ def build_searched_strategy(strategy_filename):
1891
2242
  def merge_sliced_parameter(sliced_parameters, strategy=None):
1892
2243
  """
1893
2244
  Merge parameter slices into one parameter. Used in the case of distributed inference.
1894
- For details of it, please check:
1895
- `<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
1896
2245
 
1897
2246
  Args:
1898
2247
  sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
1899
2248
  strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
1900
2249
  value is slice strategy of this parameter. If strategy is None, just merge
1901
- parameter slices in 0 axis order. Default: None.
2250
+ parameter slices in 0 axis order. Default: ``None``.
1902
2251
 
1903
2252
  Returns:
1904
2253
  Parameter, the merged parameter which has the whole data.
@@ -1986,32 +2335,128 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
1986
2335
  train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
1987
2336
  """
1988
2337
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
1989
- For details of distributed inference, please check:
1990
- `Distributed Inference
1991
- <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/distributed_inference.html>`_ .
1992
2338
 
1993
2339
  Args:
1994
2340
  network (Cell): Network for distributed predication.
1995
2341
  checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
1996
2342
  predict_strategy (dict): Strategy of predication process. It means that using one device to predict
1997
- when setting predict_strategy as None. Default: None.
2343
+ when setting predict_strategy as None. Default: ``None`` .
1998
2344
  train_strategy_filename (str): The filename of training strategy protocol buffer file.
1999
2345
  When train_strategy_filename is None, the training strategy file will be
2000
2346
  obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
2001
2347
  Therefore, the training strategy file needs to be specified
2002
- in at least one of them. Default: None.
2003
- strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
2348
+ in at least one of them. Default: ``None`` .
2349
+ strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
2004
2350
  into net when parameter name's suffix in checkpoint file is the same as the
2005
2351
  parameter in the network. When the types are inconsistent perform type conversion
2006
- on the parameters of the same type, such as float32 to float16. Default: False.
2007
- dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
2008
- is not required. Default: None.
2009
- dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
2010
- mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
2352
+ on the parameters of the same type, such as float32 to float16. Default: ``False`` .
2353
+ dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
2354
+ is not required. Default: ``None`` .
2355
+ dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
2356
+ mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
2357
+ Default: ``'AES-GCM'`` .
2011
2358
 
2012
2359
  Raises:
2013
2360
  TypeError: The type of inputs do not match the requirements.
2014
2361
  ValueError: Failed to load checkpoint into net.
2362
+
2363
+ Supported Platforms:
2364
+ ``Ascend`` ``GPU``
2365
+
2366
+ Examples:
2367
+ .. note::
2368
+ Before running the following examples, you need to configure the communication environment variables.
2369
+
2370
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2371
+ Please see the `rank table startup
2372
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
2373
+ for more details.
2374
+
2375
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2376
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
2377
+
2378
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2379
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
2380
+
2381
+ >>> import os
2382
+ >>> import numpy as np
2383
+ >>> import mindspore as ms
2384
+ >>> import mindspore.dataset as ds
2385
+ >>> from mindspore import nn, ops, train
2386
+ >>> from mindspore.communication import init
2387
+ >>>
2388
+ >>> step_per_epoch = 4
2389
+ >>> device_num = 8
2390
+ >>>
2391
+ >>> # Define the network structure.
2392
+ >>> class Net(nn.Cell):
2393
+ ... def __init__(self, matmul_size, strategy=None):
2394
+ ... super().__init__()
2395
+ ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
2396
+ ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
2397
+ ... self.matmul = ops.MatMul()
2398
+ ... self.neg = ops.Neg()
2399
+ ... if strategy is not None:
2400
+ ... self.matmul.shard(strategy)
2401
+ ...
2402
+ ... def construct(self, inputs):
2403
+ ... x = self.matmul(inputs, self.matmul_weight)
2404
+ ... x = self.neg(x)
2405
+ ... return x
2406
+ >>>
2407
+ >>> # Create dataset.
2408
+ >>> def get_dataset(*inputs):
2409
+ ... def generate():
2410
+ ... for _ in range(step_per_epoch):
2411
+ ... yield inputs
2412
+ ... return generate
2413
+ >>>
2414
+ >>> # Train network and save distributed checkpoint.
2415
+ >>> def train_net():
2416
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2417
+ ... init()
2418
+ ... np.random.seed(1)
2419
+ ... input_data = np.random.rand(16, 96).astype(np.float32)
2420
+ ... label_data = np.random.rand(16, 16).astype(np.float32)
2421
+ ... fake_dataset = get_dataset(input_data, label_data)
2422
+ ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
2423
+ ...
2424
+ ... # Set parallel strategy.
2425
+ ... strategy = ((1, 4), (4, 1))
2426
+ ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
2427
+ ... strategy_ckpt_save_file="./train_strategy.ckpt")
2428
+ ... network = Net(matmul_size=(96, 16), strategy=strategy)
2429
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
2430
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
2431
+ ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
2432
+ ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
2433
+ ... global_rank_id = int(os.getenv("RANK_ID"))
2434
+ ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
2435
+ ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
2436
+ ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
2437
+ ... ms.reset_auto_parallel_context()
2438
+ >>>
2439
+ >>> # Load distributed checkpoint and test.
2440
+ >>> def load_model():
2441
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2442
+ ... init()
2443
+ ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
2444
+ ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
2445
+ ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
2446
+ ... network = Net(matmul_size=(96, 16))
2447
+ ... model = ms.Model(network)
2448
+ ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
2449
+ ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
2450
+ ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
2451
+ ... predict_result = model.predict(predict_data)
2452
+ ... print(predict_result)
2453
+ >>>
2454
+ >>> train_net()
2455
+ >>> load_model()
2456
+ [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
2457
+ [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
2458
+ ...
2459
+ [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2015
2460
  """
2016
2461
  network = Validator.check_isinstance("network", network, nn.Cell)
2017
2462
  _check_checkpoint_file(checkpoint_filenames)
@@ -2127,6 +2572,11 @@ def async_ckpt_thread_status():
2127
2572
  Returns:
2128
2573
  bool, True, Asynchronous save checkpoint thread is running.
2129
2574
  False, Asynchronous save checkpoint thread is not executing.
2575
+
2576
+ Examples:
2577
+ >>> import mindspore as ms
2578
+ >>> ms.async_ckpt_thread_status()
2579
+ False
2130
2580
  """
2131
2581
  thr_list = threading.enumerate()
2132
2582
  return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
@@ -2184,7 +2634,8 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
2184
2634
  return merged_param
2185
2635
  param_name = merged_param.name
2186
2636
  tensor_layout = predict_strategy[param_name]
2187
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1])
2637
+ rank = get_rank()
2638
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
2188
2639
  requires_grad = merged_param.requires_grad
2189
2640
  layerwise_parallel = merged_param.layerwise_parallel
2190
2641
  split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
@@ -2268,7 +2719,8 @@ def convert_model(mindir_file, convert_file, file_format):
2268
2719
  ValueError: If the parameter `file_format` is not "ONNX".
2269
2720
 
2270
2721
  Examples:
2271
- >>> convert_model("lenet.mindir", "lenet.onnx", "ONNX")
2722
+ >>> import mindspore as ms
2723
+ >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
2272
2724
  """
2273
2725
  Validator.check_file_name_by_regular(mindir_file)
2274
2726
  Validator.check_file_name_by_regular(convert_file)