mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import os
23
23
  import shutil
24
24
  import stat
25
25
  import threading
26
- from threading import Thread, Lock
26
+ from threading import Thread, RLock
27
27
  from collections import defaultdict, OrderedDict
28
28
  from io import BytesIO
29
29
 
@@ -59,19 +59,23 @@ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_
59
59
  from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
60
60
  from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
61
61
  _restore_group_info_list
62
+ from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
63
+ _store_warm_up_ptr_by_tensor_list, _cache_enable
62
64
  from mindspore.train._utils import read_proto
63
65
  from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
64
- split_mindir
66
+ split_mindir, split_dynamic_mindir
65
67
  from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
68
+ from ..ops.operations import Cast
66
69
 
67
70
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
68
71
  "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
69
72
  "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
70
- "Bool": mstype.bool_, "str": mstype.string}
73
+ "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
71
74
 
72
75
  tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
73
76
  "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
74
- "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
77
+ "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U",
78
+ "BFloat16": np.float32}
75
79
 
76
80
  np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
77
81
 
@@ -79,7 +83,7 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4:
79
83
  5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
80
84
  11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
81
85
 
82
- _ckpt_mutex = Lock()
86
+ _ckpt_mutex = RLock()
83
87
 
84
88
  # unit is KB
85
89
  SLICE_SIZE = 512 * 1024
@@ -89,6 +93,8 @@ PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
89
93
  ENCRYPT_BLOCK_SIZE = 64 * 1024
90
94
  INT_64_MAX = 9223372036854775807
91
95
 
96
+ cpu_cast = Cast().set_device("CPU")
97
+
92
98
 
93
99
  def _special_process_par(par, new_par):
94
100
  """
@@ -105,7 +111,11 @@ def _special_process_par(par, new_par):
105
111
  if new_par.data.shape[par_shape_len + i] != 1:
106
112
  return False
107
113
 
108
- new_val = new_par.data.asnumpy()
114
+ if new_par.data.dtype == mstype.bfloat16:
115
+ new_val = cpu_cast(new_par.data, mstype.float32).asnumpy()
116
+ else:
117
+ new_val = new_par.data.asnumpy()
118
+
109
119
  new_val = new_val.reshape(par.data.shape)
110
120
  par.set_data(Tensor(new_val, par.data.dtype))
111
121
  return True
@@ -126,7 +136,10 @@ def _update_param(param, new_param, strict_load):
126
136
 
127
137
  if param.data.dtype != new_param.data.dtype:
128
138
  if _type_convert(param, new_param, strict_load):
129
- new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
139
+ if new_param.data.dtype == mstype.bfloat16:
140
+ new_tensor = cpu_cast(new_param.data, param.data.dtype)
141
+ else:
142
+ new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
130
143
  param.set_data(new_tensor, param.sliced)
131
144
  return
132
145
 
@@ -229,10 +242,16 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
229
242
  continue
230
243
  if value[0] == "offload_parameter":
231
244
  new_value = value[1:]
232
- new_value[2] = value[3].asnumpy().reshape(-1)
245
+ if value[3].dtype == mstype.bfloat16:
246
+ new_value[2] = cpu_cast(value[3], mstype.float32).asnumpy().reshape(-1)
247
+ else:
248
+ new_value[2] = value[3].asnumpy().reshape(-1)
233
249
  _write_parameter_data(name, new_value, f, enc_key, plain_data)
234
250
  _offload_if_config(value[3])
235
251
  continue
252
+ if value[0] == "BFloat16_tensor":
253
+ _write_bfloat16_data(name, value, f, enc_key, plain_data)
254
+ continue
236
255
  if isinstance(value[2], Tensor):
237
256
  _write_hugeparameter(name, value, f)
238
257
  continue
@@ -267,6 +286,21 @@ def _write_random_seed(name, value, f):
267
286
  f.write(checkpoint_list.SerializeToString())
268
287
 
269
288
 
289
+ def _write_bfloat16_data(name, value, f, enc_key, plain_data):
290
+ """Write bfloat16 data into protobuf file"""
291
+ checkpoint_list = Checkpoint()
292
+ param_value = checkpoint_list.value.add()
293
+ param_value.tag = name
294
+ param_tensor = param_value.tensor
295
+ param_tensor.dims.extend(value[1])
296
+ param_tensor.tensor_type = value[2]
297
+ param_tensor.tensor_content = value[3].get_bytes()
298
+ if enc_key is None:
299
+ f.write(checkpoint_list.SerializeToString())
300
+ else:
301
+ plain_data.write(checkpoint_list.SerializeToString())
302
+
303
+
270
304
  def _write_parameter_data(name, value, f, enc_key, plain_data):
271
305
  """Write parameter data into protobuf file."""
272
306
  data_size = value[2].nbytes / 1024
@@ -333,8 +367,8 @@ def _write_hugeparameter(name, value, f):
333
367
 
334
368
  def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
335
369
  """Check save_obj and ckpt_file_name for save_checkpoint."""
336
- if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
337
- raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell or list, "
370
+ if not isinstance(save_obj, (nn.Cell, list, dict)):
371
+ raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
338
372
  "but got {}.".format(type(save_obj)))
339
373
  if not isinstance(ckpt_file_name, str):
340
374
  raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
@@ -351,14 +385,15 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
351
385
 
352
386
  def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
353
387
  async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
354
- """
388
+ r"""
355
389
  Save checkpoint to a specified file.
356
390
 
357
391
  Args:
358
- save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
359
- [{"name": param_name, "data": param_data},...], the type of
360
- param_name would be string, and the type of param_data would
361
- be parameter or Tensor).
392
+ save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
393
+ list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
394
+ elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
395
+ `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
396
+ it can be the returned value of `mindspore.load_checkpoint()`.
362
397
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
363
398
  integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
364
399
  async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
@@ -370,16 +405,14 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
370
405
  mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
371
406
  Default: ``"AES-GCM"`` .
372
407
  choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
373
- a parameter name in string type, and the return value is a bool.
408
+ a parameter name in string type, and the returned value is a bool.
374
409
  If returns ``True`` , the Parameter that matching the custom condition will be saved.
375
410
  If returns ``False`` , the Parameter that not matching the custom condition will not
376
411
  be saved. Default: ``None`` .
377
412
  kwargs (dict): Configuration options dictionary.
378
413
 
379
- - incremental (bool): Whether export checkpoint for MapParameter incrementally.
380
-
381
414
  Raises:
382
- TypeError: If the parameter `save_obj` is not `nn.Cell` or list type.
415
+ TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
383
416
  TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
384
417
  TypeError: If the parameter `ckpt_file_name` is not string type.
385
418
 
@@ -387,17 +420,27 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
387
420
  >>> import mindspore as ms
388
421
  >>>
389
422
  >>> # Define the network structure of LeNet5. Refer to
390
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
423
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
391
424
  >>> net = LeNet5()
392
425
  >>> ms.save_checkpoint(net, "./lenet.ckpt",
393
- >>> choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
394
- >>> param_dict = ms.load_checkpoint("./lenet.ckpt")
395
- >>> print(param_dict)
426
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
427
+ >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
428
+ >>> print(param_dict1)
396
429
  {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
430
+ >>> params_list = net.trainable_params()
431
+ >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
432
+ ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
433
+ >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
434
+ >>> print(param_dict2)
435
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
436
+ >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
437
+ >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
438
+ >>> print(param_dict3)
439
+ {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
397
440
 
398
441
  Tutorial Examples:
399
442
  - `Saving and Loading the Model - Saving and Loading the Model Weight
400
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-the-model-weight>`_
443
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
401
444
  """
