mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,9 @@ import stat
27
27
  import atexit
28
28
  import threading
29
29
  from threading import Thread, RLock
30
- from multiprocessing import Pool, active_children
30
+ from multiprocessing import active_children
31
31
  import multiprocessing as mp
32
- from collections import defaultdict, OrderedDict
32
+ from collections import OrderedDict
33
33
  from io import BytesIO
34
34
 
35
35
  import math
@@ -53,37 +53,33 @@ from mindspore.log import vlog_print
53
53
  from mindspore._checkparam import check_input_data, check_input_dataset
54
54
  from mindspore import _checkparam as Validator
55
55
  from mindspore.common import dtype as mstype
56
+ from mindspore.common import np_dtype
56
57
  from mindspore.common.api import _cell_graph_executor as _executor
57
- from mindspore.common.api import _MindsporeFunctionExecutor
58
+ from mindspore.common.api import _JitExecutor
58
59
  from mindspore.common.api import _get_parameter_layout
59
- from mindspore.common.api import _generate_branch_control_input
60
60
  from mindspore.common.initializer import initializer, One
61
61
  from mindspore.common.parameter import Parameter, _offload_if_config
62
62
  from mindspore.common.tensor import Tensor
63
- from mindspore._c_expression import Tensor as Tensor_
63
+ from mindspore._c_expression import TensorPy as Tensor_
64
64
  from mindspore.common._utils import is_shape_unknown
65
65
  from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
66
66
  from mindspore.communication.management import get_rank, get_group_size
67
67
  from mindspore.experimental import MapParameter
68
68
  from mindspore.ops import Cast
69
69
  from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
70
- from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
71
- from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
72
- from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
73
- _get_device_num
74
- from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
75
- from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
76
- _restore_group_info_list, _get_param_list_when_first_dim_sharded
70
+ from mindspore.parallel._tensor import _reshape_param_data
71
+ from mindspore.parallel._utils import _is_in_auto_parallel_mode
77
72
  from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
78
73
  _store_warm_up_ptr_by_tensor_list, _cache_enable
79
74
  from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
80
- from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
81
- _extract_pipeline_stage_num
75
+ from mindspore.parallel.checkpoint_transform import restore_group_info_list as new_restore_group_info_list
76
+ from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint as new_load_distributed_checkpoint
77
+ from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
78
+ from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
82
79
  from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
83
- from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
80
+ from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
84
81
  split_mindir, split_dynamic_mindir
85
82
  from mindspore.common.generator import Generator
86
- from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
87
83
 
88
84
  tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
89
85
  "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
@@ -94,6 +90,9 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
94
90
  "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
95
91
  "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
96
92
 
93
+ if hasattr(np_dtype, "bfloat16"):
94
+ tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
95
+
97
96
  np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
98
97
 
99
98
  mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
@@ -153,22 +152,28 @@ atexit.register(_async_save_close)
153
152
 
154
153
  def _get_cur_rank_dp(parameter_layout_dict):
155
154
  """ Get dp and tp from layout dict. """
156
- pp_num = _get_auto_parallel_context("pipeline_stages")
157
- dev_num = _get_device_num()
158
155
  global_rank = get_rank()
159
- pipe_size = dev_num // pp_num
160
- initial_rank = (global_rank // pipe_size) * pipe_size
161
- parameter_redundancy_dict = get_parameter_redundancy(
162
- parameter_layout_dict, initial_rank)
156
+ parameter_redundancy_dict = get_parameter_redundancy(parameter_layout_dict)
163
157
  value_len = sys.maxsize
164
158
  min_value = ()
159
+ min_value_set = set()
165
160
  for key, value in parameter_redundancy_dict.items():
166
- if "accu_grads" in key or "inputs" in key:
161
+ if key.startswith("accu_grads") or key.startswith("inputs"):
167
162
  continue
168
163
  for item in value:
169
- if len(item) < value_len and global_rank in item:
164
+ if global_rank not in item:
165
+ continue
166
+ # if item is subset of min_value_set, update min_value_set and min_value
167
+ if len(item) < value_len:
168
+ if min_value_set and not set(item).issubset(min_value_set):
169
+ return (global_rank,)
170
170
  value_len = len(item)
171
+ min_value_set = set(item)
171
172
  min_value = item
173
+ # if value is not smaller than len of min_value len,
174
+ # check if min_value_set is subset of current item
175
+ elif not min_value_set.issubset(set(item)):
176
+ return (global_rank,)
172
177
  return min_value
173
178
 
174
179
 
@@ -188,7 +193,7 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
188
193
  cur_strategy_path (str): strategy file path for current rank.
189
194
 
190
195
  Returns:
191
- - new_ckpt_file (String), if found available checkpoint file , return it.
196
+ - new_ckpt_file (str), if found available checkpoint file , return it.
192
197
  - None, if not found available checkpoint, return None.
193
198
 
194
199
  Examples:
@@ -203,6 +208,9 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
203
208
  >>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
204
209
  >>> print(ckpt_file_new)
205
210
  """
211
+ cur_rank = get_rank()
212
+ if f"rank_{str(cur_rank)}" in cur_ckpt_path and os.path.isfile(cur_ckpt_path):
213
+ return cur_ckpt_path
206
214
  dp = _get_cur_rank_dp(cur_strategy_path)
207
215
  pattern = r'rank_\d+'
208
216
  for i in dp:
@@ -358,6 +366,8 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
358
366
  file_name_list = list(os.path.splitext(ckpt_file_name))
359
367
  file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
360
368
  tmp_name = ''.join(file_name_list)
369
+ if _ckpt_fs.backend == "mindio":
370
+ tmp_name = ckpt_file_name
361
371
  if os.path.exists(ckpt_file_name):
362
372
  os.chmod(ckpt_file_name, stat.S_IWUSR)
363
373
  os.remove(ckpt_file_name)
@@ -365,7 +375,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
365
375
  os.chmod(tmp_name, stat.S_IWUSR)
366
376
  os.remove(tmp_name)
367
377
  if format == "ckpt":
368
- ckpt_save_time_start = time.time()
378
+ ckpt_total_io_time = 0
369
379
  with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
370
380
  plain_data = None
371
381
  if enc_key is not None:
@@ -382,20 +392,26 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
382
392
  if value[0] == "offload_parameter":
383
393
  new_value = value[1:]
384
394
  new_value[2] = value[3]
385
- _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
395
+ _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data, ckpt_total_io_time)
386
396
  _offload_if_config(value[3])
387
397
  continue
388
398
  if value[1] == "str":
389
- crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
399
+ crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
400
+ crc_num, crc_check,
401
+ ckpt_total_io_time)
390
402
  continue
391
403
  if isinstance(value[2], np.ndarray):
392
- crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
404
+ crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
405
+ crc_num, crc_check,
406
+ ckpt_total_io_time)
393
407
  continue
394
408
  if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
395
409
  _write_hugeparameter(name, value, f)
396
410
  continue
397
411
 
398
- crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
412
+ crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
413
+ crc_num, crc_check,
414
+ ckpt_total_io_time)
399
415
 
400
416
  if enc_key is not None:
401
417
  plain_data.seek(0)
@@ -406,15 +422,22 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
406
422
  block_data = plain_data.read(max_block_size)
407
423
  if crc_check:
408
424
  f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
409
- ckpt_save_time_end = time.time()
410
- cost_time = ckpt_save_time_end - ckpt_save_time_start
411
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save ckpt cost time:{cost_time}.")
425
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
426
+ f"Save ckpt io cost time:{ckpt_total_io_time}.")
427
+
412
428
  elif format == "safetensors":
413
429
  save_dict = {}
414
430
  crc_num = 0
415
431
  for name in sorted(data_list.keys()):
416
432
  value = data_list[name]
417
- save_dict[name] = value[2].asnumpy()
433
+ if isinstance(value[2], np.ndarray):
434
+ save_dict[name] = value[2]
435
+ else:
436
+ bytes_data = value[2].get_bytes()
437
+ np_type = tensor_to_np_type.get(value[1])
438
+ np_array = np.frombuffer(bytes_data, np_type)
439
+ new_np_array = np_array.reshape(value[0])
440
+ save_dict[name] = new_np_array
418
441
 
419
442
  if crc_check:
420
443
  crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
@@ -428,11 +451,11 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
428
451
  save_file(save_dict, tmp_name)
429
452
  safetensors_save_time_end = time.time()
430
453
  cost_time = safetensors_save_time_end - safetensors_save_time_start
431
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors cost time:{cost_time}.")
454
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")
432
455
  if not os.path.exists(tmp_name):
433
456
  logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
434
457
  f"simultaneously modified a file.")
435
- else:
458
+ elif _ckpt_fs.backend != "mindio":
436
459
  os.rename(tmp_name, ckpt_file_name)
437
460
  os.chmod(ckpt_file_name, stat.S_IRUSR)
438
461
  except BaseException as e:
@@ -453,7 +476,7 @@ def _write_random_seed(name, value, f):
453
476
  f.write(checkpoint_list.SerializeToString())
454
477
 
455
478
 
456
- def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
479
+ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
457
480
  """Write parameter data into protobuf file."""
458
481
  data_size = value[2].nbytes / 1024
459
482
  if data_size > SLICE_SIZE:
@@ -475,14 +498,18 @@ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_ch
475
498
  output_data = checkpoint_list.SerializeToString()
476
499
  if crc_check:
477
500
  crc_num = binascii.crc32(output_data, crc_num)
501
+ io_start_time = time.time()
478
502
  f.write(output_data)
503
+ io_end_time = time.time()
504
+ io_cost_time = io_end_time - io_start_time
505
+ ckpt_total_io_time += io_cost_time
479
506
  else:
480
507
  plain_data.write(checkpoint_list.SerializeToString())
481
508
 
482
- return crc_num
509
+ return crc_num, ckpt_total_io_time
483
510
 
484
511
 
485
- def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
512
+ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
486
513
  """Write parameter bytes data into protobuf file."""
487
514
  bytes_value = value[2].get_bytes()
488
515
  chunk_size = 1024 * SLICE_SIZE
@@ -500,11 +527,15 @@ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0,
500
527
  output_data = checkpoint_list.SerializeToString()
501
528
  if crc_check:
502
529
  crc_num = binascii.crc32(output_data, crc_num)
530
+ io_start_time = time.time()
503
531
  f.write(output_data)
532
+ io_end_time = time.time()
533
+ io_cost_time = io_end_time - io_start_time
534
+ ckpt_total_io_time += io_cost_time
504
535
  else:
505
536
  plain_data.write(checkpoint_list.SerializeToString())
506
537
 
507
- return crc_num
538
+ return crc_num, ckpt_total_io_time
508
539
 
509
540
 
510
541
  def _write_mapparameter(name, value, f, map_param_inc=False):
@@ -583,15 +614,13 @@ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
583
614
  f"be set to default value '{default_value}', but got '{current_value}'.")
584
615
 
585
616
 
586
- def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save=False, map_param_inc=False,
587
- global_step_num=None):
617
+ def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
588
618
  """check save checkpoint unsupported param"""
589
619
  if format != "safetensors":
590
620
  return
591
621
  default_params = {
592
622
  "enc_key": None,
593
623
  "enc_mode": "AES-GCM",
594
- "async_save": False,
595
624
  "map_param_inc": False,
596
625
  "global_step_num": None
597
626
  }
@@ -633,15 +662,18 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
633
662
 
634
663
  Args:
635
664
  save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
636
- list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
637
- elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
638
- `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
639
- it can be the returned value of :func:`mindspore.load_checkpoint`.
665
+ list, or dict.
666
+
667
+ - If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
668
+ elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
669
+ `param_name` must be string, and the type of `param_data` must be parameter or Tensor).
670
+ - If dict, it can be the returned value of :func:`mindspore.load_checkpoint`.
671
+
640
672
  ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
641
673
  integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
642
- async_save (Union[bool, str]): Whether to use asynchronous saving of the checkpoint file, if True,
643
- the asynchronous thread is used by default. If the type is string,
644
- the method of asynchronous saving, it can be "process" or "thread".
674
+ async_save (Union[bool, str], optional): Whether to use asynchronous saving of the checkpoint file or
675
+ safetensors file, if True, the asynchronous thread is used by default. If the type
676
+ is string, the method of asynchronous saving, it can be "process" or "thread".
645
677
  Default: ``False`` .
646
678
  append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
647
679
  of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
@@ -652,9 +684,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
652
684
  Default: ``"AES-GCM"`` .
653
685
  choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
654
686
  a parameter name in string type, and the returned value is a bool.
655
- If returns ``True`` , the Parameter that matching the custom condition will be saved.
656
- If returns ``False`` , the Parameter that not matching the custom condition will not
657
- be saved. Default: ``None`` .
687
+ Default: ``None`` .
688
+
689
+ - If returns ``True`` , the Parameter that matching the custom condition will be saved.
690
+ - If returns ``False`` , the Parameter that not matching the custom condition will not
691
+ be saved.
692
+
658
693
  crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
659
694
  result to the file. Default: ``False`` .
660
695
  format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
@@ -693,6 +728,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
693
728
  - `Saving and Loading the Model - Saving and Loading the Model Weight
694
729
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
695
730
  """
731
+ start_save_time = time.time()
696
732
  ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
697
733
  integrated_save = Validator.check_bool(integrated_save)
698
734
  async_save = _check_async_save(async_save)
@@ -703,12 +739,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
703
739
  map_param_inc = kwargs.get('incremental', False)
704
740
  logger.info("Execute the process of saving checkpoint files.")
705
741
  global_step_num = kwargs.get('global_step_num', None)
706
- _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save, map_param_inc, global_step_num)
742
+ _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
707
743
 
