mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,13 +15,26 @@
15
15
  """cell"""
16
16
  from __future__ import absolute_import
17
17
 
18
- import gc
19
18
  import inspect
20
19
  import os
21
20
  import time
22
- from collections import OrderedDict
23
- import numpy
24
-
21
+ import warnings
22
+ import itertools
23
+ from collections import OrderedDict, namedtuple
24
+ from typing import (
25
+ Dict,
26
+ Optional,
27
+ Set,
28
+ Callable,
29
+ List,
30
+ Tuple,
31
+ Iterator,
32
+ Any,
33
+ TypeVar,
34
+ Mapping
35
+ )
36
+
37
+ import mindspore as ms
25
38
  from mindspore._checkparam import args_type_check, check_hook_fn
26
39
  from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
27
40
  from mindspore import log as logger
@@ -34,19 +47,62 @@ from mindspore import _checkparam as Validator
34
47
  from mindspore.common import dtype as mstype
35
48
  from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
36
49
  _no_grad
37
- from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
50
+ from mindspore.common.api import _convert_python_data, _get_args_for_run_predict
38
51
  from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
39
- from mindspore.common.parameter import Parameter, ParameterTuple
52
+ from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
40
53
  from mindspore.common.tensor import Tensor
41
54
  from mindspore.ops.operations import Cast
42
55
  from mindspore.ops.primitive import Primitive
43
56
  from mindspore.ops.operations import _inner_ops as inner
44
57
  from mindspore.parallel.shard import Shard
58
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
45
59
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
46
60
  from mindspore.common._decorator import deprecated
47
61
  from mindspore.common._register_for_recompute import recompute_registry
48
62
 
49
63
 
64
+ __all__ = [
65
+ "register_cell_buffer_registration_hook",
66
+ ]
67
+
68
+ _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
69
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
70
+
71
+
72
+ class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),):
73
+ def __repr__(self):
74
+ if not self.missing_keys and not self.unexpected_keys:
75
+ return "<All keys matched successfully>"
76
+ return super().__repr__()
77
+
78
+ __str__ = __repr__
79
+
80
+
81
+ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
82
+ r"""Register a buffer registration hook common to all cells.
83
+
84
+ .. warning ::
85
+
86
+ This adds global state to the `nn.Cell` cell
87
+
88
+ The hook will be called every time :func:`register_buffer` is invoked.
89
+ It should have the following signature::
90
+
91
+ hook(cell, name, buffer) -> None or new buffer
92
+
93
+ The hook can modify the input or return a single modified value in the hook.
94
+
95
+ Returns:
96
+ A handle that can be used to remove the added hook by calling
97
+ `handle.remove()`.
98
+ """
99
+ from mindspore.utils.hooks import _RemovableHandle
100
+ handle = _RemovableHandle(_global_buffer_registration_hooks)
101
+ _global_buffer_registration_hooks[handle.id] = hook
102
+ return handle
103
+
104
+
105
+
50
106
  class Cell(Cell_):
51
107
  """
52
108
  The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
@@ -108,6 +164,8 @@ class Cell(Cell_):
108
164
  '_attr_synced', 'pynative', 'requires_grad', 'cell_type',
109
165
  '_parameters_forward_hook', '_parameters_backward_hook']
110
166
  total_instance_count = 0
167
+ _buffers: Dict[str, Optional[Tensor]]
168
+ _non_persistent_buffers_set: Set[str]
111
169
 
112
170
  def __init__(self, auto_prefix=True, flags=None):
113
171
  Cell_.__init__(self, self._cell_tag)
@@ -115,10 +173,17 @@ class Cell(Cell_):
115
173
  self.instance_count = Cell.total_instance_count
116
174
  self._params = OrderedDict()
117
175
  self._cells = OrderedDict()
176
+ super().__setattr__("_buffers", {})
177
+ super().__setattr__("_non_persistent_buffers_set", set())
178
+ super().__setattr__("_state_dict_hooks", OrderedDict())
179
+ super().__setattr__("_state_dict_pre_hooks", OrderedDict())
180
+ super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
181
+ super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
118
182
  self._params_list = OrderedDict()
119
183
  self._primitives = OrderedDict()
120
184
  self.training = False
121
185
  self.requires_grad = False
186
+ self.is_top_cell = False
122
187
  self.pynative = False
123
188
  self._attr_synced = False
124
189
  self._param_prefix = ''
@@ -135,8 +200,8 @@ class Cell(Cell_):
135
200
  cells_compile_cache[id(self)] = self.compile_cache
136
201
  self.parameter_broadcast_done = False
137
202
  self._id = 1
138
- self.exist_names = set("")
139
- self.exist_objs = set()
203
+ self._exist_objs = None
204
+ self._exist_names = None
140
205
  self._recompute_cell = None
141
206
  self.mixed_precision_type = None
142
207
  self.sig = inspect.signature(self.construct)
@@ -146,7 +211,6 @@ class Cell(Cell_):
146
211
  if os.getenv('GC_COLLECT_IN_CELL') == '1':
147
212
  logger.warning("The convenient environment 'GC_COLLECT_IN_CELL' is deprecated from version 2.5 "
148
213
  "and will be removed in a future version.")
149
- gc.collect()
150
214
 
151
215
  if flags:
152
216
  self.add_flags(**flags)
@@ -209,6 +273,21 @@ class Cell(Cell_):
209
273
  def cell_init_args(self):
210
274
  return self._cell_init_args
211
275
 
276
+ @property
277
+ def exist_names(self):
278
+ """
279
+ Get exist parameter names adding by tuple or list of parameter.
280
+ """
281
+ if self._exist_names is None:
282
+ self._exist_names = set("")
283
+ return self._exist_names
284
+
285
+ @property
286
+ def exist_objs(self):
287
+ if self._exist_objs is None:
288
+ self._exist_objs = set()
289
+ return self._exist_objs
290
+
212
291
  @property
213
292
  def param_prefix(self):
214
293
  """
@@ -237,11 +316,6 @@ class Cell(Cell_):
237
316
  def bprop_debug(self):
238
317
  """
239
318
  Get whether cell custom bprop debug is enabled.
240
-
241
- Tutorial Examples:
242
- - `Custom Neural Network Layers - Custom Cell Reverse
243
- <https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
244
- #custom-cell-reverse>`_
245
319
  """
246
320
  return self._bprop_debug
247
321
 
@@ -358,8 +432,6 @@ class Cell(Cell_):
358
432
  raise ValueError("For 'Cell', the property 'pipeline_stage' "
359
433
  "can not be less than 0, but got {}".format(value))
360
434
  self._pipeline_stage = value
361
- for item in self.trainable_params():
362
- item.add_pipeline_stage(value)
363
435
 
364
436
  @property
365
437
  def pipeline_segment(self):
@@ -395,6 +467,374 @@ class Cell(Cell_):
395
467
  def enable_backward_hook(self):
396
468
  return self._enable_backward_hook
397
469
 