402
445
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
403
446
  integrated_save = Validator.check_bool(integrated_save)
@@ -408,70 +451,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
408
451
  map_param_inc = kwargs.get('incremental', False)
409
452
  logger.info("Execute the process of saving checkpoint files.")
410
453
 
411
- if isinstance(save_obj, nn.Cell):
412
- if save_obj.ge_init and not save_obj.ge_sync_data:
413
- from mindspore.train.callback._callback import set_cur_net
414
- set_cur_net(save_obj)
415
- save_obj.exec_checkpoint_graph()
416
- parameter_layout_dict = save_obj.parameter_layout_dict
417
- if _is_in_auto_parallel_mode() and not parameter_layout_dict:
418
- parameter_layout_dict = _get_parameter_layout()
419
- if not _is_in_auto_parallel_mode():
420
- save_obj.init_parameters_data()
421
- param_dict = OrderedDict()
422
- for _, param in save_obj.parameters_and_names():
423
- not_sliced = not param.sliced
424
- is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
425
- # All parameters are initialized immediately under PyNative mode, skip this judgement.
426
- if is_graph_mode and _is_in_auto_parallel_mode() and (not_sliced or param.has_init):
427
- continue
428
- if choice_func is not None and not choice_func(param.name):
429
- continue
430
- param_dict[param.name] = param
431
- param_list = []
432
- if append_dict and "random_op" in append_dict:
433
- phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
434
- if phase in save_obj.compile_cache and _executor.has_compiled(phase):
435
- random_byte = _executor._graph_executor.get_random_status(phase)
436
- param_list.append({"name": "random_op", "data": random_byte})
437
- append_dict.pop("random_op")
438
- for (key, value) in param_dict.items():
439
- each_param = {"name": key}
440
- if isinstance(value, MapParameter):
441
- each_param["data"] = value
442
- param_list.append(each_param)
443
- continue
444
-
445
- if value.data.is_persistent_data():
446
- # list save persistent_data: [Tensor, shape, type, param.key]
447
- param_data = ["persistent_data"]
448
- param_data.append(value.data)
449
- param_data.append(value.param_info.origin_shape)
450
- param_data.append(str(value.dtype))
451
- param_data.append(value.key)
452
- elif value.data.offload_file_path() != "":
453
- # list save offload data: [Param, shape, type, param.key]
454
- param_data = ["offload_parameter"]
455
- param_tensor = value.data
456
- if key in parameter_layout_dict:
457
- param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
458
- integrated_save)
459
- param_data.append(param_tensor)
460
- param_data.append(param_tensor.shape)
461
- param_data.append(str(param_tensor.dtype))
462
- param_data.append(value.key)
463
- else:
464
- param_data = Tensor(value.data.asnumpy())
465
-
466
- # in automatic model parallel scenario, some parameters were split to all the devices,
467
- # which should be combined before saving
468
- if key in parameter_layout_dict:
469
- param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
470
- integrated_save)
471
-
472
- each_param["data"] = param_data
473
- param_list.append(each_param)
474
- save_obj = param_list
454
+ save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
475
455
 
476
456
  if append_dict:
477
457
  append_info_list = []
@@ -479,7 +459,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
479
459
  if not isinstance(value, str):
480
460
  value = Tensor(value)
481
461
  append_info_list.append({"name": k_name, "data": value})