708
744
  if append_dict and "__exception_save__" in append_dict:
709
745
  s1 = mindspore.hal.Stream()
710
746
  with mindspore.hal.StreamCtx(s1):
711
747
  save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
748
+ for k_name, value in append_dict.items():
749
+ if isinstance(value, (Tensor, Parameter)):
750
+ append_dict[k_name] = Tensor(Tensor_.move_to(value, "CPU", False))
712
751
  s1.synchronize()
713
752
  else:
714
753
  save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
@@ -779,9 +818,11 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
779
818
  data_list[key].append(dims)
780
819
  tensor_type = str(param["data"].dtype)
781
820
  data_list[key].append(tensor_type)
782
- data = param["data"] if async_save != "process" else param["data"].asnumpy()
821
+ data = param["data"] if async_save is False else param["data"].asnumpy()
783
822
  data_list[key].append(data)
784
823
 
824
+ from mindspore.profiler import mstx
825
+ range_id = mstx.range_start('save_checkpoint', None)
785
826
  if os.getenv("AITURBO") == "1":
786
827
  from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
787
828
  ckpt_name = os.path.basename(ckpt_file_name)
@@ -819,7 +860,32 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
819
860
  else:
820
861
  _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
821
862
 
863
+ mstx.range_end(range_id)
822
864
  logger.info("Saving checkpoint process is finished.")
865
+ end_save_time = time.time()
866
+ save_checkpoint_cost_time = end_save_time - start_save_time
867
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save checkpoint cost time {save_checkpoint_cost_time}.")
868
+
869
+
870
+ def _handle_shared_param_for_pipeline_parallel(save_obj):
871
+ """ Remove shared param for save_obj """
872
+ filtered_save_obj = []
873
+ for param_dict in save_obj:
874
+ cur_param = param_dict['data']
875
+ if isinstance(cur_param, Parameter):
876
+ if not cur_param.param_info.is_pipeline_shared_param:
877
+ filtered_save_obj.append(param_dict)
878
+ else:
879
+ filtered_save_obj.append(param_dict)
880
+ return filtered_save_obj
881
+
882
+
883
+ def _is_auto_parallel_mode(save_obj):
884
+ """Check if in auto parallel mode by verifying parameter initialization."""
885
+ for _, param in save_obj.parameters_and_names():
886
+ if param.param_info.is_param_init:
887
+ return True
888
+ return False
823
889
 
824
890
 
825
891
  def _convert_list_to_param_list(save_obj, choice_func):
@@ -860,7 +926,7 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
860
926
  """Convert a dict of Parameter to param_list."""
861
927
  param_list = []
862
928
  for (key, value) in save_obj.items():
863
- if isinstance(key, str) and isinstance(value, (Parameter, str)):
929
+ if isinstance(key, str) and (isinstance(value, (Parameter, str)) or _is_buffer_type(value)):
864
930
  if choice_func is not None and not choice_func(key):
865
931
  continue
866
932
  each_param = {"name": key, "data": value}
@@ -872,15 +938,19 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
872
938
  return param_list
873
939
 
874
940
 
875
- def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
941
+ def _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode):
876
942
  """Convert cell.parameters_and_names to OrderedDict."""
877
943
  param_dict = OrderedDict()
878
944
  for _, param in save_obj.parameters_and_names():
945
+ if param.name.startswith("accu_grads") or param.name.endswith("expert_load"):
946
+ continue
879
947
  not_sliced = not param.sliced
880
948
  is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
881
949
  # All parameters are initialized immediately under PyNative mode, skip this judgement.
882
950
  judgment = not_sliced or param.has_init
883
- if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
951
+ if param.param_info.is_pipeline_shared_param:
952
+ continue
953
+ if is_graph_mode and is_parallel_mode and judgment:
884
954
  continue