470
+ @jit_forbidden_register
471
+ def register_buffer(
472
+ self, name: str, tensor: Optional[Tensor], persistent: bool = True
473
+ ) -> None:
474
+ r"""Add a buffer to the cell.
475
+
476
+ This is typically used to register a buffer that should not to be
477
+ considered a model parameter. For example, BatchNorm's `running_mean`
478
+ is not a parameter, but is part of the cell's state. Buffers, by
479
+ default, are persistent and will be saved alongside parameters. This
480
+ behavior can be changed by setting `persistent` to ``False`` . The
481
+ only difference between a persistent buffer and a non-persistent buffer
482
+ is that the latter will not be a part of this cell's :attr:`state_dict` .
483
+
484
+ Buffers can be accessed as attributes using given names.
485
+
486
+ Args:
487
+ name (str): name of the buffer. The buffer can be accessed
488
+ from this cell using the given name.
489
+ tensor (Tensor): Buffer to be registered. If ``None`` ,
490
+ the buffer is not included in the cell's :attr:`state_dict` .
491
+ persistent (bool, optional): Whether the buffer is part of this cell's :attr:`state_dict`. Default ``True``.
492
+
493
+ Examples:
494
+ >>> import mindspore
495
+ ...
496
+ >>> class Net(mindspore.nn.Cell):
497
+ ... def __init__(self):
498
+ ... super().__init__()
499
+ ... self.register_buffer("buffer0", mindspore.tensor([1, 2, 3]))
500
+ ...
501
+ ... def construct(self, x):
502
+ ... return x + self.net_buffer
503
+ ...
504
+ >>> net = Net()
505
+ >>> net.register_buffer("buffer0", mindspore.tensor([4, 5, 6]))
506
+ >>> print(net.buffer0)
507
+ [4 5 6]
508
+ """
509
+
510
+ if "_buffers" not in self.__dict__:
511
+ raise AttributeError("cannot assign buffer before Cell.__init__() call")
512
+ if not isinstance(name, str):
513
+ raise TypeError(
514
+ f"buffer name should be a string.But got this type: {type(name)}"
515
+ )
516
+ if "." in name:
517
+ raise KeyError('buffer name can\'t contain "."')
518
+ if name == "":
519
+ raise KeyError('buffer name can\'t be empty string ""')
520
+ if hasattr(self, name) and name not in self._buffers:
521
+ raise KeyError(f"attribute '{name}' already exists")
522
+ if tensor is not None and not isinstance(tensor, Tensor):
523
+ raise TypeError(
524
+ f"cannot assign '{type(tensor)}' object to buffer '{name}' "
525
+ "(mindspore Tensor or None required)"
526
+ )
527
+ for hook in _global_buffer_registration_hooks.values():
528
+ output = hook(self, name, tensor)
529
+ if output is not None:
530
+ tensor = output
531
+ if tensor is not None:
532
+ tensor._is_buffer = True
533
+ self._buffers[name] = tensor
534
+ if persistent:
535
+ self._non_persistent_buffers_set.discard(name)
536
+ else:
537
+ self._non_persistent_buffers_set.add(name)
538
+
539
+ @jit_forbidden_register
540
+ def get_buffer(self, target: str) -> "Tensor":
541
+ """Return the buffer given by `target` if it exists, otherwise throw an error.
542
+
543
+ See the docstring for `get_sub_cell` for a more detailed
544
+ explanation of this method's functionality as well as how to
545
+ correctly specify `target` .
546
+
547
+ Args:
548
+ target (str): The fully-qualified string name of the buffer
549
+ to look for. (See `get_sub_cell` for how to specify a
550
+ fully-qualified string.)
551
+
552
+ Returns:
553
+ Tensor
554
+
555
+ Examples:
556
+ >>> import mindspore
557
+ ...
558
+ ...
559
+ >>> class NetC(mindspore.nn.Cell):
560
+ ... def __init__(self):
561
+ ... super().__init__()
562
+ ... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
563
+ ...
564
+ ... def construct(self, x):
565
+ ... return x + self.buffer_c
566
+ ...
567
+ ...
568
+ >>> class NetB(mindspore.nn.Cell):
569
+ ... def __init__(self, net_c):
570
+ ... super().__init__()
571
+ ... self.net_c = net_c
572
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
573
+ ...
574
+ ... def construct(self, x):
575
+ ... return self.net_c(x) + self.buffer_b
576
+ ...
577
+ ...
578
+ >>> class NetA(mindspore.nn.Cell):
579
+ ... def __init__(self, net_b):
580
+ ... super().__init__()
581
+ ... self.net_b = net_b
582
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
583
+ ...
584
+ ... def construct(self, x):
585
+ ... return self.net_b(x) + self.buffer_a
586
+ ...
587
+ ...
588
+ >>> net_c = NetC()
589
+ >>> net_b = NetB(net_c)
590
+ >>> net_a = NetA(net_b)
591
+ >>> buffer_c = net_a.get_buffer("net_b.net_c.buffer_c")
592
+ >>> print(f'buffer_c is {buffer_c}')
593
+ buffer_c is [0 0 0]
594
+
595
+ """
596
+ cell_path, _, buffer_name = target.rpartition(".")
597
+
598
+ cell = self.get_sub_cell(cell_path)
599
+
600
+ if not hasattr(cell, buffer_name):
601
+ raise AttributeError(
602
+ cell._get_name() + " has no attribute `" + buffer_name + "`"
603
+ )
604
+
605
+ buffer = getattr(cell, buffer_name)
606
+
607
+ if buffer_name not in cell._buffers:
608
+ raise AttributeError("`" + buffer_name + "` is not a buffer")
609
+
610
+ return buffer
611
+
612
+ @jit_forbidden_register
613
+ def named_buffers(
614
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
615
+ ) -> Iterator[Tuple[str, Tensor]]:
616
+ r"""Return an iterator over cell buffers, yielding both the name of the buffer as well as the buffer itself.
617
+
618
+ Args:
619
+ prefix (str, optional): prefix to prepend to all buffer names. Default ``""``.
620
+ recurse (bool, optional): if ``True`` , then yields buffers of this cell
621
+ and all sub cells. Otherwise, yields only buffers that
622
+ are direct members of this cell. Default ``True``.
623
+ remove_duplicate (bool, optional): Whether to remove the duplicated buffers in the result. Default ``True``.
624
+
625
+ Returns:
626
+ Iterator[Tuple[str, Tensor]], an iterator of tuple containing the name and buffer.
627
+
628
+ Examples:
629
+ >>> import mindspore
630
+ ...
631
+ ...
632
+ >>> class NetB(mindspore.nn.Cell):
633
+ ... def __init__(self):
634
+ ... super().__init__()
635
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
636
+ ...
637
+ ... def construct(self, x):
638
+ ... return x + self.buffer_b
639
+ ...
640
+ ...
641
+ >>> class NetA(mindspore.nn.Cell):
642
+ ... def __init__(self, net_b):
643
+ ... super().__init__()
644
+ ... self.net_b = net_b
645
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
646
+ ...
647
+ ... def construct(self, x):
648
+ ... return self.net_b(x) + self.buffer_a
649
+ ...
650
+ ...
651
+ >>> net_b = NetB()
652
+ >>> net_a = NetA(net_b)
653
+ >>>
654
+ >>> for name, buffer in net_a.named_buffers():
655
+ ... print(f'buffer name is {name}, buffer is {buffer}')
656
+ buffer name is buffer_a, buffer is [4 5 6]
657
+ buffer name is net_b.buffer_b, buffer is [1 2 3]
658
+
659
+ """
660
+ gen = self._named_members(
661
+ lambda cell: cell._buffers.items(),
662
+ prefix=prefix,
663
+ recurse=recurse,
664
+ remove_duplicate=remove_duplicate,
665
+ )
666
+ yield from gen
667
+
668
+ @jit_forbidden_register
669
+ def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
670
+ r"""Return an iterator over cell buffers.
671
+
672
+ Args:
673
+ recurse (bool, optional): If ``True`` , then yields buffers of this cell
674
+ and all sub cells. Otherwise, yields only buffers that
675
+ are direct members of this cell. Default ``True``.
676
+
677
+ Returns:
678
+ Iterator[Tensor], an iterator of buffer.
679
+
680
+ Examples:
681
+ >>> import mindspore
682
+ ...
683
+ ...
684
+ >>> class NetB(mindspore.nn.Cell):
685
+ ... def __init__(self):
686
+ ... super().__init__()
687
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
688
+ ...
689
+ ... def construct(self, x):
690
+ ... return x + self.buffer_b
691
+ ...
692
+ ...
693
+ >>> class NetA(mindspore.nn.Cell):
694
+ ... def __init__(self, net_b):
695
+ ... super().__init__()
696
+ ... self.net_b = net_b
697
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
698
+ ...
699
+ ... def construct(self, x):
700
+ ... return self.net_b(x) + self.buffer_a
701
+ ...
702
+ ...
703
+ >>> net_b = NetB()
704
+ >>> net_a = NetA(net_b)
705
+ >>>
706
+ >>> for buffer in net_a.buffers():
707
+ ... print(f'buffer is {buffer}')
708
+ buffer is [4 5 6]
709
+ buffer is [1 2 3]
710
+
711
+ """
712
+ for _, buf in self.named_buffers(recurse=recurse):
713
+ yield buf
714
+
715
+ def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
716
+ r"""Help yield various names + members of cells."""
717
+ memo = set()
718
+ cells = (
719
+ self.cells_and_names(name_prefix=prefix)
720
+ if recurse
721
+ else [(prefix, self)]
722
+ )
723
+ for cell_prefix, cell in cells:
724
+ members = get_members_fn(cell)
725
+ for k, v in members:
726
+ if v is None or v in memo:
727
+ continue
728
+ if remove_duplicate:
729
+ memo.add(v)
730
+ name = cell_prefix + ("." if cell_prefix else "") + k
731
+ yield name, v
732
+
733
+ @jit_forbidden_register
734
+ def get_sub_cell(self, target: str) -> "Cell":
735
+ """Return the sub cell given by `target` if it exists, otherwise throw an error.
736
+
737
+ For example, let's say you have an ``nn.Cell`` ``A`` that
738
+ looks like this:
739
+
740
+ .. code-block:: text
741
+
742
+ A(
743
+ (net_b): NetB(
744
+ (net_c): NetC(
745
+ (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
746
+ )
747
+ (dense): Dense(in_features=100, out_features=200, bias=True)
748
+ )
749
+ )
750
+
751
+ (The diagram shows an ``nn.Cell`` ``A``. ``A`` has a nested
752
+ sub cell ``net_b``, which itself has two sub cells ``net_c``
753
+ and ``dense``. ``net_c`` then has a sub cell ``conv``.)
754
+
755
+ To check whether we have the ``dense`` sub cell, we
756
+ would call `get_sub_cell("net_b.dense")`. To check whether
757
+ we have the ``conv`` sub cell, we would call
758
+ `get_sub_cell("net_b.net_c.conv")`.
759
+
760
+ The runtime of ``get_sub_cell`` is bounded by the degree
761
+ of cell nesting in `target`. A query against
762
+ `name_cells` achieves the same result, but it is O(N) in
763
+ the number of transitive cells. So, for a simple check to see
764
+ if some sub cells exist, ``get_sub_cell`` should always be
765
+ used.
766
+
767
+ Args:
768
+ target (str): The fully-qualified string name of the sub cell
769
+ to look for. (See above example for how to specify a
770
+ fully-qualified string.)
771
+
772
+ Returns:
773
+ Cell
774
+
775
+ Examples:
776
+ >>> import mindspore
777
+ ...
778
+ ...
779
+ >>> class NetC(mindspore.nn.Cell):
780
+ ... def __init__(self):
781
+ ... super().__init__()
782
+ ... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
783
+ ... self.dense_c = mindspore.nn.Dense(5, 3)
784
+ ...
785
+ ... def construct(self, x):
786
+ ... return self.dense_c(x) + self.buffer_c
787
+ ...
788
+ ...
789
+ >>> class NetB(mindspore.nn.Cell):
790
+ ... def __init__(self, net_c):
791
+ ... super().__init__()
792
+ ... self.net_c = net_c
793
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
794
+ ...
795
+ ... def construct(self, x):
796
+ ... return self.net_c(x) + self.buffer_b
797
+ ...
798
+ ...
799
+ >>> class NetA(mindspore.nn.Cell):
800
+ ... def __init__(self, net_b):
801
+ ... super().__init__()
802
+ ... self.net_b = net_b
803
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
804
+ ...
805
+ ... def construct(self, x):
806
+ ... return self.net_b(x) + self.buffer_a
807
+ ...
808
+ ...
809
+ >>> net_c = NetC()
810
+ >>> net_b = NetB(net_c)
811
+ >>> net_a = NetA(net_b)
812
+ >>> net_c = net_a.get_sub_cell("net_b.net_c")
813
+ >>> print(f'net_c is {net_c}')
814
+ net_c is NetC(
815
+ (dense_c): Dense(input_channels=5, output_channels=3, has_bias=True)
816
+ )
817
+
818
+ """
819
+ if target == "":
820
+ return self
821
+
822
+ atoms: List[str] = target.split(".")
823
+ cell = self
824
+
825
+ for item in atoms:
826
+ if not hasattr(cell, item):
827
+ raise AttributeError(
828
+ cell._get_name() + " has no " "attribute `" + item + "`"
829
+ )
830
+
831
+ cell = getattr(cell, item)
832
+
833
+ if not isinstance(cell, Cell):
834
+ raise AttributeError("`" + item + "` is not " "an nn.Cell")
835
+
836
+ return cell
837
+
398
838
  def get_func_graph_proto(self):