482
- save_obj.extend(append_info_list)
462
+ save_obj.extend(append_info_list)
483
463
 
484
464
  data_list = OrderedDict()
485
465
  with _ckpt_mutex:
@@ -499,6 +479,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
499
479
  elif param["data"][0] == "offload_parameter":
500
480
  data_list[key].append("offload_parameter")
501
481
  _save_param_list_data(data_list, key, param)
482
+ elif param["data"][0] == "BFloat16_tensor":
483
+ data_list[key].append("BFloat16_tensor")
484
+ _save_param_list_data(data_list, key, param)
485
+ continue
502
486
 
503
487
  if isinstance(param["data"], str):
504
488
  data_list[key].append([0])
@@ -508,6 +492,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
508
492
  else:
509
493
  if isinstance(param["data"], Parameter):
510
494
  param["data"].init_data()
495
+ if isinstance(param["data"], Tensor) and param["data"].dtype == mstype.bfloat16:
496
+ data_list[key].append("BFloat16_tensor")
497
+ dims = []
498
+ for dim in param["data"].shape:
499
+ dims.append(dim)
500
+ data_list[key].append(dims)
501
+ data_list[key].append("BFloat16")
502
+ data_list[key].append(cpu_cast(param["data"], mstype.float32))
503
+ continue
511
504
  dims = []
512
505
  if param['data'].shape == ():
513
506
  dims.append(0)
@@ -517,7 +510,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
517
510
  data_list[key].append(dims)
518
511
  tensor_type = str(param["data"].dtype)
519
512
  data_list[key].append(tensor_type)
520
- data = param["data"].asnumpy().reshape(-1)
513
+ if param["data"].dtype == mstype.bfloat16:
514
+ data = cpu_cast(param["data"], mstype.float32).asnumpy().reshape(-1)
515
+ else:
516
+ data = param["data"].asnumpy().reshape(-1)
521
517
  data_list[key].append(data)
522
518
 
523
519
  if async_save:
@@ -530,6 +526,130 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
530
526
  logger.info("Saving checkpoint process is finished.")
531
527
 
532
528
 
529
+ def _convert_list_to_param_list(save_obj, choice_func):
530
+ """Convert a list of Parameter to param_list."""
531
+ param_list = []
532
+ if not save_obj:
533
+ return param_list
534
+ if isinstance(save_obj[0], dict):
535
+ param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
536
+ else:
537
+ for param in save_obj:
538
+ if isinstance(param, Parameter):
539
+ if choice_func is not None and not choice_func(param.name):
540
+ continue
541
+ each_param = {"name": param.name, "data": param}
542
+ param_list.append(each_param)
543
+ else:
544
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
545
+ f"the param should be parameter, but got {type(param)}")
546
+ return param_list
547
+
548
+
549
+ def _convert_dict_to_param_dict(save_obj, choice_func):
550
+ """Convert a dict of Parameter to param_list."""
551
+ param_list = []
552
+ for (key, value) in save_obj.items():
553
+ if isinstance(key, str) and isinstance(value, (Parameter, str)):
554
+ if choice_func is not None and not choice_func(key):
555
+ continue
556
+ each_param = {"name": key, "data": value}
557
+ param_list.append(each_param)
558
+ else:
559
+ raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
560
+ f"value should be Parameter, but got the type of key is {type(key)} and"
561
+ f"the type of value is {type(value)}")
562
+ return param_list
563
+
564
+
565
+ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
566
+ """Convert cell.parameters_and_names to OrderedDict."""
567
+ param_dict = OrderedDict()
568
+ for _, param in save_obj.parameters_and_names():
569
+ not_sliced = not param.sliced
570
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
571
+ # All parameters are initialized immediately under PyNative mode, skip this judgement.
572
+ judgment = not_sliced or param.has_init
573
+ if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
574
+ continue
575
+ if choice_func is not None and not choice_func(param.name):
576
+ continue
577
+ # Add suffix for cache_enabled parameter, and then parameter can carry key info.
578
+ # Notice that suffix needs be removed when loading into net.
579
+ if param.cache_enable:
580
+ param_dict[param.name + ".__param_key__" + str(param.key)] = param
581
+ else:
582
+ param_dict[param.name] = param
583
+ return param_dict
584
+
585
+
586
+ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
587
+ """Convert nn.Cell to param_list."""
588
+ param_list = []
589
+ parameter_layout_dict = save_obj.parameter_layout_dict
590
+ if _is_in_auto_parallel_mode() and not parameter_layout_dict:
591
+ parameter_layout_dict = _get_parameter_layout()
592
+ if not _is_in_auto_parallel_mode():
593
+ save_obj.init_parameters_data()
594
+ param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
595
+ if append_dict and "random_op" in append_dict:
596
+ phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
597
+ if phase in save_obj.compile_cache and _executor.has_compiled(phase):
598
+ random_byte = _executor._graph_executor.get_random_status(phase)
599
+ param_list.append({"name": "random_op", "data": random_byte})
600
+ append_dict.pop("random_op")
601
+ for (key, value) in param_dict.items():
602
+ each_param = {"name": key}
603
+ if isinstance(value, MapParameter):
604
+ each_param["data"] = value
605
+ param_list.append(each_param)
606
+ continue
607
+
608
+ if value.data.is_persistent_data():
609
+ # list save persistent_data: [Tensor, shape, type, param.key]
610
+ param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
611
+ elif value.data.offload_file_path() != "":
612
+ # list save offload data: [Param, shape, type, param.key]
613
+ param_data = ["offload_parameter"]
614
+ param_tensor = value.data
615
+ if key in parameter_layout_dict:
616
+ param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
617
+ integrated_save)
618
+ param_data.append(param_tensor)
619
+ param_data.append(param_tensor.shape)
620
+ param_data.append(str(param_tensor.dtype))
621
+ param_data.append(value.key)
622
+ elif value.data.dtype == mstype.bfloat16:
623
+ param_data = ["BFloat16_tensor"]
624
+ param_data.append(cpu_cast(value.data, mstype.float32))
625
+ param_data.append(value.data.shape)
626
+ param_data.append("BFloat16")
627
+ param_data.append(value.key)
628
+ else:
629
+ param_data = Tensor(value.data.asnumpy())
630
+
631
+ # in automatic model parallel scenario, some parameters were split to all the devices,
632
+ # which should be combined before saving
633
+ if key in parameter_layout_dict:
634
+ param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
635
+ integrated_save)
636
+
637
+ each_param["data"] = param_data
638
+ param_list.append(each_param)
639
+ return param_list
640
+
641
+
642
+ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
643
+ """Convert a save_obj to param_list."""
644
+ if isinstance(save_obj, list):
645
+ return _convert_list_to_param_list(save_obj, choice_func)
646
+
647
+ if isinstance(save_obj, dict):
648
+ return _convert_dict_to_param_dict(save_obj, choice_func)
649
+
650
+ return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
651
+
652
+
533
653
  def _save_param_list_data(data_list, key, param):