885
955
  if choice_func is not None and not choice_func(param.name):
886
956
  continue
@@ -898,11 +968,12 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
898
968
  sync_pipeline_shared_parameters(save_obj)
899
969
  param_list = []
900
970
  parameter_layout_dict = save_obj.parameter_layout_dict
901
- if _is_in_auto_parallel_mode() and not parameter_layout_dict:
971
+ is_parallel_mode = _is_auto_parallel_mode(save_obj)
972
+ if is_parallel_mode and not parameter_layout_dict:
902
973
  parameter_layout_dict = _get_parameter_layout()
903
- if not _is_in_auto_parallel_mode():
974
+ if not is_parallel_mode:
904
975
  save_obj.init_parameters_data()
905
- param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
976
+ param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
906
977
  if append_dict and "random_op" in append_dict:
907
978
  phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
908
979
  if phase in save_obj.compile_cache and _executor.has_compiled(phase):
@@ -950,11 +1021,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
950
1021
 
951
1022
  def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
952
1023
  """Convert a save_obj to param_list."""
953
- if isinstance(save_obj, list):
954
- return _convert_list_to_param_list(save_obj, choice_func)
1024
+ if isinstance(save_obj, (list, dict)):
1025
+ if isinstance(save_obj, list):
1026
+ save_obj = _convert_list_to_param_list(save_obj, choice_func)
955
1027
 
956
- if isinstance(save_obj, dict):
957
- return _convert_dict_to_param_dict(save_obj, choice_func)
1028
+ if isinstance(save_obj, dict):
1029
+ save_obj = _convert_dict_to_param_dict(save_obj, choice_func)
1030
+
1031
+ return _handle_shared_param_for_pipeline_parallel(save_obj)
958
1032
 
959
1033
  return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
960
1034
 
@@ -985,11 +1059,8 @@ def _check_append_dict(append_dict):
985
1059
  return append_dict
986
1060
 
987
1061
 
988
- def _check_load_obfuscate(**kwargs):
989
- if 'obf_func' in kwargs.keys():
990
- customized_func = _check_customized_func(kwargs.get('obf_func'))
991
- clean_funcs()
992
- add_opaque_predicate(customized_func.__name__, customized_func)
1062
+ def _is_buffer_type(value):
1063
+ if isinstance(value, Tensor) and getattr(value, "_is_buffer", False):
993
1064
  return True
994
1065
  return False
995
1066
 
@@ -1006,20 +1077,18 @@ def load(file_name, **kwargs):
1006
1077
  kwargs (dict): Configuration options dictionary.
1007
1078
 
1008
1079
  - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
1009
- - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
1080
+ - dec_mode (Union[str, function], optional):
1081
+ Specifies the decryption mode, to take effect when dec_key is set.
1010
1082
 
1011
1083
  - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
1012
1084
  - For details of using the customized decryption, please check the `tutorial
1013
1085
  <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
1014
1086
 
1015
- - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
1016
- `obfuscate_model()
1017
- <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
1018
-
1019
1087
  Returns:
1020
1088
  GraphCell, a compiled graph that can executed by `GraphCell`.
1021
1089
 
1022
1090
  Raises:
1091
+ NotImplementedError: Dynamic model structure obfuscation is no longer supported.
1023
1092
  ValueError: MindIR file does not exist or `file_name` is not a string.
1024
1093
  RuntimeError: Failed to parse MindIR file.
1025
1094
 