399
839
  """Return graph binary proto."""
400
840
  exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
@@ -405,6 +845,10 @@ class Cell(Cell_):
405
845
  params = self.__dict__['_params']
406
846
  if name in params:
407
847
  return params[name]
848
+ if '_buffers' in self.__dict__:
849
+ buffers = self.__dict__['_buffers']
850
+ if name in buffers:
851
+ return buffers[name]
408
852
  if '_cells' in self.__dict__:
409
853
  cells = self.__dict__['_cells']
410
854
  if name in cells:
@@ -427,6 +871,8 @@ class Cell(Cell_):
427
871
  def __delattr__(self, name):
428
872
  if name in self._params:
429
873
  del self._params[name]
874
+ elif name in self._buffers:
875
+ del self._buffers[name]
430
876
  elif name in self._cells:
431
877
  del self._cells[name]
432
878
  elif '_params_list' in self.__dict__ and name in self._params_list:
@@ -600,6 +1046,89 @@ class Cell(Cell_):
600
1046
  for prim in all_prims:
601
1047
  prim.add_prim_attr("strategy_gen_mode", "data_parallel")
602
1048
 
1049
+ def offload(self, backward_prefetch="Auto"):
1050
+ """
1051
+ Set the cell offload. All primitive ops in the cell will be set offload. For the intermediate
1052
+ activations calculated by these primitive ops, we will not save them in the forward pass, but
1053
+ offload them and onload them in the backward pass.
1054
+
1055
+ Note:
1056
+ - If Cell.offload is called, the mode should be set to "GRAPH_MODE".
1057
+ - If Cell.offload is called, lazyinline should be enabled.
1058
+
1059
+ Args:
1060
+ backward_prefetch(Union[str, int], optional): The timing for prefetching activations in advance in backward
1061
+ pass. Default: ``"Auto"``. If set it to ``"Auto"``, framework
1062
+ will start to prefetch activations one operator in advance.
1063
+ If set it to a positive int value, framework will start to
1064
+ prefetch activations ``backward_prefetch`` operators in
1065
+ advance, such as 1, 20, 100.
1066
+ Examples:
1067
+ >>> import mindspore.nn as nn
1068
+ >>> from mindspore import ops
1069
+ >>> from mindspore.common import Tensor, Parameter
1070
+ >>> from mindspore.common.lazy_inline import lazy_inline
1071
+ >>>
1072
+ >>> class Block(nn.Cell):
1073
+ ... def __init__(self):
1074
+ ... super(Block, self).__init__()
1075
+ ... self.transpose1 = ops.Transpose()
1076
+ ... self.transpose2 = ops.Transpose()
1077
+ ... self.transpose3 = ops.Transpose()
1078
+ ... self.transpose4 = ops.Transpose()
1079
+ ... self.real_div1 = ops.RealDiv()
1080
+ ... self.real_div2 = ops.RealDiv()
1081
+ ... self.batch_matmul1 = ops.BatchMatMul()
1082
+ ... self.batch_matmul2 = ops.BatchMatMul()
1083
+ ... self.softmax = ops.Softmax(-1)
1084
+ ... self.expand_dims = ops.ExpandDims()
1085
+ ... self.sub = ops.Sub()
1086
+ ... self.y = Parameter(Tensor(np.ones((1024, 128, 128)).astype(np.float32)))
1087
+ ... def construct(self, x):
1088
+ ... transpose1 = self.transpose1(x, (0, 2, 1, 3))
1089
+ ... real_div1 = self.real_div1(transpose1, Tensor(2.37891))
1090
+ ... transpose2 = self.transpose2(x, (0, 2, 3, 1))
1091
+ ... real_div2 = self.real_div2(transpose2, Tensor(2.37891))
1092
+ ... batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
1093
+ ... expand_dims = self.expand_dims(self.y, 1)
1094
+ ... sub = self.sub(Tensor([1.0]), expand_dims)
1095
+ ... soft_max = self.softmax(sub)
1096
+ ... transpose3 = self.transpose3(x, (0, 2, 1, 3))
1097
+ ... batch_matmul2 = self.batch_matmul2(soft_max[0], transpose3)
1098
+ ... transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
1099
+ ... return transpose4
1100
+ >>>
1101
+ >>> class OuterBlock(nn.Cell):
1102
+ ... @lazy_inline
1103
+ ... def __init__(self):
1104
+ ... super(OuterBlock, self).__init__()
1105
+ ... self.block = Block()
1106
+ ... def construct(self, x):
1107
+ ... return self.block(x)
1108
+ >>>
1109
+ >>> class Nets(nn.Cell):
1110
+ ... def __init__(self):
1111
+ ... super(Nets, self).__init__()
1112
+ ... self.blocks = nn.CellList()
1113
+ ... for _ in range(3):
1114
+ ... b = OuterBlock()
1115
+ ... b.offload()
1116
+ ... self.blocks.append(b)
1117
+ ... def construct(self, x):
1118
+ ... out = x
1119
+ ... for i in range(3):
1120
+ ... out = self.blocks[i](out)
1121
+ ... return out
1122
+ """
1123
+ if context._get_mode() == context.PYNATIVE_MODE:
1124
+ raise ValueError("The Cell offload does not support PyNative mode now.")
1125
+ if isinstance(backward_prefetch, str):
1126
+ Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
1127
+ else:
1128
+ Validator.check_non_negative_int(backward_prefetch)
1129
+ for prim in self._get_prims_recursively():
1130
+ prim._offload(backward_prefetch=backward_prefetch)
1131
+
603
1132
  def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
604
1133
  """
605
1134
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
@@ -628,7 +1157,7 @@ class Cell(Cell_):
628
1157
  If the parameter name is incorrect or the corresponding parameter
629
1158
  has been set, the parameter setting will be ignored.
630
1159
  Default: ``None`` .
631
- device (string): Select a certain device target. It is not in use right now.
1160
+ device (str): Select a certain device target. It is not in use right now.
632
1161
  Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
633
1162
  level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
634
1163
  over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
@@ -660,10 +1189,8 @@ class Cell(Cell_):
660
1189
  ... x = self.block2_shard(x)
661
1190
  ... return x
662
1191
  """
663
- if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
664
- raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
665
- f"Please check the parallel mode in parallel context.")
666
-
1192
+ if ms.communication.management.get_group_size() == 1:
1193
+ return self
667
1194
  shard_fn = Shard()
668
1195
  fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
669
1196
  self._shard_fn = fn
@@ -766,7 +1293,8 @@ class Cell(Cell_):
766
1293
  """
767
1294
  Process cell info before call construct
768
1295
  """
769
- if self.requires_grad:
1296
+ if self.requires_grad and (not _pynative_executor.grad_flag() or _pynative_executor.high_order()):
1297
+ self.is_top_cell = True
770
1298
  _pynative_executor.set_grad_flag(True)
771
1299
  _pynative_executor.new_graph(self, *args, **kwargs)
772
1300
  elif self._dynamic_shape_inputs is not None:
@@ -780,8 +1308,9 @@ class Cell(Cell_):
780
1308
  """
781
1309
  Process cell info after call construct
782
1310
  """
783
- if self.requires_grad:
1311
+ if self.requires_grad and self.is_top_cell:
784
1312
  _pynative_executor.end_graph(self, output, *args, **kwargs)