534
654
  """Save persistent data into save_obj."""
535
655
  dims = []
@@ -585,7 +705,7 @@ def load(file_name, **kwargs):
585
705
 
586
706
  - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
587
707
  `obfuscate_model()
588
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore/mindspore.obfuscate_model.html>`_.
708
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
589
709
 
590
710
  Returns:
591
711
  GraphCell, a compiled graph that can executed by `GraphCell`.
@@ -615,7 +735,7 @@ def load(file_name, **kwargs):
615
735
 
616
736
  Tutorial Examples:
617
737
  - `Saving and Loading the Model - Saving and Loading MindIR
618
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-mindir>`_
738
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
619
739
  """
620
740
  if not isinstance(file_name, str):
621
741
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
@@ -656,7 +776,7 @@ def load(file_name, **kwargs):
656
776
  return graph
657
777
 
658
778
 
659
- def export_split_mindir(file_name):
779
+ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
660
780
  """
661
781
  Auto Split MindIR.
662
782
 
@@ -664,6 +784,10 @@ def export_split_mindir(file_name):
664
784
 
665
785
  Args:
666
786
  file_name (str): MindIR file name.
787
+ device_num (int): device number.
788
+ rank_id (int): rank id.
789
+ dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
790
+ sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
667
791
 
668
792
  Raises:
669
793
  ValueError: MindIR file does not exist or `file_name` is not a string.
@@ -671,11 +795,9 @@ def export_split_mindir(file_name):
671
795
 
672
796
  Examples:
673
797
  >>> import mindspore as ms
674
- >>> from mindspore.communication import init
675
798
  >>> context.set_context(mode=context.GRAPH_MODE)
676
799
  >>>
677
- >>> init(backend_name="hccl")
678
- >>> ms.export_split_mindir("net.mindir")
800
+ >>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
679
801
 
680
802
  """
681
803
  if not isinstance(file_name, str):
@@ -690,8 +812,11 @@ def export_split_mindir(file_name):
690
812
  file_name = os.path.abspath(file_name)
691
813
 
692
814
  logger.info("Execute the process of export and split mindir.")
693
-
694
- graph = split_mindir(file_name)
815
+ dynamic = True
816
+ if dynamic:
817
+ graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
818
+ else:
819
+ graph = split_mindir(file_name)
695
820
 
696
821
  if graph is None:
697
822
  if _is_cipher_file(file_name):
@@ -779,17 +904,20 @@ def obfuscate_model(obf_config, **kwargs):
779
904
  - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
780
905
  is the same as using :func:`mindspore.export`.
781
906
  - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
782
- should be in range of (0, 1] or in ["small", "medium", "large"].
907
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
908
+ correspond to 0.1, 0.3, and 0.6 respectively.
783
909
  - customized_func (function): A python function used for customized function mode, which used for control
784
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
785
- function needs to ensure that its result is constant for any input. Users can refer to opaque
910
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
911
+ Reference to 'my_func()' in
912
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
913
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
786
914
  predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
787
915
  when loading obfuscated model.
788
- - obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
789
- weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
790
- then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should be
791
- noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
792
- would be applied if both of them are set.
916
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
917
+ structure of obfuscated models corresponding to different random seeds is different. If
918
+ `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
919
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
920
+ be set, and the latter mode would be applied if both of them are set.
793
921
 
794
922
  kwargs (dict): Configuration options dictionary.
795
923
 
@@ -928,27 +1056,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
928
1056
  >>> print(param_dict["conv2.weight"])
929
1057
  Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
930
1058
  >>> def func(param_name):
931
- >>> whether_load = False
932
- >>> if param_name.startswith("conv"):
933
- >>> whether_load = True
934
- >>> if param_name.startswith("conv1"):
935
- >>> whether_load = False
936
- >>> return whether_load
1059
+ ... whether_load = False
1060
+ ... if param_name.startswith("conv"):
1061
+ ... whether_load = True
1062
+ ... if param_name.startswith("conv1"):
1063
+ ... whether_load = False
1064
+ ... return whether_load
937
1065
  >>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
938
1066
  >>> print(param_dict1["conv2.weight"])
939
1067
  Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
940
1068
  >>> def func(param_name):
941
- >>> whether_load = False
942
- >>> if param_name.startswith("conv1"):
943
- >>> whether_load = True
944
- >>> return whether_load
1069
+ ... whether_load = False
1070
+ ... if param_name.startswith("conv1"):
1071
+ ... whether_load = True
1072
+ ... return whether_load
945
1073
  >>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
946
1074
  >>> print(param_dict2)
947
1075
  {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
948
1076
 
949
1077
  Tutorial Examples:
950
1078
  - `Saving and Loading the Model - Saving and Loading the Model Weight
951
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1079
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
952
1080
  """