@@ -1046,6 +1115,8 @@ def load(file_name, **kwargs):
1046
1115
  - `Saving and Loading the Model - Saving and Loading MindIR
1047
1116
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
1048
1117
  """
1118
+ if 'obf_func' in kwargs.keys():
1119
+ raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
1049
1120
  if not isinstance(file_name, str):
1050
1121
  raise ValueError("For 'load', the argument 'file_name' must be string, but "
1051
1122
  "got {}.".format(type(file_name)))
@@ -1057,9 +1128,6 @@ def load(file_name, **kwargs):
1057
1128
  "please check whether the 'file_name' is correct.")
1058
1129
  file_name = os.path.realpath(file_name)
1059
1130
 
1060
- # set customized functions for dynamic obfuscation
1061
- obfuscated = _check_load_obfuscate(**kwargs)
1062
-
1063
1131
  logger.info("Execute the process of loading mindir.")
1064
1132
  if 'dec_key' in kwargs.keys():
1065
1133
  dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
@@ -1072,9 +1140,9 @@ def load(file_name, **kwargs):
1072
1140
  else:
1073
1141
  dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
1074
1142
  graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
1075
- decrypt=dec_func, obfuscated=obfuscated)
1143
+ decrypt=dec_func)
1076
1144
  else:
1077
- graph = load_mindir(file_name, obfuscated=obfuscated)
1145
+ graph = load_mindir(file_name)
1078
1146
 
1079
1147
  if graph is None:
1080
1148
  if _is_cipher_file(file_name):
@@ -1141,181 +1209,12 @@ def _check_param_type(param_config, key, target_type, requested):
1141
1209
  if key in param_config:
1142
1210
  if not isinstance(param_config[key], target_type):
1143
1211
  raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
1144
- if key == 'obf_random_seed':
1145
- if param_config[key] > INT_64_MAX or param_config[key] <= 0:
1146
- raise ValueError(
1147
- "'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
1148
- param_config[key]))
1149
1212
  return param_config[key]
1150
1213
  if requested:
1151
1214
  raise ValueError("The parameter {} is requested, but not got.".format(key))
1152
- if key == "obf_random_seed":
1153
- return 0
1154
1215
  return None
1155
1216
 
1156
1217
 
1157
- def _check_customized_func(customized_func):
1158
- """ check customized function of dynamic obfuscation """
1159
- if not callable(customized_func):
1160
- raise TypeError(
1161
- "'customized_func' must be a function, but not got {}.".format(type(customized_func)))
1162
- # test customized_func
1163
- try:
1164
- func_result = customized_func(1.0, 1.0)
1165
- except Exception as ex:
1166
- raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
1167
- else:
1168
- if not isinstance(func_result, bool):
1169
- raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
1170
- return customized_func
1171
-
1172
-
1173
- def _check_obfuscate_params(obf_config):
1174
- """Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
1175
- if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
1176
- raise ValueError(
1177
- "At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
1178
- obfuscate_type = _check_param_type(obf_config, "type", str, False)
1179
- if obfuscate_type not in (None, "dynamic"):
1180
- raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
1181
- if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
1182
- if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
1183
- raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
1184
- obf_config['obf_ratio']))
1185
- ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
1186
- obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
1187
- obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
1188
- if (obf_ratio <= 0) or (obf_ratio > 1):
1189
- raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
1190
- customized_funcs = []
1191
- if 'customized_func' in obf_config.keys():
1192
- device_target = context.get_context('device_target')
1193
- if device_target in ["GPU", "Ascend"]:
1194
- raise ValueError(
1195
- "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
1196
- customized_funcs.append(_check_customized_func(obf_config['customized_func']))
1197
- obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
1198
- return obf_ratio, customized_funcs, obf_random_seed
1199
-
1200
-
1201
- def obfuscate_model(obf_config, **kwargs):
1202
- """
1203
- Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
1204
- predict correctness. The obfuscated model can prevent attackers from stealing the model.
1205
-
1206
- Args:
1207
- obf_config (dict): obfuscation config.
1208
-
1209
- - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1210
- - original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
1211
- model is encrypted, then enc_key and enc_mode should be provided.
1212
- - save_model_path (str): The path to save the obfuscated model.
1213
- - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
1214
- is the same as using :func:`mindspore.export`.
1215
- - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1216
- should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1217
- correspond to 0.1, 0.3, and 0.6 respectively.
1218
- - customized_func (function): A python function used for customized function mode, which used for control
1219
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1220
- Reference to 'my_func()' in
1221
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1222
- This function needs to ensure that its result is constant for any input. Users can refer to opaque
1223
- predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
1224
- when loading obfuscated model.
1225
- - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1226
- structure of obfuscated models corresponding to different random seeds is different. If
1227
- `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
1228
- interface when loading
1229
- obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1230
- be set, and the latter mode would be applied if both of them are set.
1231
-
1232
- kwargs (dict): Configuration options dictionary.
1233
-
1234
- - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
1235
- - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
1236
- Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
1237
-
1238
- Raises:
1239
- TypeError: If `obf_config` is not a dict.
1240
- ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
1241
- ValueError: If `original_model_path` is not provided in `obf_config`.
1242
- ValueError: If the model saved in `original_model_path` has been obfuscated.
1243
- ValueError: If `save_model_path` is not provided in `obf_config`.
1244
- ValueError: If `obf_ratio` is not provided in `obf_config`.
1245
- ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
1246
- ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
1247
- ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
1248
-
1249
- Examples:
1250
- >>> import mindspore as ms
1251
- >>> import mindspore.nn as nn
1252
- >>> import numpy as np
1253
- >>> # Download ori_net.mindir
1254
- >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
1255
- >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
1256
- >>> obf_config = {'original_model_path': "./net.mindir",
1257
- ... 'save_model_path': "./obf_net",
1258
- ... 'model_inputs': [input1, ],
1259
- ... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
1260
- >>> ms.obfuscate_model(obf_config)
1261
- >>> obf_func = ms.load("obf_net.mindir")
1262
- >>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
1263
- >>> print(obf_net(input1).asnumpy())
1264
- """
1265
- if not isinstance(obf_config, dict):
1266
- raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
1267
- file_path = _check_param_type(obf_config, "original_model_path", str, True)
1268
- if not file_path.endswith(".mindir"):
1269
- raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
1270
- "please input the correct 'file_path'.")
1271
- if not os.path.exists(file_path):
1272
- raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
1273
- "please check whether the 'file_path' is correct.")
1274
- saved_path = _check_param_type(obf_config, "save_model_path", str, True)
1275
- model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
1276
- for item in model_inputs:
1277
- if not isinstance(item, Tensor):
1278
- raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
1279
- if -1 in item.shape:
1280
- raise ValueError(
1281
- "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
1282
- obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
1283
- if customized_funcs and obf_random_seed > 0:
1284
- logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
1285
- " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
1286
-
1287
- if obf_random_seed == 0: # apply customized_func mode
1288
- clean_funcs()
1289
- for func in customized_funcs:
1290
- add_opaque_predicate(func.__name__, func)
1291
- branch_control_input = 0
1292
- else: # apply password mode
1293
- branch_control_input = _generate_branch_control_input(obf_random_seed)
1294
-
1295
- if 'enc_key' in kwargs.keys():
1296
- enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
1297
- enc_mode = "AES-GCM"
1298
- if 'enc_mode' in kwargs.keys():
1299
- enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
1300
- if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
1301
- raise ValueError(
1302
- "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
1303
- "obfuscate_model(), but got {}.".format(enc_mode))
1304
- obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1305
- branch_control_input=branch_control_input, dec_key=enc_key,
1306
- key_len=len(enc_key),
1307
- dec_mode=enc_mode)
1308
- else:
1309
- obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1310
- branch_control_input=branch_control_input)
1311
-
1312
- obf_net = nn.GraphCell(obf_graph)
1313
- if obf_random_seed != 0:
1314
- append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
1315
- model_inputs += [append_y_tensor]
1316
- export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
1317
-
1318
-
1319
1218
  def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1320
1219
  dec_mode, crc_check, format):
1321
1220
  """load parameter into parameter_dict"""
@@ -1323,17 +1222,22 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
1323
1222
  if format == "safetensors":
1324
1223
  with safe_open(ckpt_file_name, framework='np') as f:
1325
1224
  cal_crc_num = 0
1326
- sf_load_time_start = time.time()
1225
+ total_io_cost_time = 0
1327
1226
  for k in sorted(f.keys()):
1328
1227
  if crc_check:
1329
1228
  cal_crc_num = binascii.crc32(bytes(k, encoding='utf-8'), cal_crc_num)
1330
1229
  cal_crc_num = binascii.crc32(bytes(f.get_tensor(k)), cal_crc_num)
1331
1230
  if choice_func is not None and not choice_func(k):
1332
1231
  continue
1333
- parameter_dict[k] = Parameter(Tensor.from_numpy(f.get_tensor(k)))
1334
- sf_load_time_end = time.time()
1335
- cost_time = sf_load_time_end - sf_load_time_start
1336
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load safetensors cost time:{cost_time}.")
1232
+ io_start_time = time.time()
1233
+ value = f.get_tensor(k)
1234
+ io_end_time = time.time()
1235
+ io_cost_time = io_end_time - io_start_time
1236
+ total_io_cost_time += io_cost_time
1237
+ parameter_dict[k] = Parameter(Tensor.from_numpy(value))
1238
+
1239
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
1240
+ f"Load safetensors io cost time:{total_io_cost_time}.")
1337
1241
  if crc_check:
1338
1242
  if f.metadata() is None or f.metadata().get("crc_num") is None:
1339
1243
  logger.warning(
@@ -1411,38 +1315,37 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1411
1315
  Load checkpoint info from a specified file.
1412
1316
 
1413
1317
  Note:
1414
- - `specify_prefix` and `filter_prefix` do not affect each other.
1415
- - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1416
1318
  - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1417
- `choice_func` is recommended instead.
1319
+ `choice_func` is recommended instead. `specify_prefix` and `filter_prefix` do not affect each other.
1418
1320
  And using either of those two args will override `choice_func` at the same time.
1321
+ - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1419
1322
  - When loading a checkpoint that has removed redundancy, the network should be compiled.
1420
1323
 
1421
1324
  Args:
1422
1325
  ckpt_file_name (str): Checkpoint file name.
1423
- net (Cell): The network where the parameters will be loaded. Default: ``None`` .
1424
- strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1425
- into net when parameter name's suffix in checkpoint file is the same as the
1326
+ net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1327
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1328
+ parameter into net when parameter name's suffix in checkpoint file is the same as the
1426
1329
  parameter in the network. When the types are inconsistent perform type conversion
1427
1330
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1428
- filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1429
- filter_prefix will not be loaded. Default: ``None`` .
1430
- dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
1431
- is not required. Default: ``None`` .
1432
- dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
1433
- mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
1331
+ filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
1332
+ Parameters starting with the filter_prefix will not be loaded. Default: ``None`` .
1333
+ dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
1334
+ the decryption is not required. Default: ``None`` .
1335
+ dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies the
1336
+ decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
1434
1337
  Default: ``"AES-GCM"`` .
1435
- specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1436
- specify_prefix will be loaded. Default: ``None`` .
1437
- choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
1338
+ specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
1339
+ Parameters starting with the specify_prefix will be loaded. Default: ``None`` .
1340
+ choice_func (Union[None, function], optional) : Input value of the function is a Parameter name of type string,
1438
1341
  and the return value is a bool. If returns ``True`` , the Parameter
1439
1342
  that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1440
1343
  matches the custom condition will be removed. Default: ``None`` .
1441
- crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1442
- remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1344
+ crc_check (bool, optional) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1345
+ remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
1443
1346
  Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1444
1347
  redundant-free loading is not enabled.
1445
- format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1348
+ format (str, optional): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
1446
1349
 
1447
1350
  Returns:
1448
1351
  Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
@@ -1487,6 +1390,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1487
1390
  - `Saving and Loading the Model - Saving and Loading the Model Weight
1488
1391
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1489
1392
  """
1393
+ start_load_time = time.time()
1490
1394
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
1491
1395
  specify_prefix = _check_prefix(specify_prefix)
1492
1396
  filter_prefix = _check_prefix(filter_prefix)
@@ -1535,6 +1439,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
1535
1439
  _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1536
1440
 
1537
1441
  vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Load checkpoint is finished.")
1442
+ end_load_time = time.time()
1443
+ load_checkpoint_cost_time = end_load_time - start_load_time
1444
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load checkpoint cost time {load_checkpoint_cost_time}.")
1538
1445
  return parameter_dict
1539
1446
 
1540
1447
 
@@ -1554,7 +1461,7 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1554
1461
  And using either of those two args will override `choice_func` at the same time.
1555
1462
 
1556
1463
  Args:
1557
- ckpt_file_name (str): Checkpoint file name.
1464
+ ckpt_file_name (str): Checkpoint file name. The file extension must be `ckpt` or `safetensors` .
1558
1465
  net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1559
1466
  strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1560
1467
  parameter into net when parameter name's suffix in checkpoint file is the
@@ -1612,10 +1519,11 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
1612
1519
  >>> model.train(2, dataset)
1613
1520
  >>> print("param dict len: ", len(param_dict), flush=True)
1614
1521
  """
1522
+ format = "safetensors" if ckpt_file_name.endswith(".safetensors") else "ckpt"
1615
1523
  from concurrent.futures import ThreadPoolExecutor
1616
1524
  executor = ThreadPoolExecutor(max_workers=2)
1617
1525
  param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
1618
- dec_key, dec_mode, specify_prefix, choice_func)
1526
+ dec_key, dec_mode, specify_prefix, choice_func, format=format)
1619
1527
  return ParamDictFuture(executor, param_dict_future)
1620
1528
 
1621
1529
 
@@ -1703,7 +1611,7 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1703
1611
  pb_content = f.read()
1704
1612
  ckpt_load_time_end = time.time()
1705
1613
  cost_time = ckpt_load_time_end - ckpt_load_time_start
1706
- vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt cost time:{cost_time}.")
1614
+ vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt io cost time:{cost_time}.")
1707
1615
 
1708
1616
  else:
1709
1617
  pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
@@ -1774,17 +1682,18 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1774
1682
  Load parameters into network, return parameter list that are not loaded in the network.
1775
1683
 
1776
1684
  Note:
1777
- - When loading a parameter dict that has removed redundancy, the network should be compiled.
1685
+ When loading a parameter dict that has removed redundancy, the network should be compiled.
1778
1686
 
1779
1687
  Args:
1780
1688
  net (Cell): The network where the parameters will be loaded.
1781
1689
  parameter_dict (dict): The dictionary generated by load checkpoint file,
1782
1690
  it is a dictionary consisting of key: parameters's name, value: parameter.
1783
- strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1691
+ strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` ,
1692
+ it will load parameter
1784
1693
  into net when parameter name's suffix in checkpoint file is the same as the
1785
1694
  parameter in the network. When the types are inconsistent perform type conversion
1786
1695
  on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1787
- remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
1696
+ remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
1788
1697
  Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
1789
1698
  redundant-free loading is not enabled.
1790
1699
 
@@ -1825,6 +1734,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1825
1734
  param_not_load = []
1826
1735
  ckpt_not_load = list(parameter_dict.keys())
1827
1736
  for _, param in net.parameters_and_names():
1737
+ if param.param_info.is_pipeline_shared_param:
1738
+ continue
1828
1739
  if param.name in parameter_dict:
1829
1740
  if isinstance(param, MapParameter):
1830
1741
  param.import_data(parameter_dict[param.name])
@@ -1843,31 +1754,24 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
1843
1754
  if param_not_load and not strict_load:
1844
1755
  _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1845
1756
 
1846
- logger.info("Loading parameters into net is finished.")
1847
- if param_not_load:
1848
- logger.warning("For 'load_param_into_net', "
1849
- "{} parameters in the 'net' are not loaded, because they are not in the "
1850
- "'parameter_dict', please check whether the network structure is consistent "
1851
- "when training and loading checkpoint. Another possibility is that "
1852
- "the redundant loading is not enabled, but the loaded checkpoint is saved with "
1853
- "redundancy removed. ".format(len(param_not_load)))
1854
- logger.warning("{} are not loaded.".format(param_not_load))
1855
1757
  if remove_redundancy:
1856
- parallel_mode = context.get_auto_parallel_context("parallel_mode")
1857
- if parallel_mode == "stand_alone":
1758
+ if get_group_size() == 1:
1858
1759
  raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
1859
- f"in parallel scenarios, but got {parallel_mode}.")
1760
+ f"in parallel scenarios, but got stand_alone.")
1860
1761
  if not net.compile_cache and not net.parameter_layout_dict:
1861
1762
  raise ValueError("When loading a parameter dict that has removed redundancy, "
1862
1763
  "the network should be compiled.")
1863
1764
  param_layout = net.parameter_layout_dict
1864
- rank_id = get_rank()
1865
- device_num = _get_device_num()
1866
- stage_num = _get_auto_parallel_context("pipeline_stages")
1867
- chunk_size = device_num // stage_num
1868
- initial_rank = (rank_id // chunk_size) * chunk_size
1869
- _single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
1765
+ _single_parameter_broadcast(net, param_layout, param_not_load)
1766
+ mindspore.hal.synchronize()
1870
1767
 
1768
+ logger.info("Loading parameters into net is finished.")
1769
+ if param_not_load:
1770
+ logger.warning("For 'load_param_into_net', "
1771
+ "{} parameters in the 'net' are not loaded, because they are not in the "
1772
+ "'parameter_dict', please check whether the network structure is consistent "
1773
+ "when training and loading checkpoint.".format(len(param_not_load)))
1774
+ logger.warning("{} are not loaded.".format(param_not_load))
1871
1775
  return param_not_load, ckpt_not_load
1872
1776
 
1873
1777
 
@@ -2050,9 +1954,6 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
2050
1954
  elif opt_shard_group:
2051
1955
  allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
2052
1956
  tuple(after_reshape_slice_shape))
2053
- elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
2054
- allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
2055
- tuple(after_reshape_slice_shape))
2056
1957
  net.parallel_parameter_merge_net_dict[param_name] = allgather_net
2057
1958
  if allgather_net:
2058
1959
  param_data = allgather_net(param_data)
@@ -2106,27 +2007,6 @@ def export(net, *inputs, file_name, file_format, **kwargs):
2106
2007
 
2107
2008
  - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
2108
2009
  preprocessing of the dataset into MindIR.
2109
-
2110
- - obf_config (dict): obfuscation config.
2111
-
2112
- - type (str): The type of obfuscation, only 'dynamic' is supported until now.
2113
- - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
2114
- should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
2115
- correspond to 0.1, 0.3, and 0.6 respectively.
2116
- - customized_func (function): A python function used for customized function mode, which used for control
2117
- the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
2118
- Reference to 'my_func()' in
2119
- `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
2120
- This function needs to ensure that its result is constant for any input. Users can refer to opaque
2121
- predicates. If customized_func is set, then it should be passed to `load()` interface when loading
2122
- obfuscated model.
2123
- - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
2124
- structure of obfuscated models corresponding to different random seeds is different. If
2125
- `obf_random_seed` is set, then it should be passed
2126
- to :class:`mindspore.nn.GraphCell` interface when loading
2127
- obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
2128
- be set, and the latter mode would be applied if both of them are set.
2129
-
2130
2010
  - incremental (bool): export MindIR incrementally.
2131
2011
 
2132
2012
  - custom_func (function): Functions for custom defined export policies. This function will be used to
@@ -2160,6 +2040,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
2160
2040
  - `Saving and Loading the Model - Saving and Loading MindIR
2161
2041
  <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
2162
2042
  """
2043
+ if 'obf_func' in kwargs.keys():
2044
+ raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
2163
2045
  old_ms_jit_value = context.get_context("jit_syntax_level")
2164
2046
  context.set_context(jit_syntax_level=mindspore.STRICT)
2165
2047
 
@@ -2241,8 +2123,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
2241
2123
  It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
2242
2124
  """
2243
2125
  logger.info("exporting model file:%s format:%s.", file_name, file_format)
2244
- if "obf_config" in kwargs and file_format != "MINDIR":
2245
- raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
2246
2126
  if "custom_func" in kwargs and file_format != "MINDIR" and kwargs["custom_func"] is not None:
2247
2127
  raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
2248
2128
  if file_format == 'AIR':
@@ -2456,14 +2336,13 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
2456
2336
  os.chmod(data_file_name, stat.S_IRUSR)
2457
2337
 
2458
2338
 
2459
- def _msfunc_info(net, *inputs):
2339
+ def _msfunc_info(net, jit_executor, *inputs):
2460
2340
  """Get mindir stream and parameter dict of ms_function"""
2461
2341
  # pylint: disable=protected-access
2462
2342
  net_dict = OrderedDict()
2463
- _ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
2464
- graph_id = _ms_func_executor.compile(net.__name__, *inputs)
2465
- mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
2466
- params = _ms_func_executor._graph_executor.get_params(graph_id)
2343
+ graph_id = jit_executor.compile(net.__name__, *inputs)
2344
+ mindir_stream = jit_executor._get_func_graph_proto(net, graph_id, 'mind_ir')
2345
+ params = jit_executor._graph_executor.get_params(graph_id)
2467
2346
  for name, value in params.items():
2468
2347
  net_dict[name] = Parameter(value, name=name)
2469
2348
  return mindir_stream, net_dict
@@ -2475,53 +2354,21 @@ def _cell_info(net, incremental, *inputs):
2475
2354
  graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2476
2355
  # pylint: disable=protected-access
2477
2356
  mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
2478
- # clean obfuscation config to prevent the next call
2479
- _executor.obfuscate_config = None
2480
-
2481
2357
  net_dict = net.parameters_dict()
2482
2358
  return mindir_stream, net_dict
2483
2359
 
2484
2360
 
2485
- def _set_obfuscate_config(**kwargs):
2486
- """Set obfuscation config for executor."""
2487
- logger.warning("Obfuscate model.")
2488
- if 'enc_mode' in kwargs.keys():
2489
- enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
2490
- if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
2491
- raise ValueError(
2492
- "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
2493
- "obfuscation, but got {}.".format(enc_mode))
2494
- obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
2495
- if customized_funcs and obf_random_seed > 0:
2496
- logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
2497
- " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
2498
-
2499
- if obf_random_seed == 0: # apply customized_func mode
2500
- device_target = context.get_context('device_target')
2501
- if device_target in ["GPU", "Ascend"]:
2502
- raise ValueError(
2503
- "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
2504
- clean_funcs()
2505
- for func in customized_funcs:
2506
- add_opaque_predicate(func.__name__, func)
2507
- _executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
2508
-
2509
-
2510
2361
  def _save_mindir(net, file_name, *inputs, **kwargs):
2511
2362
  """Save MindIR format file."""
2512
- # set obfuscate configs
2513
- if 'obf_config' in kwargs.keys():
2514
- _set_obfuscate_config(**kwargs)
2515
- for item in inputs:
2516
- if -1 in item.shape:
2517
- raise ValueError(
2518
- "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
2363
+ executor = _executor
2364
+ if not isinstance(net, nn.Cell):
2365
+ executor = _JitExecutor(net, time.time() * 1e9)
2519
2366
 
2520
2367
  incremental = kwargs.get('incremental', False)
2521
2368
 
2522
2369
  model = mindir_model()
2523
2370
  if not isinstance(net, nn.Cell):
2524
- mindir_stream, net_dict = _msfunc_info(net, *inputs)
2371
+ mindir_stream, net_dict = _msfunc_info(net, executor, *inputs)
2525
2372
  else:
2526
2373
  mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
2527
2374
  model.ParseFromString(mindir_stream)
@@ -2594,8 +2441,10 @@ def _save_together(net_dict, model):
2594
2441
  if name in net_dict.keys():
2595
2442
  data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
2596
2443
  else:
2597
- raise ValueError("The parameter '{}' is not belongs to any cell,"
2598
- "the data of parameter cannot be exported.".format(param_proto.name))
2444
+ raise ValueError("There's a mindspore.Parameter that wasn't created in nn.Cell, and mindspore.export() "
2445
+ f"does not support exporting such Parameters. The parameter name is: {name}.\n"
2446
+ "You can find the supported syntax range for mindspore.export() at the following link:\n"
2447
+ "https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html")
2599
2448
  if data_total > TOTAL_SAVE:
2600
2449
  return False
2601
2450
  return True
@@ -2762,566 +2611,6 @@ def parse_print(print_file_name):
2762
2611
  return tensor_list
2763
2612
 
2764
2613
 
2765
- def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
2766
- """
2767
- Merge data slices to one tensor with whole data when strategy is not None.
2768
-
2769
- Args:
2770
- sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
2771
- parameter_name (str): Name of parameter.
2772
- strategy (dict): Parameter slice strategy.
2773
- is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
2774
-
2775
- Returns:
2776
- Tensor, the merged Tensor which has the whole data.
2777
-
2778
- Raises:
2779
- ValueError: Failed to merge.
2780
- """
2781
- layout = strategy.get(parameter_name)
2782
- try:
2783
- dev_mat = list(layout.dev_matrix[0].dim)
2784
- tensor_map = list(layout.tensor_map[0].dim)
2785
- param_split_shape = list(layout.param_split_shape[0].dim)
2786
- field_size = int(layout.field)
2787
- except BaseException as e:
2788
- raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
2789
- f", please make sure that 'strategy' is correct.") from e
2790
-
2791
- device_count = 1
2792
- for dim in dev_mat:
2793
- device_count *= dim
2794
-
2795
- if len(sliced_data) != device_count:
2796
- raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
2797
- f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
2798
- f"device_count is {device_count}.")
2799
-
2800
- if not param_split_shape:
2801
- if not is_even:
2802
- raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
2803
- "should be the same when slice manner is even.")
2804
-
2805
- all_gather_tensor = Tensor(np.concatenate(sliced_data))
2806
-
2807
- if field_size > 0:
2808
- merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
2809
- else:
2810
- merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
2811
-
2812
- else:
2813
- tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
2814
-
2815
- slice_count = 1
2816
- for dim in tensor_strategy:
2817
- slice_count *= dim
2818
-
2819
- if len(param_split_shape) != slice_count:
2820
- raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
2821
- f"{slice_count}, but got {len(param_split_shape)}.")
2822
-
2823
- tensor_slices_new = list(range(slice_count))
2824
- tensor_slices = sliced_data
2825
- for i in range(device_count):
2826
- slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
2827
- if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
2828
- raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
2829
- f"{param_split_shape[slice_index]} in 0 axis, but got "
2830
- f"{tensor_slices[i].shape[0]}.")
2831
- tensor_slices_new[slice_index] = np.array(tensor_slices[i])
2832
-
2833
- dim_len = len(tensor_strategy)
2834
- for i in range(dim_len):
2835
- ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
2836
- tensor_slices_new_inner = []
2837
- for j in range(ele_count):
2838
- new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
2839
- for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
2840
- (j + 1) * tensor_strategy[dim_len - 1 - i]):
2841
- new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
2842
- tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
2843
- tensor_slices_new = tensor_slices_new_inner
2844
- merged_tensor = Tensor(tensor_slices_new[0])
2845
-
2846
- return merged_tensor
2847
-
2848
-
2849
- def restore_group_info_list(group_info_file_name):
2850
- """
2851
- Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2852
- who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2853
- environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2854
-
2855
- Args:
2856
- group_info_file_name (str): Name of group information file.
2857
-
2858
- Returns:
2859
- List, the rank list.
2860
-
2861
- Raises:
2862
- ValueError: group information file is incorrect.
2863
- TypeError: `group_info_file_name` is not str.
2864
-
2865
- Examples:
2866
- >>> import mindspore as ms
2867
- >>> ms.restore_list = restore_group_info_list("./group_info.pb")
2868
- """
2869
- if not isinstance(group_info_file_name, str):
2870
- raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
2871
- f"but got {type(group_info_file_name)}.")
2872
-
2873
- if not os.path.isfile(group_info_file_name):
2874
- raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
2875
-
2876
- if os.path.getsize(group_info_file_name) == 0:
2877
- raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
2878
-
2879
- return _restore_group_info_list(group_info_file_name)
2880
-
2881
-
2882
- def build_searched_strategy(strategy_filename):
2883
- """
2884
- Build strategy of every parameter in network. Used in the case of distributed inference.
2885
-
2886
- Args:
2887
- strategy_filename (str): Name of strategy file.
2888
-
2889
- Returns:
2890
- Dict, whose key is parameter name and value is slice strategy of this parameter.
2891
-
2892
- Raises:
2893
- ValueError: Strategy file is incorrect.
2894
- TypeError: `strategy_filename` is not a string.
2895
-
2896
- Examples:
2897
- >>> import mindspore as ms
2898
- >>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
2899
- """
2900
- return _build_searched_strategy(strategy_filename)
2901
-
2902
-
2903
- def merge_sliced_parameter(sliced_parameters, strategy=None):
2904
- """
2905
- Merge parameter slices into one parameter. Used in the case of distributed inference.
2906
-
2907
- Args:
2908
- sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
2909
- strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
2910
- value is slice strategy of this parameter. If strategy is None, just merge
2911
- parameter slices in 0 axis order. Default: ``None``.
2912
-
2913
- Returns:
2914
- Parameter, the merged parameter which has the whole data.
2915
-
2916
- Raises:
2917
- ValueError: Failed to merge.
2918
- TypeError: The sliced_parameters is incorrect or strategy is not dict.
2919
- KeyError: The parameter name is not in keys of strategy.
2920
-
2921
- Examples:
2922
- >>> import numpy as np
2923
- >>> import mindspore as ms
2924
- >>> from mindspore import Tensor, Parameter
2925
- >>>
2926
- >>> sliced_parameters = [
2927
- ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
2928
- ... "network.embedding_table"),
2929
- ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
2930
- ... "network.embedding_table"),
2931
- ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
2932
- ... "network.embedding_table"),
2933
- ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
2934
- ... "network.embedding_table")]
2935
- >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
2936
- >>> print(merged_parameter)
2937
- Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
2938
- """
2939
- if not isinstance(sliced_parameters, list):
2940
- raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
2941
- f"but got {type(sliced_parameters)}.")
2942
-
2943
- if not sliced_parameters:
2944
- raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
2945
-
2946
- if strategy and not isinstance(strategy, dict):
2947
- raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
2948
- f"but got {type(strategy)}.")
2949
-
2950
- try:
2951
- parameter_name = sliced_parameters[0].name
2952
- parameter_shape = sliced_parameters[0].data.shape
2953
- parameter_shape_length = len(parameter_shape)
2954
- except BaseException as e:
2955
- raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
2956
- f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
2957
-
2958
- is_even = True
2959
- for index, parameter in enumerate(sliced_parameters):
2960
- if not isinstance(parameter, Parameter):
2961
- raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
2962
- f"but got {type(parameter)} at index {index}.")
2963
-
2964
- if parameter.name != parameter_name \
2965
- or len(parameter.data.shape) != parameter_shape_length \
2966
- or parameter.data.shape[1:] != parameter_shape[1:]:
2967
- raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
2968
- f" have the same name, dimension length and shape except 0 axis. The name, dimension "
2969
- f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
2970
- f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
2971
- f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
2972
- f"at index {index}.")
2973
-
2974
- if parameter.data.shape != parameter_shape:
2975
- is_even = False
2976
-
2977
- layerwise_parallel = sliced_parameters[0].layerwise_parallel
2978
- requires_grad = sliced_parameters[0].requires_grad
2979
- sliced_data = []
2980
- for parameter in sliced_parameters:
2981
- if parameter.data.dtype == mstype.bfloat16:
2982
- sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
2983
- else:
2984
- sliced_data.append(parameter.data.asnumpy())
2985
-
2986
- if not strategy:
2987
- merged_tensor = Tensor(np.concatenate(sliced_data))
2988
- merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2989
-
2990
- else:
2991
- if parameter_name not in strategy.keys():
2992
- raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
2993
- f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
2994
- merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
2995
- merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2996
-
2997
- return merged_parameter
2998
-
2999
-
3000
- def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
3001
- output_format, name_map, return_param_dict):
3002
- """gather transform tasks"""
3003
- tasks = []
3004
- for rank in range(0, dst_device_num):
3005
- tasks.append(
3006
- (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
3007
- return_param_dict))
3008
- return tasks
3009
-
3010
-
3011
- def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
3012
- train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
3013
- format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
3014
- output_format='safetensors', name_map=None, max_process_num=64,
3015
- return_param_dict=False):
3016
- """
3017
- Load checkpoint into net for distributed predication. Used in the case of distributed inference.
3018
-
3019
- Note:
3020
- `output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
3021
-
3022
- Args:
3023
- network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
3024
- can be left blank or passed as None, and the interface will execute save mode.
3025
- checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
3026
- predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
3027
- when setting predict_strategy as None. Default: ``None`` .
3028
- train_strategy_filename (str): The filename of training strategy protocol buffer file.
3029
- When train_strategy_filename is None, the training strategy file will be
3030
- obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
3031
- Therefore, the training strategy file needs to be specified
3032
- in at least one of them. Default: ``None`` .
3033
- strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
3034
- into net when parameter name's suffix in checkpoint file is the same as the
3035
- parameter in the network. When the types are inconsistent, perform type conversion
3036
- on the parameters of the same type, such as float32 to float16. Default: ``False`` .
3037
- dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
3038
- is not required. Default: ``None`` .
3039
- dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
3040
- mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
3041
- Default: ``'AES-GCM'`` .
3042
- format (str): Input weight format to be loaded into the network.
3043
- It can be set to either "ckpt" or "safetensors". Default: "ckpt".
3044
- unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
3045
- Default: ``None`` .
3046
- dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
3047
- rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
3048
- globally by initializing the network; In save mode, save the file according to the input
3049
- sequence number. If it is not input, save the entire file.
3050
- output_format (str, optional): Control the format of the output checkpoint after conversion.
3051
- It can be set to either "ckpt" or "safetensors". Default: "safetensors".
3052
- name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
3053
- dictionary before loading or saving the segmented weights into the network. Default: None.
3054
- max_process_num (int): Maximum number of processes. Default: 64.
3055
- return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
3056
-
3057
- Raises:
3058
- TypeError: The type of inputs do not match the requirements.
3059
- ValueError: Failed to load checkpoint into net.
3060
-
3061
- Supported Platforms:
3062
- ``Ascend`` ``GPU`` ``CPU``
3063
-
3064
- Examples:
3065
- .. note::
3066
- Before running the following examples, you need to configure the communication environment variables.
3067
-
3068
- For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
3069
- Please see the `rank table startup
3070
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
3071
- for more details.
3072
-
3073
- For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
3074
- <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
3075
-
3076
- For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
3077
- Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
3078
-
3079
- >>> import os
3080
- >>> import numpy as np
3081
- >>> import mindspore as ms
3082
- >>> import mindspore.dataset as ds
3083
- >>> from mindspore import nn, ops, train
3084
- >>> from mindspore.communication import init
3085
- >>>
3086
- >>> step_per_epoch = 4
3087
- >>> device_num = 8
3088
- >>>
3089
- >>> # Define the network structure.
3090
- >>> class Net(nn.Cell):
3091
- ... def __init__(self, matmul_size, strategy=None):
3092
- ... super().__init__()
3093
- ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
3094
- ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
3095
- ... self.matmul = ops.MatMul()
3096
- ... self.neg = ops.Neg()
3097
- ... if strategy is not None:
3098
- ... self.matmul.shard(strategy)
3099
- ...
3100
- ... def construct(self, inputs):
3101
- ... x = self.matmul(inputs, self.matmul_weight)
3102
- ... x = self.neg(x)
3103
- ... return x
3104
- >>>
3105
- >>> # Create dataset.
3106
- >>> def get_dataset(*inputs):
3107
- ... def generate():
3108
- ... for _ in range(step_per_epoch):
3109
- ... yield inputs
3110
- ... return generate
3111
- >>>
3112
- >>> # Train network and save distributed checkpoint.
3113
- >>> def train_net():
3114
- ... ms.set_context(mode=ms.GRAPH_MODE)
3115
- ... init()
3116
- ... np.random.seed(1)
3117
- ... input_data = np.random.rand(16, 96).astype(np.float32)
3118
- ... label_data = np.random.rand(16, 16).astype(np.float32)
3119
- ... fake_dataset = get_dataset(input_data, label_data)
3120
- ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
3121
- ...
3122
- ... # Set parallel strategy.
3123
- ... strategy = ((1, 4), (4, 1))
3124
- ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
3125
- ... strategy_ckpt_save_file="./train_strategy.ckpt")
3126
- ... network = Net(matmul_size=(96, 16), strategy=strategy)
3127
- ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
3128
- ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
3129
- ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
3130
- ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
3131
- ... global_rank_id = int(os.getenv("RANK_ID"))
3132
- ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
3133
- ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
3134
- ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
3135
- ... ms.reset_auto_parallel_context()
3136
- >>>
3137
- >>> # Load distributed checkpoint and test.
3138
- >>> def load_model():
3139
- ... ms.set_context(mode=ms.GRAPH_MODE)
3140
- ... init()
3141
- ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
3142
- ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
3143
- ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
3144
- ... network = Net(matmul_size=(96, 16))
3145
- ... model = ms.Model(network)
3146
- ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
3147
- ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
3148
- ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
3149
- ... predict_result = model.predict(predict_data)
3150
- ... print(predict_result)
3151
- >>>
3152
- >>> train_net()
3153
- >>> load_model()
3154
- [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
3155
- [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
3156
- ...
3157
- [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
3158
- """
3159
- if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
3160
- raise ValueError(
3161
- f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
3162
- f"must be 'ckpt' or 'safetensors', but got {format}.")
3163
-
3164
- if format == 'safetensors':
3165
- if unified_safetensors_dir is None:
3166
- raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
3167
- f"when format is 'safetensors'.")
3168
- unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
3169
- for param in unsupport_param:
3170
- if param is not None:
3171
- raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
3172
- f"when format is 'safetensors'.")
3173
- if strict_load or dec_mode != 'AES-GCM':
3174
- raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
3175
- f"when format is 'safetensors'.")
3176
- if network is not None:
3177
- try:
3178
- rank_id = get_rank()
3179
- except RuntimeError:
3180
- rank_id = 0
3181
- logger.warning(f"Get rank failed, default loading weight for rank 0.")
3182
- param_dict = _load_parallel_checkpoint(
3183
- (unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
3184
- return_param_dict))
3185
- return param_dict
3186
- if dst_safetensors_dir is None:
3187
- raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
3188
- f"when network is None.")
3189
- if rank_id is not None:
3190
- _load_parallel_checkpoint(
3191
- (unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3192
- rank_id, output_format, name_map, return_param_dict))
3193
- else:
3194
- dst_strategy_dict = _build_searched_strategy(predict_strategy)
3195
- dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
3196
- dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
3197
- dst_device_num = dst_stage_device_num * dst_stage_num
3198
- tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
3199
- dst_device_num, output_format, name_map, return_param_dict)
3200
- with Pool(processes=max_process_num) as pool:
3201
- list(pool.imap(_load_parallel_checkpoint, tasks))
3202
- return True
3203
-
3204
- network = Validator.check_isinstance("network", network, nn.Cell)
3205
- _check_checkpoint_file(checkpoint_filenames)
3206
- _check_predict_strategy(predict_strategy)
3207
-
3208
- dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
3209
- dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
3210
-
3211
- if train_strategy_filename is None:
3212
- train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
3213
- _train_strategy = build_searched_strategy(train_strategy_filename)
3214
- train_strategy = _convert_to_list(_train_strategy)
3215
-
3216
- train_dev_count = 1
3217
- ckpt_file_len = len(checkpoint_filenames)
3218
- for dim in train_strategy[list(train_strategy.keys())[0]][0]:
3219
- train_dev_count *= dim
3220
- if train_dev_count != ckpt_file_len:
3221
- raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
3222
- f"equal to the device count of training process. "
3223
- f"But got the length of 'checkpoint_filenames'"
3224
- f" is {ckpt_file_len} and the device count is {train_dev_count}.")
3225
- rank_list = _infer_rank_list(train_strategy, predict_strategy)
3226
-
3227
- param_total_dict = defaultdict(dict)
3228
- for file_index, file_name in enumerate(checkpoint_filenames):
3229
- ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
3230
- for param_name, param in ckpt_dict.items():
3231
- param_total_dict[param_name][file_index] = param
3232
-
3233
- param_dict = {}
3234
- param_not_in_strategy = []
3235
- param_not_in_ckpt = []
3236
- for _, param in network.parameters_and_names():
3237
- sliced_params = []
3238
- if param.name not in rank_list.keys():
3239
- param_not_in_strategy.append(param.name)
3240
- continue
3241
- if param.name not in param_total_dict:
3242
- param_not_in_ckpt.append(param.name)
3243
- continue
3244
-
3245
- param_rank = rank_list.get(param.name)[0]
3246
- skip_merge_split = rank_list.get(param.name)[1]
3247
- shard_stride = train_strategy.get(param.name)[4]
3248
- tensor_map = train_strategy.get(param.name)[1]
3249
- first_dim_shard_idx = tensor_map[0] if tensor_map else -1
3250
- device_arrangement = train_strategy.get(param.name)[0]
3251
- first_dim_shard_size = 1
3252
- if first_dim_shard_idx >= 0:
3253
- first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
3254
- if train_strategy.get(param.name)[5]:
3255
- repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
3256
- else:
3257
- repeat_size = 0
3258
- for rank in param_rank:
3259
- param_total_list = list(range(0, ckpt_file_len))
3260
- if first_dim_shard_size != 1:
3261
- param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
3262
- if repeat_size > 0:
3263
- shard_size = shard_stride * train_strategy.get(param.name)[5]
3264
- rank_index = param_total_list.index(rank)
3265
- start = rank_index // shard_size * shard_size
3266
- param_total_list = param_total_list[start:start + shard_size]
3267
- if shard_stride > 0:
3268
- param_stride = []
3269
- # merge pre parameter
3270
- param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
3271
- param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
3272
- param_index = list(set(param_index))
3273
- param_index.sort()
3274
- for rank_num in param_index:
3275
- if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
3276
- param_stride.append(
3277
- cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
3278
- else:
3279
- param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
3280
-
3281
- sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
3282
- else:
3283
- sliced_param = param_total_dict[param.name][rank]
3284
-
3285
- sliced_params.append(sliced_param)
3286
- if skip_merge_split:
3287
- split_param = sliced_params[0]
3288
- else:
3289
- param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
3290
- _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
3291
- split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
3292
- opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
3293
- if opt_shard_group:
3294
- if split_param.data.dtype == mstype.bfloat16:
3295
- data = cpu_cast(split_param.data, mstype.float32).asnumpy()
3296
- else:
3297
- data = split_param.data.asnumpy()
3298
- rank = get_rank(opt_shard_group)
3299
- size = get_group_size(opt_shard_group)
3300
- try:
3301
- data_slice = np.split(data, size)[rank]
3302
- except BaseException as e:
3303
- logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
3304
- " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
3305
- raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
3306
- f" in load distributed checkpoint for {param.name}. Data shape is "
3307
- f"{split_param.data.shape} and group is {opt_shard_group}.") from e
3308
- split_param = Parameter(Tensor(data_slice), param.name,
3309
- split_param.requires_grad, split_param.layerwise_parallel)
3310
- param_dict[param.name] = split_param
3311
-
3312
- if param_not_in_strategy:
3313
- logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
3314
- "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
3315
- .format(param_not_in_strategy))
3316
- if param_not_in_ckpt:
3317
- logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
3318
- "the checkpoint file, please check whether 'checkpoint_filenames' is correct."
3319
- .format(param_not_in_ckpt))
3320
-
3321
- load_param_into_net(network, param_dict, strict_load=strict_load)
3322
- return True
3323
-
3324
-
3325
2614
  def async_ckpt_thread_status():
3326
2615
  """
3327
2616
  Get the status of asynchronous save checkpoint thread.
@@ -3346,69 +2635,6 @@ def async_ckpt_thread_status():
3346
2635
  return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
3347
2636
 
3348
2637
 
3349
- def _check_predict_strategy(predict_strategy):
3350
- """Check predict strategy."""
3351
-
3352
- def _check_int_list(arg):
3353
- if not isinstance(arg, list):
3354
- return False
3355
- for item in arg:
3356
- if not isinstance(item, int):
3357
- return False
3358
- return True
3359
-
3360
- if predict_strategy is None:
3361
- return
3362
-
3363
- flag = True
3364
- predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
3365
- for key in predict_strategy.keys():
3366
- if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
3367
- or len(predict_strategy[key]) < 4:
3368
- flag = False
3369
- dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
3370
- if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
3371
- not (_check_int_list(param_split_shape) or not param_split_shape) or \
3372
- not (isinstance(field_size, int) and field_size == 0):
3373
- flag = False
3374
-
3375
- if not flag:
3376
- raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
3377
- f"the key of it must be string, and the value of it must be list or tuple that "
3378
- f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
3379
- f"param_split_shape (list[int]) and field_size (int, which value is 0)."
3380
- f"Please check whether 'predict_strategy' is correct.")
3381
-
3382
-
3383
- def _check_checkpoint_file(checkpoint_filenames):
3384
- """Check checkpoint file name."""
3385
- for index, filename in enumerate(checkpoint_filenames):
3386
- if not isinstance(filename, str) or not os.path.exists(filename) \
3387
- or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
3388
- raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
3389
- f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
3390
- f"be a string ending with '.ckpt', and the checkpoint file it represents must "
3391
- f"be exist and not empty.")
3392
-
3393
-
3394
- def _merge_and_split(sliced_params, train_strategy, predict_strategy):
3395
- """Merge sliced parameter and split it according to the predict strategy."""
3396
- merged_param = merge_sliced_parameter(sliced_params, train_strategy)
3397
- if predict_strategy is None:
3398
- return merged_param
3399
- param_name = merged_param.name
3400
- tensor_layout = predict_strategy[param_name]
3401
- rank = get_rank()
3402
- split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
3403
- requires_grad = merged_param.requires_grad
3404
- layerwise_parallel = merged_param.layerwise_parallel
3405
- if merged_param.data.dtype == mstype.bfloat16:
3406
- split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
3407
- else:
3408
- split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
3409
- return split_param
3410
-
3411
-
3412
2638
  def _calculation_net_size(net):
3413
2639
  """Calculate the size of parameters in the network."""
3414
2640
  data_total = 0
@@ -3702,3 +2928,35 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
3702
2928
  ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
3703
2929
  dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
3704
2930
  mindspore.save_checkpoint(param_dict_tensor, dst_file)
2931
+
2932
+
2933
+ def restore_group_info_list(group_info_file_name):
2934
+ """
2935
+ Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2936
+ who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2937
+ environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2938
+ """
2939
+ return new_restore_group_info_list(group_info_file_name)
2940
+
2941
+
2942
+ def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
2943
+ train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
2944
+ format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
2945
+ output_format='safetensors', name_map=None, max_process_num=64,
2946
+ return_param_dict=False):
2947
+ """ Load checkpoint into net for distributed predication. Used in the case of distributed inference. """
2948
+ new_load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy,
2949
+ train_strategy_filename, strict_load, dec_key, dec_mode,
2950
+ format, unified_safetensors_dir, dst_safetensors_dir, rank_id,
2951
+ output_format, name_map, max_process_num,
2952
+ return_param_dict)
2953
+
2954
+
2955
+ def merge_sliced_parameter(sliced_parameters, strategy=None):
2956
+ """ Merge parameter slices into one parameter. Used in the case of distributed inference. """
2957
+ return new_merge_sliced_parameter(sliced_parameters, strategy)
2958
+
2959
+
2960
+ def build_searched_strategy(strategy_filename):
2961
+ """ Build strategy of every parameter in network. Used in the case of distributed inference. """
2962
+ return new_build_searched_strategy(strategy_filename)