1313
+ self.is_top_cell = False
785
1314
  elif self._dynamic_shape_inputs is not None:
786
1315
  _pynative_executor.set_cell_use_dynamic_shape_process(False)
787
1316
 
@@ -826,52 +1355,41 @@ class Cell(Cell_):
826
1355
  self._add_attr(key, value)
827
1356
  self._attr_synced = True
828
1357
 
829
- def _set_attr_for_parameter(self, name, value):
830
- """Set attr for parameter."""
831
- cells = self.__dict__.get('_cells')
832
- params = self.__dict__.get('_params')
833
- if params is None:
834
- raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
835
- if name in self.__dict__:
836
- if self.__dict__[name] is not None:
837
- raise TypeError(f"For 'Cell', the {name} should not be Parameter.")
838
- del self.__dict__[name]
839
- if cells and name in cells:
840
- raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.")
841
- self.insert_param_to_cell(name, value)
842
-
843
- def _set_attr_for_parameter_tuple(self, name, value):
844
- """Set attr for parameter in ParameterTuple."""
845
- params = self.__dict__.get('_params')
846
- params_list = self.__dict__.get('_params_list')
847
- if params is None:
848
- raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
849
- exist_names = set("")
850
- exist_objs = set()
851
- for item in value:
852
- if item in exist_objs:
853
- # If there are multiple identical objects, their names only check once.
854
- continue
855
- exist_objs.add(item)
856
- if item.name == PARAMETER_NAME_DEFAULT:
857
- logger.warning("For 'Cell', the parameter definition is deprecated.\n"
858
- "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
859
- item.name = item.name + "$" + str(self._id)
860
- self._id += 1
861
- self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
862
- if item.name in exist_names:
863
- raise ValueError("The value {} , its name '{}' already exists. "
864
- "Please set a unique name for the parameter.".format(value, item.name))
865
- exist_names.add(item.name)
866
-
867
- if context._get_mode() == context.PYNATIVE_MODE:
1358
+ def _set_attr_for_param_or_param_tuple(self, name, value):
1359
+ """Set attr for param and tensor."""
1360
+ if isinstance(value, Parameter):
868
1361
  if name in self.__dict__:
869
1362
  del self.__dict__[name]
870
- if name in params:
871
- del params[name]
872
- params_list[name] = value
873
- else:
874
- object.__setattr__(self, name, value)
1363
+ self.insert_param_to_cell(name, value)
1364
+ elif isinstance(value, ParameterTuple):
1365
+ exist_names = set("")
1366
+ exist_objs = set()
1367
+ for item in value:
1368
+ if item in exist_objs:
1369
+ # If there are multiple identical objects, their names only check once.
1370
+ continue
1371
+ exist_objs.add(item)
1372
+ if item.name == PARAMETER_NAME_DEFAULT:
1373
+ logger.warning("For 'Cell', the parameter definition is deprecated.\n"
1374
+ "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
1375
+ item.name = item.name + "$" + str(self._id)
1376
+ self._id += 1
1377
+ self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1378
+ if item.name in exist_names:
1379
+ raise ValueError("The value {} , its name '{}' already exists. "
1380
+ "Please set a unique name for the parameter.".format(value, item.name))
1381
+ exist_names.add(item.name)
1382
+
1383
+ if context._get_mode() == context.PYNATIVE_MODE:
1384
+ if name in self.__dict__:
1385
+ del self.__dict__[name]
1386
+ params = self.__dict__.get('_params')
1387
+ if name in params:
1388
+ del params[name]
1389
+ params_list = self.__dict__.get('_params_list')
1390
+ params_list[name] = value
1391
+ else:
1392
+ object.__setattr__(self, name, value)
875
1393
 
876
1394
  def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
877
1395
  """Set attr for parameter in list or tuple."""
@@ -884,24 +1402,18 @@ class Cell(Cell_):
884
1402
  item.name = item.name + "$" + str(self._id)
885
1403
  self._id += 1
886
1404
  if item.name in self.exist_names:
887
- raise ValueError("The value {} , its name '{}' already exists. "
888
- "Please set a unique name for the parameter.".format(value, item.name))
1405
+ raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
1406
+ "Please set a unique name for the parameter.")
889
1407
  self.exist_names.add(item.name)
890
1408
  object.__setattr__(self, name, value)
891
1409
 
892
1410
  def _set_attr_for_cell(self, name, value):
893
1411
  """Set attr for cell."""
894
- cells = self.__dict__.get('_cells')
895
- params = self.__dict__.get('_params')
896
- if cells is None:
897
- raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.")
898
1412
  if name in self.__dict__:
899
1413
  del self.__dict__[name]
900
- if params and name in params:
901
- raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.")
902
1414
  if self._auto_prefix:
903
1415
  value.update_parameters_name(name + '.')
904
- cells[name] = value
1416
+ self.insert_child_to_cell(name, value)
905
1417
  if hasattr(self, '_cell_init_args'):
906
1418
  self.cell_init_args += str({name: value})
907
1419
 
@@ -914,30 +1426,57 @@ class Cell(Cell_):
914
1426
  else:
915
1427
  self.insert_param_to_cell(name, None)
916
1428
 
917
- def __setattr__(self, name, value):
918
- cells = self.__dict__.get('_cells')
1429
+ def _set_attr_for_object(self, name, value):
1430
+ """Set attr for py object."""
919
1431
  params = self.__dict__.get('_params')
920
- if isinstance(value, Parameter):
921
- self._set_attr_for_parameter(name, value)
922
- elif isinstance(value, ParameterTuple):
923
- self._set_attr_for_parameter_tuple(name, value)
924
- elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value):
1432
+ if params is not None and name in params:
1433
+ if value is not None:
1434
+ if isinstance(value, Tensor):
1435
+ params[name].set_data(value)
1436
+ return
1437
+ raise TypeError(
1438
+ f"Parameter '{name}' already exists in network, "
1439
+ f"can not assign this type: '{type(value)}' as a parameter.")
1440
+ params[name] = None
1441
+ return
1442
+ cells = self.__dict__.get('_cells')
1443
+ if cells is not None and name in cells:
1444
+ if value is not None:
1445
+ raise TypeError(
1446
+ f"Sub cell '{name}' already exists in network, "
1447
+ f"can not assign this type: '{type(value)}' as a cell.")
1448
+ cells[name] = None
1449
+ return
1450
+ buffers = self.__dict__.get('_buffers')
1451
+ if buffers is not None and name in buffers:
1452
+ if value is not None:
1453
+ raise TypeError(
1454
+ f"Buffer '{name}' already exists in network, "
1455
+ f"can not assign this type: '{type(value)}' as a buffer.")
1456
+ buffers[name] = None
1457
+ return
1458
+ object.__setattr__(self, name, value)
1459
+
1460
+ def __setattr__(self, name, value):
1461
+ if isinstance(value, (Parameter, ParameterTuple)):
1462
+ self._set_attr_for_param_or_param_tuple(name, value)
1463
+ elif _is_parameter_list_or_tuple(value):
925
1464
  self._set_attr_for_parameter_in_list_or_tuple(name, value)
926
1465
  elif isinstance(value, Cell):
927
1466
  self._set_attr_for_cell(name, value)
928
- elif params and name in params:
929
- self._set_attr_for_params(name, value)
930
- elif cells and name in cells:
931
- if value is not None:
932
- raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
933
- self._cells[name] = None
934
- else:
935
- if isinstance(value, Primitive):
936
- value.set_prim_instance_name(name)
937
- self._primitives[name] = value
1467
+ elif isinstance(value, _Buffer):
1468
+ if name in self.__dict__:
1469
+ del self.__dict__[name]
1470
+ self.register_buffer(name, value)
1471
+ elif isinstance(value, Primitive):
1472
+ value.set_prim_instance_name(name)
1473
+ self._primitives[name] = value
938
1474
  object.__setattr__(self, name, value)
939
- if name not in Cell.IGNORE_LIST:
940
- self._attr_synced = False
1475
+ else:
1476
+ self._set_attr_for_object(name, value)
1477
+
1478
+ def _get_name(self):
1479
+ return self.__class__.__name__
941
1480
 
942
1481
  def extend_repr(self):
943
1482
  """
@@ -951,19 +1490,28 @@ class Cell(Cell_):
951
1490
  return self.__repr__()
952
1491
 
953
1492
  def __repr__(self):
954
- extra_str = self.extend_repr()
955
- info_str = self.__class__.__name__ + '<'
956
- if self._cells:
957
- sub_str = '\n'
958
- if extra_str:
959
- sub_str += '{}\n'.format(self.extend_repr())
960
- for key, value in self._cells.items():
961
- sub_str += '({}): {}\n'.format(key, repr(value))
962
- sub_str = sub_str.replace('\n', '\n ') + '>'
963
- info_str += sub_str
964
- else:
965
- info_str += extra_str + '>'
966
- return info_str
1493
+ extra_lines = []
1494
+ extend_repr = self.extend_repr()
1495
+ # empty string will be split into list ['']
1496
+ if extend_repr:
1497
+ extra_lines = extend_repr.split("\n")
1498
+ child_lines = []
1499
+ for key, cell in self._cells.items():
1500
+ cell_str = repr(cell)
1501
+ cell_str = _addindent(cell_str, 2)
1502
+ child_lines.append("(" + key + "): " + cell_str)
1503
+ lines = extra_lines + child_lines
1504
+
1505
+ main_str = self._get_name() + "("
1506
+ if lines:
1507
+ # simple one-liner info, which most builtin Modules will use
1508
+ if len(extra_lines) == 1 and not child_lines:
1509
+ main_str += extra_lines[0]
1510
+ else:
1511
+ main_str += "\n " + "\n ".join(lines) + "\n"
1512
+
1513
+ main_str += ")"
1514
+ return main_str
967
1515
 
968
1516
  def load_parameter_slice(self, params):
969
1517
  """