953
1081
  ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
954
1082
  specify_prefix = _check_prefix(specify_prefix)
@@ -979,8 +1107,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
979
1107
  choice_func is not None and not choice_func(element.tag):
980
1108
  continue
981
1109
  if element.tensor.ByteSize() == 0:
982
- _load_map_parameter(checkpoint_list, element, element_id,
983
- map_data_list, map_shape_list, parameter_dict)
1110
+ _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
984
1111
  if element.tag in parameter_dict:
985
1112
  map_data_list = [[], [], []]
986
1113
  map_shape_list = [0, 0, 0]
@@ -992,6 +1119,13 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
992
1119
  if data_type == 'str':
993
1120
  str_length = int(len(data) / 4)
994
1121
  np_type = np_type + str(str_length)
1122
+ if data_type == "BFloat16":
1123
+ dims = element.tensor.dims
1124
+ param_data = np.frombuffer(data, np_type)
1125
+ param_data = param_data.reshape(list(dims))
1126
+ parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
1127
+ parameter_dict[element.tag] = parameter
1128
+ continue
995
1129
  element_data = np.frombuffer(data, np_type)
996
1130
  param_data_list.append(element_data)
997
1131
  if (element_id == len(checkpoint_list.value) - 1) or \
@@ -1024,8 +1158,12 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1024
1158
  raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
1025
1159
  f"'filter_prefix' or 'specify_prefix' are set correctly.")
1026
1160
 
1161
+ if _warm_up_host_cache_enabled(parameter_dict):
1162
+ (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
1027
1163
  if net is not None:
1028
1164
  load_param_into_net(net, parameter_dict, strict_load)
1165
+ if _warm_up_host_cache_enabled(parameter_dict):
1166
+ _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1029
1167
 
1030
1168
  return parameter_dict
1031
1169
 
@@ -1061,7 +1199,7 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1061
1199
 
1062
1200
 
1063
1201
  def _check_ckpt_file_name(ckpt_file_name):
1064
- """Check function load_checkpoint's cket_file_name."""
1202
+ """Check function load_checkpoint's ckpt_file_name."""
1065
1203
  if not isinstance(ckpt_file_name, str):
1066
1204
  raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
1067
1205
  "but got {}.".format(type(ckpt_file_name)))
@@ -1175,7 +1313,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1175
1313
  >>> import mindspore as ms
1176
1314
  >>>
1177
1315
  >>> # Define the network structure of LeNet5. Refer to
1178
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
1316
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1179
1317
  >>> net = LeNet5()
1180
1318
  >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1181
1319
  >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
@@ -1185,7 +1323,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1185
1323
 
1186
1324
  Tutorial Examples:
1187
1325
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1188
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1326
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1189
1327
  """
1190
1328
  if not isinstance(net, nn.Cell):
1191
1329
  logger.critical("Failed to combine the net and the parameters.")
@@ -1219,6 +1357,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1219
1357
  if isinstance(param, MapParameter):
1220
1358
  param.import_data(parameter_dict[param.name])
1221
1359
  continue
1360
+ # Add has attr protection when load server checkpoint file on worker.
1361
+ if not hasattr(parameter_dict[param.name], "data"):
1362
+ continue
1222
1363
  new_param = copy.deepcopy(parameter_dict[param.name])
1223
1364
  _update_param(param, new_param, strict_load)
1224
1365
  ckpt_not_load.remove(param.name)
@@ -1243,6 +1384,72 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
1243
1384
  return param_not_load, ckpt_not_load
1244
1385
 
1245
1386
 
1387
+ def _warm_up_host_cache_enabled(parameter_dict):
1388
+ """Warm up host cache enabled."""
1389
+ if _cache_enable():
1390
+ return True
1391
+ for key in parameter_dict.keys():
1392
+ if key.find(".__param_key__") != -1:
1393
+ return True
1394
+ return False
1395
+
1396
+
1397
+ def _warm_up_host_cache(parameter_dict, net):
1398
+ """Warm up host cache."""
1399
+ ms_role = os.getenv("MS_ROLE")
1400
+ is_worker = ms_role == "MS_WORKER"
1401
+ param_key_dict = {}
1402
+ # Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
1403
+ if is_worker:
1404
+ net.init_parameters_data()
1405
+ net_dict = {}
1406
+ for name, value in net.parameters_and_names():
1407
+ net_dict[name] = value
1408
+ for param_name, value in parameter_dict.items():
1409
+ pos = param_name.find(".__param_key__")
1410
+ if pos != -1:
1411
+ net_param_name = param_name[:pos]
1412
+ param_key_dict[param_name] = net_param_name
1413
+ net_value = None
1414
+ if net_param_name not in net_dict:
1415
+ logger.warning("net param name : %s is not in net", net_param_name)
1416
+ else:
1417
+ net_value = net_dict.get(net_param_name, None)
1418
+ pos += len(".__param_key__")
1419
+ param_key = int(param_name[pos:])
1420
+ value_is_map_parameter = isinstance(value, list) and len(value) == 3
1421
+ if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
1422
+ key_tensor = Tensor.from_numpy(value[0])
1423
+ value_tensor = Tensor.from_numpy(value[1])
1424
+ status_tensor = Tensor.from_numpy(value[2])
1425
+ _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
1426
+ elif not isinstance(value, list) and isinstance(net_value, Parameter):
1427
+ _store_warm_up_ptr_by_tensor(param_key, value)
1428
+ else:
1429
+ logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
1430
+ else:
1431
+ for param_name, value in parameter_dict.items():
1432
+ pos = param_name.find(".__param_key__")
1433
+ if pos != -1:
1434
+ net_param_name = param_name[:pos]
1435
+ param_key_dict[param_name] = net_param_name
1436
+ # Split param key from parameter_dict since worker cannot load param key.
1437
+ warm_up_dict = {}
1438
+ for key, value in param_key_dict.items():
1439
+ if is_worker:
1440
+ warm_up_dict[value] = parameter_dict.pop(key)
1441
+ else:
1442
+ parameter_dict[value] = parameter_dict.pop(key)
1443
+ return (is_worker, parameter_dict, warm_up_dict)
1444
+
1445
+
1446
+ def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
1447
+ """Warm up host cache post process."""
1448
+ if is_worker:
1449
+ net_dict.update(warm_up_dict)
1450
+ _set_checkpoint_load_status(True)
1451
+
1452
+
1246
1453
  def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
1247
1454
  """When some net parameter did not load, try to continue loading."""
1248
1455
  prefix_name = ""
@@ -1350,9 +1557,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1350
1557
  Note:
1351
1558
  1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1352
1559
  2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
1353
- 3. Exporting functions decorated with 'jit' to mindir format is supported.
1354
- 4. When exporting a function decorated with 'jit', the function should not involve class properties in
1355
- calculations.
1560
+ 3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1561
+ 4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1562
+ class properties in calculations.
1356
1563
 
1357
1564
  Args:
1358
1565
  net (Union[Cell, function]): MindSpore network.
@@ -1388,17 +1595,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1388
1595
 
1389
1596
  - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1390
1597
  - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1391
- should be in range of (0, 1] or in ["small", "medium", "large"].
1598
+ should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1599
+ correspond to 0.1, 0.3, and 0.6 respectively.
1392
1600
  - customized_func (function): A python function used for customized function mode, which used for control
1393
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
1394
- function needs to ensure that its result is constant for any input. Users can refer to opaque
1601
+ the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1602
+ Reference to 'my_func()' in
1603
+ `tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
1604
+ This function needs to ensure that its result is constant for any input. Users can refer to opaque
1395
1605
  predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1396
1606
  obfuscated model.
1397
- - obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
1398
- weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
1399
- then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should
1400
- be noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
1401
- would be applied if both of them are set.
1607
+ - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1608
+ structure of obfuscated models corresponding to different random seeds is different. If
1609
+ `obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
1610
+ obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1611
+ be set, and the latter mode would be applied if both of them are set.
1402
1612
 
1403
1613
  - incremental (bool): export MindIR incrementally.
1404
1614
 
@@ -1408,14 +1618,14 @@ def export(net, *inputs, file_name, file_format, **kwargs):
1408
1618
  >>> from mindspore import Tensor
1409
1619
  >>>
1410
1620
  >>> # Define the network structure of LeNet5. Refer to
1411
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
1621
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1412
1622
  >>> net = LeNet5()
1413
1623
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1414
1624
  >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
1415
1625
 
1416
1626
  Tutorial Examples:
1417
1627
  - `Saving and Loading the Model - Saving and Loading MindIR
1418
- <https://mindspore.cn/tutorials/en/r2.1/beginner/save_load.html#saving-and-loading-mindir>`_
1628
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
1419
1629
  """
1420
1630
  old_ms_jit_value = context.get_context("jit_syntax_level")
1421
1631
  context.set_context(jit_syntax_level=mindspore.STRICT)
@@ -1475,7 +1685,7 @@ def _get_funcgraph(net, *inputs):
1475
1685
  >>> from mindspore import Tensor
1476
1686
  >>>
1477
1687
  >>> # Define the network structure of LeNet5. Refer to
1478
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
1688
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
1479
1689
  >>> net = LeNet5()
1480
1690
  >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1481
1691
  >>> ms.get_funcgraph(net, input_tensor)
@@ -1657,10 +1867,17 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1657
1867
  data_file_name = os.path.join(dirname, external_local)
1658
1868
  f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
1659
1869
  try:
1870
+ round_ = 0
1871
+ names = []
1660
1872
  for param_proto in model.graph.parameter:
1661
1873
  name = param_proto.name[param_proto.name.find(":") + 1:]
1874
+ names.append((name, param_proto))
1875
+ names.sort(key=lambda x: x[0])
1876
+ for pairs in names:
1877
+ name = pairs[0]
1878
+ param_proto = pairs[1]
1662
1879
  param = net_dict[name]
1663
- raw_data = param.data.asnumpy().tobytes()
1880
+ raw_data = param.data.get_bytes()
1664
1881
  data_length = len(raw_data)
1665
1882
  append_size = 0
1666
1883
  if data_length % 64 != 0:
@@ -1678,6 +1895,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
1678
1895
  offset += (data_length + append_size)