@@ -1129,9 +1677,11 @@ class Cell(Cell_):
1129
1677
  args (tuple): Args of the Cell object.
1130
1678
  kwargs (dict): Kwargs of the Cell object.
1131
1679
  """
1680
+ _init_auto_parallel_context(self)
1132
1681
  self._compile_args = self._get_compile_args(args)
1133
1682
  _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1134
1683
  jit_config_dict=self._jit_config_dict, **kwargs)
1684
+ _clear_auto_parallel_context(self)
1135
1685
 
1136
1686
  def compile_and_run(self, *args, **kwargs):
1137
1687
  """
@@ -1262,9 +1812,9 @@ class Cell(Cell_):
1262
1812
  >>> net2 = nn.Dense(2, 2)
1263
1813
  >>> net1.insert_child_to_cell("child", net2)
1264
1814
  >>> print(net1)
1265
- ReLU<
1266
- (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1267
- >
1815
+ ReLU(
1816
+ (child): Dense(input_channels=2, output_channels=2, has_bias=True)
1817
+ )
1268
1818
  """
1269
1819
  if not isinstance(child_name, str):
1270
1820
  raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
@@ -1322,13 +1872,22 @@ class Cell(Cell_):
1322
1872
  new_param_tuple.append(param)
1323
1873
  cell.__dict__[key] = ParameterTuple(new_param_tuple)
1324
1874
 
1875
+ def _get_cell_parallel_mode(self):
1876
+ """Determine whether the current cell is in parallel mode."""
1877
+ is_parallel_mode = False
1878
+ for _, param in self.parameters_and_names():
1879
+ if param.param_info.is_param_init:
1880
+ is_parallel_mode = True
1881
+ break
1882
+ return is_parallel_mode
1883
+
1325
1884
  def init_parameters_data(self, auto_parallel_mode=False):
1326
1885
  """
1327
1886
  Initialize all parameters and replace the original saved parameters in cell.
1328
1887
 
1329
1888
  Note:
1330
1889
  trainable_params() and other similar interfaces may return different parameter instance after
1331
- `init_parameters_data`, do not save these results.
1890
+ `init_parameters_data`. It is not recommended to save these results.
1332
1891
 
1333
1892
  Args:
1334
1893
  auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
@@ -1366,9 +1925,18 @@ class Cell(Cell_):
1366
1925
 
1367
1926
  # replace all original usage.
1368
1927
  cells = self.cells_and_names()
1928
+ is_parallel_mode = self._get_cell_parallel_mode()
1929
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
1930
+
1369
1931
  for _, cell in cells:
1370
1932
  params = cell._params.items()
1371
1933
  for param_name, param in params:
1934
+ not_sliced = not param.sliced
1935
+ judgment = not_sliced
1936
+ if param.param_info.is_pipeline_shared_param:
1937
+ continue
1938
+ if is_graph_mode and is_parallel_mode and judgment:
1939
+ continue
1372
1940
  if not auto_parallel_mode:
1373
1941
  cell._params[param_name] = _updata(param)
1374
1942
  continue
@@ -1380,6 +1948,12 @@ class Cell(Cell_):
1380
1948
  param_tuple = cell_dict[key]
1381
1949
  new_param_tuple = []
1382
1950
  for param in param_tuple:
1951
+ not_sliced = not param.sliced
1952
+ judgment = not_sliced
1953
+ if param.param_info.is_pipeline_shared_param:
1954
+ continue
1955
+ if is_graph_mode and is_parallel_mode and judgment:
1956
+ continue
1383
1957
  if not auto_parallel_mode:
1384
1958
  new_param_tuple.append(_updata(param))
1385
1959
  continue
@@ -1687,7 +2261,7 @@ class Cell(Cell_):
1687
2261
  ... return x
1688
2262
  >>> net = Net()
1689
2263
  >>> print(net.cells())
1690
- odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
2264
+ odict_values([Dense(input_channels=2, output_channels=2, has_bias=True)])
1691
2265
  """
1692
2266
  return self.name_cells().values()
1693
2267
 
@@ -1748,7 +2322,7 @@ class Cell(Cell_):
1748
2322
  ... return x
1749
2323
  >>> net = Net()
1750
2324
  >>> print(net.name_cells())
1751
- OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
2325
+ OrderedDict([('dense', Dense(input_channels=2, output_channels=2, has_bias=True))])
1752
2326
  """
1753
2327
  value_set = set()
1754
2328
  cells = OrderedDict()
@@ -1789,10 +2363,10 @@ class Cell(Cell_):
1789
2363
  ... if isinstance(cell, nn.Dense):
1790
2364
  ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
1791
2365
  >>> net.apply(func)
1792
- SequentialCell<
1793
- (0): Dense<input_channels=2, output_channels=2, has_bias=True>
1794
- (1): Dense<input_channels=2, output_channels=2, has_bias=True>
1795
- >
2366
+ SequentialCell(
2367
+ (0): Dense(input_channels=2, output_channels=2, has_bias=True)
2368
+ (1): Dense(input_channels=2, output_channels=2, has_bias=True)
2369
+ )
1796
2370
  >>> print(net[0].weight.asnumpy())
1797
2371
  [[1. 1.]
1798
2372
  [1. 1.]]
@@ -1832,9 +2406,6 @@ class Cell(Cell_):
1832
2406
  if not hasattr(self, "_func_graph_flags"):
1833
2407
  self._func_graph_flags = {}
1834
2408
  self._func_graph_flags.update({**flags})
1835
- if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
1836
- raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
1837
- "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
1838
2409
  self.__dict__.update({**flags})
1839
2410
  self._add_mixed_precision_flag(**flags)
1840
2411
  return self
@@ -1927,8 +2498,8 @@ class Cell(Cell_):
1927
2498
  >>>
1928
2499
  >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1929
2500
  >>> net.to_float(mstype.float16)
1930
- Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1931
- padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
2501
+ Conv2d(input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
2502
+ padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW)
1932
2503
  """
1933
2504
  if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1934
2505
  raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
@@ -2133,8 +2704,7 @@ class Cell(Cell_):
2133
2704
  """
2134
2705
  if context._get_mode() == context.GRAPH_MODE:
2135
2706
  return HookHandle()
2136
- if not check_hook_fn("register_forward_pre_hook", hook_fn):
2137
- return HookHandle()
2707
+ check_hook_fn(hook_fn)
2138
2708
  handle = HookHandle(self._forward_pre_hook)
2139
2709
  self._forward_pre_hook[handle.handle_id] = hook_fn
2140
2710
  return handle
@@ -2233,8 +2803,7 @@ class Cell(Cell_):
2233
2803
  return HookHandle()
2234
2804
  if context._get_mode() == context.GRAPH_MODE:
2235
2805
  return HookHandle()
2236
- if not check_hook_fn("register_forward_hook", hook_fn):
2237
- return HookHandle()
2806
+ check_hook_fn(hook_fn)
2238
2807
  handle = HookHandle(self._forward_hook)
2239
2808
  self._forward_hook[handle.handle_id] = hook_fn
2240
2809
  return handle
@@ -2324,8 +2893,7 @@ class Cell(Cell_):
2324
2893
  """
2325
2894
  if context._get_mode() == context.GRAPH_MODE:
2326
2895
  return HookHandle()
2327
- if not check_hook_fn("register_backward_pre_hook", hook_fn):
2328
- return HookHandle()
2896
+ check_hook_fn(hook_fn)
2329
2897
  handle = HookHandle(self._backward_pre_hook)
2330
2898
  self._backward_pre_hook[handle.handle_id] = hook_fn
2331
2899
  if self._cell_backward_pre_hook is None:
@@ -2361,6 +2929,527 @@ class Cell(Cell_):
2361
2929
  len(ret), len(outputs)))
2362
2930
  return ret
2363
2931
 
2932
+ def get_extra_state(self) -> Any:
2933
+ """Return any extra state to include in the cell's state_dict.
2934
+
2935
+ This function is called from ``state_dict``.
2936
+ Implement this and a corresponding ``set_extra_state`` for your cell
2937
+ if you need to store extra state.
2938
+
2939
+ Note that extra state should be picklable to ensure working serialization
2940
+ of the state_dict. Only provide backwards compatibility guarantees
2941
+ for serializing tensors; other objects may break backwards compatibility if
2942
+ their serialized pickled form changes.
2943
+
2944
+ Returns:
2945
+ object, any extra state to store in the cell's state_dict.
2946
+ """
2947
+ raise RuntimeError(
2948
+ "Reached a code path in Cell.get_extra_state() that should never be called."
2949
+
2950
+ )
2951
+
2952
+ def set_extra_state(self, state: Any) -> None:
2953
+ """Set extra state contained in the loaded `state_dict`.
2954
+
2955
+ This function is called from `load_state_dict` to handle any extra state
2956
+ found within the `state_dict`. Implement this function and a corresponding
2957
+ `get_extra_state` for your cell if you need to store extra state within its
2958
+ `state_dict`.
2959
+
2960
+ Args:
2961
+ state (dict): Extra state from the `state_dict`.
2962
+ """
2963
+ raise RuntimeError(
2964
+ "Reached a code path in Cell.set_extra_state() that should never be called."
2965
+ )
2966
+
2967
+ @jit_forbidden_register
2968
+ def register_state_dict_post_hook(self, hook):
2969
+ r"""Register a post-hook for the :func:`mindspore.nn.Cell.state_dict` method.
2970
+
2971
+ It should have the following signature:
2972
+
2973
+ hook(cell, state_dict, prefix, local_metadata) -> None
2974
+
2975
+ The registered hooks can modify the ``state_dict`` inplace.
2976
+
2977
+ Args:
2978
+ hook (Callable): The hook function after `state_dict` is called.
2979
+
2980
+ Returns:
2981
+ A handle that can be used to remove the added hook by calling
2982
+ `handle.remove()`.
2983
+ """
2984
+ from mindspore.utils.hooks import _RemovableHandle
2985
+ handle = _RemovableHandle(self._state_dict_hooks)
2986
+ self._state_dict_hooks[handle.id] = hook
2987
+ return handle
2988
+
2989
+ @jit_forbidden_register
2990
+ def register_state_dict_pre_hook(self, hook):
2991
+ r"""Register a pre-hook for the :func:`mindspore.nn.Cell.state_dict` method.
2992
+
2993
+ It should have the following signature:
2994
+
2995
+ hook(cell, prefix, keep_vars) -> None
2996
+
2997
+ The registered hooks can be used to perform pre-processing before the `state_dict`
2998
+ call is made.
2999
+
3000
+ Args:
3001
+ hook (Callable): The hook function before `state_dict` is called.
3002
+
3003
+ Returns:
3004
+ A handle that can be used to remove the added hook by calling
3005
+ `handle.remove()`.
3006
+
3007
+ Examples:
3008
+ >>> import mindspore
3009
+ ...
3010
+ ...
3011
+ >>> class NetA(mindspore.nn.Cell):
3012
+ ... def __init__(self):
3013
+ ... super().__init__()
3014
+ ... self.register_buffer("buffer_a", mindspore.tensor([1, 2, 3]))
3015
+ ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
3016
+ ...
3017
+ ... def construct(self, x):
3018
+ ... return x + self.buffer_a + self.param_a
3019
+ ...
3020
+ ...
3021
+ >>> def _add_extra_param(cell, prefix, keep_vars):
3022
+ ... cell._params["extra_param"] = mindspore.Parameter(mindspore.tensor([4, 5, 6]))
3023
+ ...
3024
+ ...
3025
+ >>> net = NetA()
3026
+ >>> handle = net.register_state_dict_pre_hook(_add_extra_param)
3027
+ >>> net_state_dict = net.state_dict()
3028
+ >>> handle.remove()
3029
+ >>> print("extra_param" in net_state_dict)
3030
+ True
3031
+ """
3032
+ from mindspore.utils.hooks import _RemovableHandle
3033
+ handle = _RemovableHandle(self._state_dict_pre_hooks)
3034
+ self._state_dict_pre_hooks[handle.id] = hook
3035
+ return handle
3036
+
3037
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
3038
+ r"""Save cell state to the `destination` dictionary.
3039
+
3040
+ The `destination` dictionary will contain the state
3041
+ of the cell, but not its descendants. This is called on every
3042
+ sub cell in :func:`mindspore.nn.Cell.state_dict`.
3043
+
3044
+ In rare cases, subclasses can achieve class-specific behavior by
3045
+ overriding this method with custom logic.
3046
+
3047
+ Args:
3048
+ destination (dict): a dict where state will be stored
3049
+ prefix (str): the prefix for parameters and buffers used in this
3050
+ cell
3051
+ """
3052
+ for name, param in self._params.items():
3053
+ if param is not None:
3054
+ destination[prefix + name] = param
3055
+ for name, buf in self._buffers.items():
3056
+ if buf is not None and name not in self._non_persistent_buffers_set:
3057
+ destination[prefix + name] = buf
3058
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
3059
+ if (
3060
+ getattr(self.__class__, "get_extra_state", Cell.get_extra_state)
3061
+ is not Cell.get_extra_state
3062
+ ):
3063
+ destination[extra_state_key] = self.get_extra_state()
3064
+
3065
+ # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
3066
+ # back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
3067
+ T_destination = TypeVar("T_destination", bound=Dict[str, Any])
3068
+
3069
+ @jit_forbidden_register
3070
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
3071
+ r"""Return a dictionary containing references to the whole state of the cell.
3072
+
3073
+ Both parameters and persistent buffers (e.g. running averages) are
3074
+ included. Keys are corresponding parameter and buffer names.
3075
+ Parameters and buffers set to ``None`` are not included.
3076
+
3077
+ .. note::
3078
+ The returned object is a shallow copy. It contains references
3079
+ to the cell's parameters and buffers.
3080
+
3081
+ .. warning::
3082
+ - Currently ``state_dict()`` also accepts positional arguments for
3083
+ ``destination``, ``prefix`` and ``keep_vars`` in order. However,
3084
+ this is being deprecated and keyword arguments will be enforced in
3085
+ future releases.
3086
+
3087
+ - Please avoid the use of argument ``destination`` as it is not
3088
+ designed for end-users.
3089
+
3090
+ Args:
3091
+ destination (dict, optional): If provided, the state of cell will
3092
+ be updated into the dict and the same object is returned.
3093
+ Otherwise, an ``OrderedDict`` will be created and returned.
3094
+ Default: ``None``.
3095
+ prefix (str, optional): A prefix added to parameter and buffer
3096
+ names to compose the keys in state_dict. Default: ``''``.
3097
+ keep_vars (bool, optional): Whether the state_dict returns a copy. Default: ``False`` , returns a reference.
3098
+
3099
+ Returns:
3100
+ Dict, a dictionary containing a whole state of the cell.
3101
+
3102
+ Examples:
3103
+ >>> import mindspore
3104
+ >>> class Model(mindspore.nn.Cell):
3105
+ ... def __init__(self):
3106
+ ... super().__init__()
3107
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
3108
+ ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
3109
+ ...
3110
+ ... def construct(self, x):
3111
+ ... return x + self.buffer_a + self.param_a
3112
+ ...
3113
+ ...
3114
+ >>> model = Model()
3115
+ >>> print(model.state_dict())
3116
+ OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3117
+ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3118
+ """
3119
+ # TODO: Remove `args` and the parsing logic when BC allows.
3120
+ if args:
3121
+ # DeprecationWarning is ignored by default
3122
+ warnings.warn(
3123
+ "Positional args are being deprecated, use kwargs instead. Refer to "
3124
+ "https://www.mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html"
3125
+ " for details.",
3126
+ FutureWarning,
3127
+ stacklevel=2,
3128
+ )
3129
+ if destination is None:
3130
+ destination = args[0]
3131
+ if len(args) > 1 and prefix == "":
3132
+ prefix = args[1]
3133
+ if len(args) > 2 and keep_vars is False:
3134
+ keep_vars = args[2]
3135
+ if destination is not None and not isinstance(destination, dict):
3136
+ raise TypeError(f"The type of destination must be OrderedDict, but got {type(destination)}")
3137
+ if not isinstance(prefix, str):
3138
+ raise TypeError(f"The type of prefix must be string, but got {type(prefix)}")
3139
+ if not isinstance(keep_vars, bool):
3140
+ raise TypeError(f"The type of keep_vars must be bool, but got {type(keep_vars)}")
3141
+
3142
+ if destination is None:
3143
+ destination = OrderedDict()
3144
+ destination._metadata = OrderedDict()
3145
+
3146
+ local_metadata = {}
3147
+ if hasattr(destination, "_metadata"):
3148
+ destination._metadata[prefix[:-1]] = local_metadata
3149
+
3150
+ for hook in self._state_dict_pre_hooks.values():
3151
+ hook(self, prefix, keep_vars)
3152
+ self._save_to_state_dict(destination, prefix, keep_vars)
3153
+ for name, cell in self._cells.items():
3154
+ if cell is not None:
3155
+ cell.state_dict(
3156
+ destination=destination,
3157
+ prefix=prefix + name + ".",
3158
+ keep_vars=keep_vars,
3159
+ )
3160
+ for hook in self._state_dict_hooks.values():
3161
+ hook_result = hook(self, destination, prefix, local_metadata)
3162
+ if hook_result is not None:
3163
+ raise RuntimeError("state_dict post-hook must return None")
3164
+ return destination
3165
+
3166
+ @jit_forbidden_register
3167
+ def register_load_state_dict_pre_hook(self, hook):
3168
+ r"""Register a pre-hook to be run before cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
3169
+
3170
+ It should have the following signature:
3171
+
3172
+ hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
3173
+
3174
+ Args:
3175
+ hook (Callable): The hook function before `load_state_dict` is called.
3176
+
3177
+ Returns:
3178
+ A handle that can be used to remove the added hook by calling
3179
+ `handle.remove()`.
3180
+ """
3181
+ from mindspore.utils.hooks import _RemovableHandle
3182
+ handle = _RemovableHandle(self._load_state_dict_pre_hooks)
3183
+ self._load_state_dict_pre_hooks[handle.id] = hook
3184
+ return handle
3185
+
3186
+ @jit_forbidden_register
3187
+ def register_load_state_dict_post_hook(self, hook):
3188
+ r"""Register a post-hook to be run after cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
3189
+
3190
+ It should have the following signature:
3191
+
3192
+ hook(cell, incompatible_keys) -> None
3193
+
3194
+ The ``cell`` argument is the current cell that this hook is registered
3195
+ on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
3196
+ of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
3197
+ is a ``list`` of ``str`` containing the missing keys and
3198
+ ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
3199
+
3200
+ The given incompatible_keys can be modified inplace if needed.
3201
+
3202
+ Note that the checks performed when calling :func:`load_state_dict` with
3203
+ ``strict=True`` are affected by modifications the hook makes to
3204
+ ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
3205
+ set of keys will result in an error being thrown when ``strict=True``, and
3206
+ clearing out both missing and unexpected keys will avoid an error.
3207
+
3208
+ Args:
3209
+ hook (Callable): The hook function after `load_state_dict` is called.
3210
+
3211
+ Returns:
3212
+ A handle that can be used to remove the added hook by calling
3213
+ `handle.remove()`.
3214
+ """
3215
+ from mindspore.utils.hooks import _RemovableHandle
3216
+ handle = _RemovableHandle(self._load_state_dict_post_hooks)
3217
+ self._load_state_dict_post_hooks[handle.id] = hook
3218
+ return handle
3219
+
3220
+ def _load_from_state_dict(
3221
+ self,
3222
+ state_dict,
3223
+ prefix,
3224
+ local_metadata,
3225
+ strict,
3226
+ missing_keys,
3227
+ unexpected_keys,
3228
+ error_msgs,
3229
+ ):
3230
+ r"""Copy parameters and buffers from :attr:`state_dict` into only this cell, but not its descendants.
3231
+
3232
+ This is called on every sub cell
3233
+ in :func:`mindspore.nn.Cell.load_state_dict`. Metadata saved for this
3234
+ cell in input :attr:`state_dict` is provided as :attr:`local_metadata`.
3235
+ For state dicts without metadata, :attr:`local_metadata` is empty.
3236
+ Subclasses can achieve class-specific backward compatible loading using
3237
+ the version number at `local_metadata.get("version", None)`.
3238
+
3239
+ .. note::
3240
+ :attr:`state_dict` is not the same object as the input
3241
+ :attr:`state_dict` to :func:`mindspore.nn.Cell.load_state_dict`. So
3242
+ it can be modified.
3243
+
3244
+ Args:
3245
+ state_dict (dict): a dict containing parameters and
3246
+ persistent buffers.
3247
+ prefix (str): the prefix for parameters and buffers used in this
3248
+ cell
3249
+ local_metadata (dict): a dict containing the metadata for this cell.
3250
+ See
3251
+ strict (bool): whether to strictly enforce that the keys in
3252
+ :attr:`state_dict` with :attr:`prefix` match the names of
3253
+ parameters and buffers in this cell
3254
+ missing_keys (list of str): if ``strict=True``, add missing keys to
3255
+ this list
3256
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
3257
+ keys to this list
3258
+ error_msgs (list of str): error messages should be added to this
3259
+ list, and will be reported together in
3260
+ :func:`mindspore.nn.Cell.load_state_dict`
3261
+ """
3262
+ for hook in self._load_state_dict_pre_hooks.values():
3263
+ hook(
3264
+ self,
3265
+ state_dict,
3266
+ prefix,
3267
+ local_metadata,
3268
+ strict,
3269
+ missing_keys,
3270
+ unexpected_keys,
3271
+ error_msgs,
3272
+ )
3273
+
3274
+ persistent_buffers = {
3275
+ k: v
3276
+ for k, v in self._buffers.items()
3277
+ if k not in self._non_persistent_buffers_set
3278
+ }
3279
+ local_name_params = itertools.chain(
3280
+ self._params.items(), persistent_buffers.items()
3281
+ )
3282
+ local_state = {k: v for k, v in local_name_params if v is not None}
3283
+
3284
+ for name, param in local_state.items():
3285
+ key = prefix + name
3286
+ if key in state_dict:
3287
+ input_param = state_dict[key]
3288
+ if not isinstance(input_param, Tensor):
3289
+ error_msgs.append(
3290
+ f'While copying the parameter named "{key}", '
3291
+ "expected Tensor or Tensor-like object from checkpoint but "
3292
+ f"received {type(input_param)}"
3293
+ )
3294
+ continue
3295
+
3296
+ if input_param.shape != param.shape:
3297
+ # local shape should match the one in checkpoint
3298
+ error_msgs.append(
3299
+ f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, "
3300
+ f"the shape in current model is {param.shape}."
3301
+ )
3302
+ continue
3303
+ try:
3304
+ param.assign_value(Tensor(input_param.asnumpy(), dtype=param.dtype))
3305
+ except Exception as ex: # pylint: disable=W0703
3306
+ error_msgs.append(
3307
+ f'While copy the parameter named "{key}", '
3308
+ f"whose shape in the model are {param.shape} and "
3309
+ f"whose shape in the checkpoint are {input_param.shape}, "
3310
+ f"an exception occurred : {ex.args}."
3311
+ )
3312
+ elif strict:
3313
+ missing_keys.append(key)
3314
+
3315
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
3316
+ if getattr(self.__class__, "set_extra_state", Cell.set_extra_state) is not Cell.set_extra_state:
3317
+ if extra_state_key in state_dict:
3318
+ self.set_extra_state(state_dict[extra_state_key])
3319
+ elif strict:
3320
+ missing_keys.append(extra_state_key)
3321
+ elif strict and (extra_state_key in state_dict):
3322
+ unexpected_keys.append(extra_state_key)
3323
+
3324
+ if strict:
3325
+ for key in state_dict.keys():
3326
+ if key.startswith(prefix) and key != extra_state_key:
3327
+ input_name = key[len(prefix):].split(".", 1)
3328
+ # Must be cell if it have attributes
3329
+ if len(input_name) > 1:
3330
+ if input_name[0] not in self._cells:
3331
+ unexpected_keys.append(key)
3332
+ elif input_name[0] not in local_state:
3333
+ unexpected_keys.append(key)
3334
+
3335
+ @jit_forbidden_register
3336
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
3337
+ r"""Copy parameters and buffers from :attr:`state_dict` into this cell and its descendants.
3338
+
3339
+ If :attr:`strict` is ``True``, then
3340
+ the keys of :attr:`state_dict` must exactly match the keys returned
3341
+ by this cell's :func:`mindspore.nn.Cell.state_dict` function.
3342
+
3343
+ Args:
3344
+ state_dict (dict): A dict containing parameters and
3345
+ persistent buffers.
3346
+ strict (bool, optional): Whether to strictly enforce that the keys
3347
+ in input `state_dict` match the keys returned by this cell's
3348
+ :func:`mindspore.nn.Cell.state_dict` function. Default ``True`` .
3349
+
3350
+ Returns:
3351
+ A namedtuple with ``missing_keys`` and ``unexpected_keys`` fields,
3352
+
3353
+ - `missing_keys` is a list of str containing any keys that are expected
3354
+ by this cell but missing from the provided ``state_dict``.
3355
+
3356
+ - `unexpected_keys` is a list of str containing the keys that are not
3357
+ expected by this cell but present in the provided ``state_dict``.
3358
+
3359
+ Note:
3360
+ If `strict` is ``True`` and a parameter or buffer is registered as ``None``, but its corresponding key
3361
+ exists in :attr:`state_dict`, and :func:`mindspore.nn.Cell.load_state_dict` will raise a ``RuntimeError``.
3362
+
3363
+ Examples:
3364
+ >>> import mindspore
3365
+ >>> import os
3366
+ >>> class Model(mindspore.nn.Cell):
3367
+ ... def __init__(self):
3368
+ ... super().__init__()
3369
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
3370
+ ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
3371
+ ...
3372
+ ... def construct(self, x):
3373
+ ... return x + self.buffer_a + self.param_a
3374
+ ...
3375
+ ...
3376
+ >>> model = Model()
3377
+ >>> print(model.state_dict())
3378
+ >>> mindspore.save_checkpoint(model.state_dict(), './model_state_dict_ckpt')
3379
+ >>> new_model = Model()
3380
+ >>> new_model.load_state_dict(mindspore.load_checkpoint('./model_state_dict_ckpt'))
3381
+ >>> print(new_model.state_dict())
3382
+ >>> os.remove('./model_state_dict_ckpt')
3383
+ OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3384
+ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3385
+ OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3386
+ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3387
+ """
3388
+ if not isinstance(state_dict, Mapping):
3389
+ raise TypeError(
3390
+ f"Expected state_dict to be dict-like, got {type(state_dict)}."
3391
+ )
3392
+
3393
+ missing_keys: List[str] = []
3394
+ unexpected_keys: List[str] = []
3395
+ error_msgs: List[str] = []
3396
+
3397
+ # copy state_dict so _load_from_state_dict can modify it
3398
+ metadata = getattr(state_dict, "_metadata", None)
3399
+ state_dict = OrderedDict(state_dict)
3400
+ if metadata is not None:
3401
+ # mypy isn't aware that "_metadata" exists in state_dict
3402
+ state_dict._metadata = metadata # type: ignore[attr-defined]
3403
+
3404
+ def load(cell, local_state_dict, prefix=""):
3405
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
3406
+ cell._load_from_state_dict(
3407
+ local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
3408
+ )
3409
+ for name, child in cell._cells.items():
3410
+ if child is not None:
3411
+ child_prefix = prefix + name + "."
3412
+ child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
3413
+ load(child, child_state_dict, child_prefix) # noqa: F821
3414
+
3415
+ # Note that the hook can modify missing_keys and unexpected_keys.
3416
+ incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
3417
+ for hook in cell._load_state_dict_post_hooks.values():
3418
+ out = hook(cell, incompatible_keys)
3419
+ if out is not None:
3420
+ raise RuntimeError(
3421
+ "Hooks registered with ``register_load_state_dict_post_hook`` are not"
3422
+ "expected to return new values, if incompatible_keys need to be modified,"
3423
+ "it should be done inplace."
3424
+ )
3425
+
3426
+ load(self, state_dict)
3427
+ del load
3428
+
3429
+ if strict:
3430
+ if unexpected_keys:
3431
+ error_msgs.insert(
3432
+ 0,
3433
+ "Unexpected key(s) in state_dict: {}. ".format(
3434
+ ", ".join(f'"{k}"' for k in unexpected_keys)
3435
+ ),
3436
+ )
3437
+ if missing_keys:
3438
+ error_msgs.insert(
3439
+ 0,
3440
+ "Missing key(s) in state_dict: {}. ".format(
3441
+ ", ".join(f'"{k}"' for k in missing_keys)
3442
+ ),
3443
+ )
3444
+
3445
+ if error_msgs:
3446
+ raise RuntimeError(
3447
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
3448
+ self.__class__.__name__, "\n\t".join(error_msgs)
3449
+ )
3450
+ )
3451
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
3452
+
2364
3453
  def register_backward_hook(self, hook_fn):
2365
3454
  """
2366
3455
  Register the backward hook function.
@@ -2420,8 +3509,7 @@ class Cell(Cell_):
2420
3509
  """
2421
3510
  if context._get_mode() == context.GRAPH_MODE:
2422
3511
  return HookHandle()
2423
- if not check_hook_fn("register_backward_hook", hook_fn):
2424
- return HookHandle()
3512
+ check_hook_fn(hook_fn)
2425
3513
  handle = HookHandle(self._backward_hook)
2426
3514
  self._backward_hook[handle.handle_id] = hook_fn
2427
3515
  if self._cell_backward_hook is None:
@@ -2565,8 +3653,9 @@ class Cell(Cell_):
2565
3653
  if not self._has_config_recompute:
2566
3654
  self._has_config_recompute = True
2567
3655
  else:
2568
- raise RuntimeError("The recompute interface can be configured only once."
2569
- " When the parent cell is configured, the child cell should not be configured")
3656
+ logger.info("The recompute interface can be configured only once."
3657
+ " When the parent cell is configured, the child cell should not be configured")
3658
+ return
2570
3659
  self._set_recompute_scope(mode)
2571
3660
  if mode and not output_recompute:
2572
3661
  self.add_flags(output_no_recompute=True)
@@ -2606,18 +3695,13 @@ class Cell(Cell_):
2606
3695
  """
2607
3696
  if context.get_context("mode") == context.PYNATIVE_MODE:
2608
3697
  self._recompute_cell = recompute_registry.get()(self.construct)
2609
- self._add_recompute_flag()
2610
- return
2611
3698
  self._recompute()
2612
3699
  if 'mp_comm_recompute' in kwargs.keys():
2613
3700
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
2614
3701
  if 'parallel_optimizer_comm_recompute' in kwargs.keys():
2615
- if (kwargs.get('parallel_optimizer_comm_recompute', False) and
2616
- context.get_auto_parallel_context("pipeline_stages") > 1):
3702
+ if kwargs.get('parallel_optimizer_comm_recompute', False):
2617
3703
  logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
2618
- "are not support recomputation in pipeline parallel.")
2619
- elif context.get_auto_parallel_context("pipeline_stages") == 1:
2620
- self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
3704
+ "is replaced with zero3.")
2621
3705
  if 'recompute_slice_activation' in kwargs:
2622
3706
  self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
2623
3707
 
@@ -2709,18 +3793,6 @@ class Cell(Cell_):
2709
3793
  if hasattr(network, "_amp_level"):
2710
3794
  self._amp_level = getattr(network, "_amp_level")
2711
3795
 
2712
- def _add_recompute_flag(self):
2713
- """
2714
- Set pynative cell recomputed.
2715
- """
2716
- if not self._has_config_recompute:
2717
- self._has_config_recompute = True
2718
- else:
2719
- logger.info("The recompute interface can be configured only once."
2720
- " If the parent cell is configured, the child cell should not be configured")
2721
- for cell in self.cells():
2722
- cell._add_recompute_flag()
2723
-
2724
3796
  def _register_parameters_hook(self, forward_hook=None, backward_hook=None, all=False):
2725
3797
  """
2726
3798
  Register the forward hook for parameters and register the backward hook for the corresponding gradient.
@@ -2807,6 +3879,7 @@ class Cell(Cell_):
2807
3879
  cell._parameters_forward_hook = forward_hook
2808
3880
  cell._parameters_backward_hook = backward_hook
2809
3881
 
3882
+
2810
3883
  class GraphCell(Cell):
2811
3884
  """
2812
3885
  Base class for running the graph loaded from a MindIR.
@@ -2820,12 +3893,10 @@ class GraphCell(Cell):
2820
3893
  The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
2821
3894
  If the parameter exists in the graph according to the name, update it's value.
2822
3895
  If the parameter does not exist, ignore it. Default: ``None`` .
2823
- obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
2824
- used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
2825
- a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
2826
- provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
3896
+ obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation, which is not supported now.
2827
3897
 
2828
3898
  Raises:
3899
+ NotImplementedError: Dynamic structure obfuscation is not supported now.
2829
3900
  TypeError: If the `graph` is not a FuncGraph.
2830
3901
  TypeError: If the `params_init` is not a dict.
2831
3902
  TypeError: If the key of the `params_init` is not a str.
@@ -2855,20 +3926,12 @@ class GraphCell(Cell):
2855
3926
 
2856
3927
  def __init__(self, graph, params_init=None, obf_random_seed=None):
2857
3928
  super(GraphCell, self).__init__(auto_prefix=True)
3929
+ if obf_random_seed is not None:
3930
+ raise NotImplementedError("Dynamic structure obfuscation is not supported now.")
2858
3931
  if not isinstance(graph, FuncGraph):
2859
3932
  raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
2860
3933
  f"but got type {type(graph)}.")
2861
3934
  self.graph = graph