1679
1896
  write_data = _encrypt_data(is_encrypt, write_data, kwargs)
1680
1897
  f.write(write_data)
1898
+ round_ += 1
1899
+ logger.debug(f"writing {round_}th split data, name:{name}")
1681
1900
 
1682
1901
  graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
1683
1902
  if os.path.exists(graph_file_name):
@@ -1787,7 +2006,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
1787
2006
  for param_proto in model.graph.parameter:
1788
2007
  param_name = param_proto.name[param_proto.name.find(":") + 1:]
1789
2008
  if param_name in net_dict.keys():
1790
- param_data = net_dict[param_name].data.asnumpy().tobytes()
2009
+ param_data = net_dict[param_name].data.get_bytes()
1791
2010
  param_proto.raw_data = param_data
1792
2011
  else:
1793
2012
  raise ValueError("The parameter '{}' is not belongs to any cell,"
@@ -1797,10 +2016,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
1797
2016
  map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
1798
2017
  if map_param_name in net_dict.keys():
1799
2018
  map_parameter = net_dict[map_param_name]
1800
- key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
1801
- map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
1802
- map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
1803
- map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
2019
+ key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
2020
+ map_param_proto.key_tensor.raw_data = key_bytes
2021
+ map_param_proto.value_tensor.raw_data = value_bytes
2022
+ map_param_proto.status_tensor.raw_data = status_bytes
1804
2023
  else:
1805
2024
  raise ValueError("The map_parameter '{}' is not belongs to any cell,"
1806
2025
  "the data of parameter cannot be exported.".format(map_param_proto.name))
@@ -1831,7 +2050,7 @@ def _save_together(net_dict, model):
1831
2050
  for param_proto in model.graph.parameter:
1832
2051
  name = param_proto.name[param_proto.name.find(":") + 1:]
1833
2052
  if name in net_dict.keys():
1834
- data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
2053
+ data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
1835
2054
  else:
1836
2055
  raise ValueError("The parameter '{}' is not belongs to any cell,"
1837
2056
  "the data of parameter cannot be exported.".format(param_proto.name))
@@ -1862,7 +2081,7 @@ def _save_dataset_to_mindir(model, dataset):
1862
2081
 
1863
2082
  def parse_print(print_file_name):
1864
2083
  """
1865
- Parse data file generated by mindspore.ops.Print.
2084
+ Parse data file generated by :class:`mindspore.ops.Print`.
1866
2085
 
1867
2086
  Args:
1868
2087
  print_file_name (str): The file name needs to be parsed.
@@ -2039,8 +2258,8 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
2039
2258
  def restore_group_info_list(group_info_file_name):
2040
2259
  """
2041
2260
  Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2042
- who saves the group_info_file_name. To save the group info file, please export GROUP_INFO_FILE environment variables
2043
- like "export GROUP_INFO_FILE=/data/group_info.pb".
2261
+ who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2262
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2044
2263
 
2045
2264
  Args:
2046
2265
  group_info_file_name (str): Name of group information file.
@@ -2050,7 +2269,7 @@ def restore_group_info_list(group_info_file_name):
2050
2269
 
2051
2270
  Raises:
2052
2271
  ValueError: group information file is incorrect.
2053
- TypeError: group_info_file_name is not str.
2272
+ TypeError: `group_info_file_name` is not str.
2054
2273
 
2055
2274
  Examples:
2056
2275
  >>> import mindspore as ms
@@ -2072,9 +2291,6 @@ def restore_group_info_list(group_info_file_name):
2072
2291
  def build_searched_strategy(strategy_filename):
2073
2292
  """
2074
2293
  Build strategy of every parameter in network. Used in the case of distributed inference.
2075
- For details of it, please check:
2076
- `Saving and Loading Models in Hybrid Parallel Mode
2077
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
2078
2294
 
2079
2295
  Args:
2080
2296
  strategy_filename (str): Name of strategy file.
@@ -2096,8 +2312,6 @@ def build_searched_strategy(strategy_filename):
2096
2312
  def merge_sliced_parameter(sliced_parameters, strategy=None):
2097
2313
  """
2098
2314
  Merge parameter slices into one parameter. Used in the case of distributed inference.
2099
- For details of it, please check:
2100
- `<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
2101
2315
 
2102
2316
  Args:
2103
2317
  sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
@@ -2171,7 +2385,12 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
2171
2385
 
2172
2386
  layerwise_parallel = sliced_parameters[0].layerwise_parallel
2173
2387
  requires_grad = sliced_parameters[0].requires_grad
2174
- sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters]
2388
+ sliced_data = []
2389
+ for parameter in sliced_parameters:
2390
+ if parameter.data.dtype == mstype.bfloat16:
2391
+ sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
2392
+ else:
2393
+ sliced_data.append(parameter.data.asnumpy())
2175
2394
 
2176
2395
  if not strategy:
2177
2396
  merged_tensor = Tensor(np.concatenate(sliced_data))
@@ -2191,9 +2410,6 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2191
2410
  train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
2192
2411
  """
2193
2412
  Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2194
- For details of distributed inference, please check:
2195
- `Distributed Inference
2196
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/distributed_inference.html>`_ .
2197
2413
 
2198
2414
  Args:
2199
2415
  network (Cell): Network for distributed predication.
@@ -2218,6 +2434,104 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2218
2434
  Raises:
2219
2435
  TypeError: The type of inputs do not match the requirements.
2220
2436
  ValueError: Failed to load checkpoint into net.
2437
+
2438
+ Supported Platforms:
2439
+ ``Ascend`` ``GPU``
2440
+
2441
+ Examples:
2442
+ .. note::
2443
+ Before running the following examples, you need to configure the communication environment variables.
2444
+
2445
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2446
+ Please see the `rank table startup
2447
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
2448
+ for more details.
2449
+
2450
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2451
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
2452
+
2453
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2454
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
2455
+
2456
+ >>> import os
2457
+ >>> import numpy as np
2458
+ >>> import mindspore as ms
2459
+ >>> import mindspore.dataset as ds
2460
+ >>> from mindspore import nn, ops, train
2461
+ >>> from mindspore.communication import init
2462
+ >>>
2463
+ >>> step_per_epoch = 4
2464
+ >>> device_num = 8
2465
+ >>>
2466
+ >>> # Define the network structure.
2467
+ >>> class Net(nn.Cell):
2468
+ ... def __init__(self, matmul_size, strategy=None):
2469
+ ... super().__init__()
2470
+ ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
2471
+ ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
2472
+ ... self.matmul = ops.MatMul()
2473
+ ... self.neg = ops.Neg()
2474
+ ... if strategy is not None:
2475
+ ... self.matmul.shard(strategy)
2476
+ ...
2477
+ ... def construct(self, inputs):
2478
+ ... x = self.matmul(inputs, self.matmul_weight)
2479
+ ... x = self.neg(x)
2480
+ ... return x
2481
+ >>>
2482
+ >>> # Create dataset.
2483
+ >>> def get_dataset(*inputs):
2484
+ ... def generate():
2485
+ ... for _ in range(step_per_epoch):
2486
+ ... yield inputs
2487
+ ... return generate
2488
+ >>>
2489
+ >>> # Train network and save distributed checkpoint.
2490
+ >>> def train_net():
2491
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2492
+ ... init()
2493
+ ... np.random.seed(1)
2494
+ ... input_data = np.random.rand(16, 96).astype(np.float32)
2495
+ ... label_data = np.random.rand(16, 16).astype(np.float32)
2496
+ ... fake_dataset = get_dataset(input_data, label_data)
2497
+ ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
2498
+ ...
2499
+ ... # Set parallel strategy.
2500
+ ... strategy = ((1, 4), (4, 1))
2501
+ ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
2502
+ ... strategy_ckpt_save_file="./train_strategy.ckpt")
2503
+ ... network = Net(matmul_size=(96, 16), strategy=strategy)
2504
+ ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
2505
+ ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
2506
+ ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
2507
+ ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
2508
+ ... global_rank_id = int(os.getenv("RANK_ID"))
2509
+ ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
2510
+ ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
2511
+ ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
2512
+ ... ms.reset_auto_parallel_context()
2513
+ >>>
2514
+ >>> # Load distributed checkpoint and test.
2515
+ >>> def load_model():
2516
+ ... ms.set_context(mode=ms.GRAPH_MODE)
2517
+ ... init()
2518
+ ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
2519
+ ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
2520
+ ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
2521
+ ... network = Net(matmul_size=(96, 16))
2522
+ ... model = ms.Model(network)
2523
+ ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
2524
+ ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
2525
+ ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
2526
+ ... predict_result = model.predict(predict_data)
2527
+ ... print(predict_result)
2528
+ >>>
2529
+ >>> train_net()
2530
+ >>> load_model()
2531
+ [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
2532
+ [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
2533
+ ...
2534
+ [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
2221
2535
  """
2222
2536
  network = Validator.check_isinstance("network", network, nn.Cell)
2223
2537
  _check_checkpoint_file(checkpoint_filenames)
@@ -2282,7 +2596,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2282
2596
  param_index = list(set(param_index))
2283
2597
  param_index.sort()
2284
2598
  for rank_num in param_index:
2285
- param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
2599
+ if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
2600
+ param_stride.append(
2601
+ cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
2602
+ else:
2603
+ param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
2286
2604
 
2287
2605
  sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
2288
2606
  else:
@@ -2297,7 +2615,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
2297
2615
  split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
2298
2616
  opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
2299
2617
  if opt_shard_group:
2300
- data = split_param.data.asnumpy()
2618
+ if split_param.data.dtype == mstype.bfloat16:
2619
+ data = cpu_cast(split_param.data, mstype.float32).asnumpy()
2620
+ else:
2621
+ data = split_param.data.asnumpy()
2301
2622
  rank = get_rank(opt_shard_group)
2302
2623
  size = get_group_size(opt_shard_group)
2303
2624
  try:
@@ -2395,10 +2716,15 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
2395
2716
  return merged_param
2396
2717
  param_name = merged_param.name
2397
2718
  tensor_layout = predict_strategy[param_name]
2398
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1])
2719
+ rank = get_rank()
2720
+ split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
2399
2721
  requires_grad = merged_param.requires_grad
2400
2722
  layerwise_parallel = merged_param.layerwise_parallel
2401
- split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
2723
+ data_type = merged_param.data.dtype
2724
+ if data_type == mstype.bfloat16:
2725
+ split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
2726
+ else:
2727
+ split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
2402
2728
  return split_param
2403
2729
 
2404
2730
 
@@ -2407,7 +2733,7 @@ def _calculation_net_size(net):
2407
2733
  data_total = 0
2408
2734
  net_dict = net.parameters_dict()
2409
2735
  for name in net_dict:
2410
- data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
2736
+ data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
2411
2737
 
2412
2738
  return data_total
2413
2739