2862
- self.obf_random_seed = obf_random_seed
2863
- if obf_random_seed is not None:
2864
- if not isinstance(obf_random_seed, int):
2865
- raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
2866
- int_64_max = 9223372036854775807
2867
- if obf_random_seed <= 0 or obf_random_seed > int_64_max:
2868
- raise ValueError(
2869
- "'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
2870
- "but got {}.".format(int_64_max, obf_random_seed))
2871
- self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
2872
3935
  params_init = {} if params_init is None else params_init
2873
3936
  if not isinstance(params_init, dict):
2874
3937
  raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
@@ -2888,19 +3951,30 @@ class GraphCell(Cell):
2888
3951
  def __call__(self, *args, **kwargs):
2889
3952
  self.phase = "graph_load_from_mindir"
2890
3953
  self._add_attr("graph_load_from_mindir", self.graph)
2891
- if not self.obf_random_seed:
2892
- return self.compile_and_run(*args, **kwargs)
2893
- append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
2894
- return self.compile_and_run(*args, append_input, **kwargs)
3954
+ return self.compile_and_run(*args, **kwargs)
2895
3955
 
2896
3956
 
2897
- def _check_param_list_tuple(value):
3957
+ def _is_parameter_list_or_tuple(value):
2898
3958
  """
2899
3959
  Check the type of input in list or tuple is Parameter.
2900
3960
  :param value: list or tuple.
2901
3961
  :return: The types of all inputs are parameter.
2902
3962
  """
2903
- for item in value:
2904
- if not isinstance(item, Parameter):
2905
- return False
2906
- return True
3963
+ if isinstance(value, (list, tuple)) and value:
3964
+ for item in value:
3965
+ if not isinstance(item, Parameter):
3966
+ return False
3967
+ return True
3968
+ return False
3969
+
3970
+
3971
+ def _addindent(s_, num_spaces):
3972
+ s = s_.split("\n")
3973
+ # don't do anything for single-line stuff
3974
+ if len(s) == 1:
3975
+ return s_
3976
+ first = s.pop(0)
3977
+ s = [(num_spaces * " ") + line for line in s]
3978
+ s = "\n".join(s)
3979
+ s = first + "\n" + s
3980
+ return s