mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +47 -198
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +229 -99
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +480 -372
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +5 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +975 -1981
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +324 -573
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +183 -117
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +179 -120
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +798 -761
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +933 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1373 -192
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +53 -42
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +19 -15
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +3 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +361 -359
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +52 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +757 -185
  287. mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
  288. mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4351 -3813
  334. mindspore/ops/function/nn_func.py +1712 -637
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +452 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +26 -18
  343. mindspore/ops/functional.py +23 -7
  344. mindspore/ops/functional_overload.py +1548 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +23 -15
  347. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +298 -87
  360. mindspore/ops/operations/debug_ops.py +157 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +212 -531
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1895 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +159 -40
  426. mindspore/parallel/_cell_wrapper.py +132 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +700 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +258 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -59
  446. mindspore/parallel/transform_safetensors.py +364 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +416 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +96 -27
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +269 -136
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +552 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,13 +15,26 @@
15
15
  """cell"""
16
16
  from __future__ import absolute_import
17
17
 
18
- import gc
19
18
  import inspect
20
19
  import os
21
20
  import time
22
- from collections import OrderedDict
23
- import numpy
24
-
21
+ import warnings
22
+ import itertools
23
+ from collections import OrderedDict, namedtuple
24
+ from typing import (
25
+ Dict,
26
+ Optional,
27
+ Set,
28
+ Callable,
29
+ List,
30
+ Tuple,
31
+ Iterator,
32
+ Any,
33
+ TypeVar,
34
+ Mapping
35
+ )
36
+
37
+ import mindspore as ms
25
38
  from mindspore._checkparam import args_type_check, check_hook_fn
26
39
  from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
27
40
  from mindspore import log as logger
@@ -34,19 +47,62 @@ from mindspore import _checkparam as Validator
34
47
  from mindspore.common import dtype as mstype
35
48
  from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
36
49
  _no_grad
37
- from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
50
+ from mindspore.common.api import _convert_python_data, _get_args_for_run_predict
38
51
  from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
39
- from mindspore.common.parameter import Parameter, ParameterTuple
52
+ from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
40
53
  from mindspore.common.tensor import Tensor
41
54
  from mindspore.ops.operations import Cast
42
55
  from mindspore.ops.primitive import Primitive
43
56
  from mindspore.ops.operations import _inner_ops as inner
44
57
  from mindspore.parallel.shard import Shard
58
+ from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
45
59
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
46
60
  from mindspore.common._decorator import deprecated
47
61
  from mindspore.common._register_for_recompute import recompute_registry
48
62
 
49
63
 
64
+ __all__ = [
65
+ "register_cell_buffer_registration_hook",
66
+ ]
67
+
68
+ _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
69
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
70
+
71
+
72
+ class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),):
73
+ def __repr__(self):
74
+ if not self.missing_keys and not self.unexpected_keys:
75
+ return "<All keys matched successfully>"
76
+ return super().__repr__()
77
+
78
+ __str__ = __repr__
79
+
80
+
81
+ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
82
+ r"""Register a buffer registration hook common to all cells.
83
+
84
+ .. warning ::
85
+
86
+ This adds global state to the `nn.Cell` cell
87
+
88
+ The hook will be called every time :func:`register_buffer` is invoked.
89
+ It should have the following signature::
90
+
91
+ hook(cell, name, buffer) -> None or new buffer
92
+
93
+ The hook can modify the input or return a single modified value in the hook.
94
+
95
+ Returns:
96
+ A handle that can be used to remove the added hook by calling
97
+ `handle.remove()`.
98
+ """
99
+ from mindspore.utils.hooks import _RemovableHandle
100
+ handle = _RemovableHandle(_global_buffer_registration_hooks)
101
+ _global_buffer_registration_hooks[handle.id] = hook
102
+ return handle
103
+
104
+
105
+
50
106
  class Cell(Cell_):
51
107
  """
52
108
  The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
@@ -60,7 +116,7 @@ class Cell(Cell_):
60
116
  .. note::
61
117
  Cell is the inference mode by default. For a class that inherits a Cell,
62
118
  if the training and inference have different structures, the subclass performs the inference branch by default.
63
- To set the training mode, refer to `mindspore.nn.Cell.set_train` .
119
+ To set the training mode, refer to :func:`mindspore.nn.Cell.set_train` .
64
120
 
65
121
  .. warning::
66
122
  In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
@@ -105,8 +161,11 @@ class Cell(Cell_):
105
161
  '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
106
162
  '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
107
163
  '_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
108
- '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
164
+ '_attr_synced', 'pynative', 'requires_grad', 'cell_type',
165
+ '_parameters_forward_hook', '_parameters_backward_hook']
109
166
  total_instance_count = 0
167
+ _buffers: Dict[str, Optional[Tensor]]
168
+ _non_persistent_buffers_set: Set[str]
110
169
 
111
170
  def __init__(self, auto_prefix=True, flags=None):
112
171
  Cell_.__init__(self, self._cell_tag)
@@ -114,10 +173,17 @@ class Cell(Cell_):
114
173
  self.instance_count = Cell.total_instance_count
115
174
  self._params = OrderedDict()
116
175
  self._cells = OrderedDict()
176
+ super().__setattr__("_buffers", {})
177
+ super().__setattr__("_non_persistent_buffers_set", set())
178
+ super().__setattr__("_state_dict_hooks", OrderedDict())
179
+ super().__setattr__("_state_dict_pre_hooks", OrderedDict())
180
+ super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
181
+ super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
117
182
  self._params_list = OrderedDict()
118
183
  self._primitives = OrderedDict()
119
184
  self.training = False
120
185
  self.requires_grad = False
186
+ self.is_top_cell = False
121
187
  self.pynative = False
122
188
  self._attr_synced = False
123
189
  self._param_prefix = ''
@@ -134,8 +200,8 @@ class Cell(Cell_):
134
200
  cells_compile_cache[id(self)] = self.compile_cache
135
201
  self.parameter_broadcast_done = False
136
202
  self._id = 1
137
- self.exist_names = set("")
138
- self.exist_objs = set()
203
+ self._exist_objs = None
204
+ self._exist_names = None
139
205
  self._recompute_cell = None
140
206
  self.mixed_precision_type = None
141
207
  self.sig = inspect.signature(self.construct)
@@ -143,7 +209,8 @@ class Cell(Cell_):
143
209
 
144
210
  # call gc to release GE session resources used by non-used cell objects
145
211
  if os.getenv('GC_COLLECT_IN_CELL') == '1':
146
- gc.collect()
212
+ logger.warning("The convenient environment 'GC_COLLECT_IN_CELL' is deprecated from version 2.5 "
213
+ "and will be removed in a future version.")
147
214
 
148
215
  if flags:
149
216
  self.add_flags(**flags)
@@ -158,6 +225,10 @@ class Cell(Cell_):
158
225
  self._cell_backward_hook = None
159
226
  self._is_recursion_hook = False
160
227
 
228
+ # parameters hook
229
+ self._parameters_forward_hook = None
230
+ self._parameters_backward_hook = None
231
+
161
232
  self.cell_type = None
162
233
  self.cast = Cast()
163
234
  self._has_config_recompute = False
@@ -202,6 +273,21 @@ class Cell(Cell_):
202
273
  def cell_init_args(self):
203
274
  return self._cell_init_args
204
275
 
276
+ @property
277
+ def exist_names(self):
278
+ """
279
+ Get exist parameter names adding by tuple or list of parameter.
280
+ """
281
+ if self._exist_names is None:
282
+ self._exist_names = set("")
283
+ return self._exist_names
284
+
285
+ @property
286
+ def exist_objs(self):
287
+ if self._exist_objs is None:
288
+ self._exist_objs = set()
289
+ return self._exist_objs
290
+
205
291
  @property
206
292
  def param_prefix(self):
207
293
  """
@@ -230,11 +316,6 @@ class Cell(Cell_):
230
316
  def bprop_debug(self):
231
317
  """
232
318
  Get whether cell custom bprop debug is enabled.
233
-
234
- Tutorial Examples:
235
- - `Custom Neural Network Layers - Custom Cell Reverse
236
- <https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
237
- #custom-cell-reverse>`_
238
319
  """
239
320
  return self._bprop_debug
240
321
 
@@ -351,8 +432,6 @@ class Cell(Cell_):
351
432
  raise ValueError("For 'Cell', the property 'pipeline_stage' "
352
433
  "can not be less than 0, but got {}".format(value))
353
434
  self._pipeline_stage = value
354
- for item in self.trainable_params():
355
- item.add_pipeline_stage(value)
356
435
 
357
436
  @property
358
437
  def pipeline_segment(self):
@@ -388,6 +467,374 @@ class Cell(Cell_):
388
467
  def enable_backward_hook(self):
389
468
  return self._enable_backward_hook
390
469
 
470
+ @jit_forbidden_register
471
+ def register_buffer(
472
+ self, name: str, tensor: Optional[Tensor], persistent: bool = True
473
+ ) -> None:
474
+ r"""Add a buffer to the cell.
475
+
476
+ This is typically used to register a buffer that should not to be
477
+ considered a model parameter. For example, BatchNorm's `running_mean`
478
+ is not a parameter, but is part of the cell's state. Buffers, by
479
+ default, are persistent and will be saved alongside parameters. This
480
+ behavior can be changed by setting `persistent` to ``False`` . The
481
+ only difference between a persistent buffer and a non-persistent buffer
482
+ is that the latter will not be a part of this cell's :attr:`state_dict` .
483
+
484
+ Buffers can be accessed as attributes using given names.
485
+
486
+ Args:
487
+ name (str): name of the buffer. The buffer can be accessed
488
+ from this cell using the given name.
489
+ tensor (Tensor): Buffer to be registered. If ``None`` ,
490
+ the buffer is not included in the cell's :attr:`state_dict` .
491
+ persistent (bool, optional): Whether the buffer is part of this cell's :attr:`state_dict`. Default ``True``.
492
+
493
+ Examples:
494
+ >>> import mindspore
495
+ ...
496
+ >>> class Net(mindspore.nn.Cell):
497
+ ... def __init__(self):
498
+ ... super().__init__()
499
+ ... self.register_buffer("buffer0", mindspore.tensor([1, 2, 3]))
500
+ ...
501
+ ... def construct(self, x):
502
+ ... return x + self.net_buffer
503
+ ...
504
+ >>> net = Net()
505
+ >>> net.register_buffer("buffer0", mindspore.tensor([4, 5, 6]))
506
+ >>> print(net.buffer0)
507
+ [4 5 6]
508
+ """
509
+
510
+ if "_buffers" not in self.__dict__:
511
+ raise AttributeError("cannot assign buffer before Cell.__init__() call")
512
+ if not isinstance(name, str):
513
+ raise TypeError(
514
+ f"buffer name should be a string.But got this type: {type(name)}"
515
+ )
516
+ if "." in name:
517
+ raise KeyError('buffer name can\'t contain "."')
518
+ if name == "":
519
+ raise KeyError('buffer name can\'t be empty string ""')
520
+ if hasattr(self, name) and name not in self._buffers:
521
+ raise KeyError(f"attribute '{name}' already exists")
522
+ if tensor is not None and not isinstance(tensor, Tensor):
523
+ raise TypeError(
524
+ f"cannot assign '{type(tensor)}' object to buffer '{name}' "
525
+ "(mindspore Tensor or None required)"
526
+ )
527
+ for hook in _global_buffer_registration_hooks.values():
528
+ output = hook(self, name, tensor)
529
+ if output is not None:
530
+ tensor = output
531
+ if tensor is not None:
532
+ tensor._is_buffer = True
533
+ self._buffers[name] = tensor
534
+ if persistent:
535
+ self._non_persistent_buffers_set.discard(name)
536
+ else:
537
+ self._non_persistent_buffers_set.add(name)
538
+
539
+ @jit_forbidden_register
540
+ def get_buffer(self, target: str) -> "Tensor":
541
+ """Return the buffer given by `target` if it exists, otherwise throw an error.
542
+
543
+ See the docstring for `get_sub_cell` for a more detailed
544
+ explanation of this method's functionality as well as how to
545
+ correctly specify `target` .
546
+
547
+ Args:
548
+ target (str): The fully-qualified string name of the buffer
549
+ to look for. (See `get_sub_cell` for how to specify a
550
+ fully-qualified string.)
551
+
552
+ Returns:
553
+ Tensor
554
+
555
+ Examples:
556
+ >>> import mindspore
557
+ ...
558
+ ...
559
+ >>> class NetC(mindspore.nn.Cell):
560
+ ... def __init__(self):
561
+ ... super().__init__()
562
+ ... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
563
+ ...
564
+ ... def construct(self, x):
565
+ ... return x + self.buffer_c
566
+ ...
567
+ ...
568
+ >>> class NetB(mindspore.nn.Cell):
569
+ ... def __init__(self, net_c):
570
+ ... super().__init__()
571
+ ... self.net_c = net_c
572
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
573
+ ...
574
+ ... def construct(self, x):
575
+ ... return self.net_c(x) + self.buffer_b
576
+ ...
577
+ ...
578
+ >>> class NetA(mindspore.nn.Cell):
579
+ ... def __init__(self, net_b):
580
+ ... super().__init__()
581
+ ... self.net_b = net_b
582
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
583
+ ...
584
+ ... def construct(self, x):
585
+ ... return self.net_b(x) + self.buffer_a
586
+ ...
587
+ ...
588
+ >>> net_c = NetC()
589
+ >>> net_b = NetB(net_c)
590
+ >>> net_a = NetA(net_b)
591
+ >>> buffer_c = net_a.get_buffer("net_b.net_c.buffer_c")
592
+ >>> print(f'buffer_c is {buffer_c}')
593
+ buffer_c is [0 0 0]
594
+
595
+ """
596
+ cell_path, _, buffer_name = target.rpartition(".")
597
+
598
+ cell = self.get_sub_cell(cell_path)
599
+
600
+ if not hasattr(cell, buffer_name):
601
+ raise AttributeError(
602
+ cell._get_name() + " has no attribute `" + buffer_name + "`"
603
+ )
604
+
605
+ buffer = getattr(cell, buffer_name)
606
+
607
+ if buffer_name not in cell._buffers:
608
+ raise AttributeError("`" + buffer_name + "` is not a buffer")
609
+
610
+ return buffer
611
+
612
+ @jit_forbidden_register
613
+ def named_buffers(
614
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
615
+ ) -> Iterator[Tuple[str, Tensor]]:
616
+ r"""Return an iterator over cell buffers, yielding both the name of the buffer as well as the buffer itself.
617
+
618
+ Args:
619
+ prefix (str, optional): prefix to prepend to all buffer names. Default ``""``.
620
+ recurse (bool, optional): if ``True`` , then yields buffers of this cell
621
+ and all sub cells. Otherwise, yields only buffers that
622
+ are direct members of this cell. Default ``True``.
623
+ remove_duplicate (bool, optional): Whether to remove the duplicated buffers in the result. Default ``True``.
624
+
625
+ Returns:
626
+ Iterator[Tuple[str, Tensor]], an iterator of tuple containing the name and buffer.
627
+
628
+ Examples:
629
+ >>> import mindspore
630
+ ...
631
+ ...
632
+ >>> class NetB(mindspore.nn.Cell):
633
+ ... def __init__(self):
634
+ ... super().__init__()
635
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
636
+ ...
637
+ ... def construct(self, x):
638
+ ... return x + self.buffer_b
639
+ ...
640
+ ...
641
+ >>> class NetA(mindspore.nn.Cell):
642
+ ... def __init__(self, net_b):
643
+ ... super().__init__()
644
+ ... self.net_b = net_b
645
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
646
+ ...
647
+ ... def construct(self, x):
648
+ ... return self.net_b(x) + self.buffer_a
649
+ ...
650
+ ...
651
+ >>> net_b = NetB()
652
+ >>> net_a = NetA(net_b)
653
+ >>>
654
+ >>> for name, buffer in net_a.named_buffers():
655
+ ... print(f'buffer name is {name}, buffer is {buffer}')
656
+ buffer name is buffer_a, buffer is [4 5 6]
657
+ buffer name is net_b.buffer_b, buffer is [1 2 3]
658
+
659
+ """
660
+ gen = self._named_members(
661
+ lambda cell: cell._buffers.items(),
662
+ prefix=prefix,
663
+ recurse=recurse,
664
+ remove_duplicate=remove_duplicate,
665
+ )
666
+ yield from gen
667
+
668
+ @jit_forbidden_register
669
+ def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
670
+ r"""Return an iterator over cell buffers.
671
+
672
+ Args:
673
+ recurse (bool, optional): If ``True`` , then yields buffers of this cell
674
+ and all sub cells. Otherwise, yields only buffers that
675
+ are direct members of this cell. Default ``True``.
676
+
677
+ Returns:
678
+ Iterator[Tensor], an iterator of buffer.
679
+
680
+ Examples:
681
+ >>> import mindspore
682
+ ...
683
+ ...
684
+ >>> class NetB(mindspore.nn.Cell):
685
+ ... def __init__(self):
686
+ ... super().__init__()
687
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
688
+ ...
689
+ ... def construct(self, x):
690
+ ... return x + self.buffer_b
691
+ ...
692
+ ...
693
+ >>> class NetA(mindspore.nn.Cell):
694
+ ... def __init__(self, net_b):
695
+ ... super().__init__()
696
+ ... self.net_b = net_b
697
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
698
+ ...
699
+ ... def construct(self, x):
700
+ ... return self.net_b(x) + self.buffer_a
701
+ ...
702
+ ...
703
+ >>> net_b = NetB()
704
+ >>> net_a = NetA(net_b)
705
+ >>>
706
+ >>> for buffer in net_a.buffers():
707
+ ... print(f'buffer is {buffer}')
708
+ buffer is [4 5 6]
709
+ buffer is [1 2 3]
710
+
711
+ """
712
+ for _, buf in self.named_buffers(recurse=recurse):
713
+ yield buf
714
+
715
+ def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
716
+ r"""Help yield various names + members of cells."""
717
+ memo = set()
718
+ cells = (
719
+ self.cells_and_names(name_prefix=prefix)
720
+ if recurse
721
+ else [(prefix, self)]
722
+ )
723
+ for cell_prefix, cell in cells:
724
+ members = get_members_fn(cell)
725
+ for k, v in members:
726
+ if v is None or v in memo:
727
+ continue
728
+ if remove_duplicate:
729
+ memo.add(v)
730
+ name = cell_prefix + ("." if cell_prefix else "") + k
731
+ yield name, v
732
+
733
+ @jit_forbidden_register
734
+ def get_sub_cell(self, target: str) -> "Cell":
735
+ """Return the sub cell given by `target` if it exists, otherwise throw an error.
736
+
737
+ For example, let's say you have an ``nn.Cell`` ``A`` that
738
+ looks like this:
739
+
740
+ .. code-block:: text
741
+
742
+ A(
743
+ (net_b): NetB(
744
+ (net_c): NetC(
745
+ (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
746
+ )
747
+ (dense): Dense(in_features=100, out_features=200, bias=True)
748
+ )
749
+ )
750
+
751
+ (The diagram shows an ``nn.Cell`` ``A``. ``A`` has a nested
752
+ sub cell ``net_b``, which itself has two sub cells ``net_c``
753
+ and ``dense``. ``net_c`` then has a sub cell ``conv``.)
754
+
755
+ To check whether we have the ``dense`` sub cell, we
756
+ would call `get_sub_cell("net_b.dense")`. To check whether
757
+ we have the ``conv`` sub cell, we would call
758
+ `get_sub_cell("net_b.net_c.conv")`.
759
+
760
+ The runtime of ``get_sub_cell`` is bounded by the degree
761
+ of cell nesting in `target`. A query against
762
+ `name_cells` achieves the same result, but it is O(N) in
763
+ the number of transitive cells. So, for a simple check to see
764
+ if some sub cells exist, ``get_sub_cell`` should always be
765
+ used.
766
+
767
+ Args:
768
+ target (str): The fully-qualified string name of the sub cell
769
+ to look for. (See above example for how to specify a
770
+ fully-qualified string.)
771
+
772
+ Returns:
773
+ Cell
774
+
775
+ Examples:
776
+ >>> import mindspore
777
+ ...
778
+ ...
779
+ >>> class NetC(mindspore.nn.Cell):
780
+ ... def __init__(self):
781
+ ... super().__init__()
782
+ ... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
783
+ ... self.dense_c = mindspore.nn.Dense(5, 3)
784
+ ...
785
+ ... def construct(self, x):
786
+ ... return self.dense_c(x) + self.buffer_c
787
+ ...
788
+ ...
789
+ >>> class NetB(mindspore.nn.Cell):
790
+ ... def __init__(self, net_c):
791
+ ... super().__init__()
792
+ ... self.net_c = net_c
793
+ ... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
794
+ ...
795
+ ... def construct(self, x):
796
+ ... return self.net_c(x) + self.buffer_b
797
+ ...
798
+ ...
799
+ >>> class NetA(mindspore.nn.Cell):
800
+ ... def __init__(self, net_b):
801
+ ... super().__init__()
802
+ ... self.net_b = net_b
803
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
804
+ ...
805
+ ... def construct(self, x):
806
+ ... return self.net_b(x) + self.buffer_a
807
+ ...
808
+ ...
809
+ >>> net_c = NetC()
810
+ >>> net_b = NetB(net_c)
811
+ >>> net_a = NetA(net_b)
812
+ >>> net_c = net_a.get_sub_cell("net_b.net_c")
813
+ >>> print(f'net_c is {net_c}')
814
+ net_c is NetC(
815
+ (dense_c): Dense(input_channels=5, output_channels=3, has_bias=True)
816
+ )
817
+
818
+ """
819
+ if target == "":
820
+ return self
821
+
822
+ atoms: List[str] = target.split(".")
823
+ cell = self
824
+
825
+ for item in atoms:
826
+ if not hasattr(cell, item):
827
+ raise AttributeError(
828
+ cell._get_name() + " has no " "attribute `" + item + "`"
829
+ )
830
+
831
+ cell = getattr(cell, item)
832
+
833
+ if not isinstance(cell, Cell):
834
+ raise AttributeError("`" + item + "` is not " "an nn.Cell")
835
+
836
+ return cell
837
+
391
838
  def get_func_graph_proto(self):
392
839
  """Return graph binary proto."""
393
840
  exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
@@ -398,6 +845,10 @@ class Cell(Cell_):
398
845
  params = self.__dict__['_params']
399
846
  if name in params:
400
847
  return params[name]
848
+ if '_buffers' in self.__dict__:
849
+ buffers = self.__dict__['_buffers']
850
+ if name in buffers:
851
+ return buffers[name]
401
852
  if '_cells' in self.__dict__:
402
853
  cells = self.__dict__['_cells']
403
854
  if name in cells:
@@ -420,6 +871,8 @@ class Cell(Cell_):
420
871
  def __delattr__(self, name):
421
872
  if name in self._params:
422
873
  del self._params[name]
874
+ elif name in self._buffers:
875
+ del self._buffers[name]
423
876
  elif name in self._cells:
424
877
  del self._cells[name]
425
878
  elif '_params_list' in self.__dict__ and name in self._params_list:
@@ -492,14 +945,17 @@ class Cell(Cell_):
492
945
  if self._forward_pre_hook:
493
946
  inputs = self._run_forward_pre_hook(inputs)
494
947
 
495
- if self._backward_hook:
496
- output = self._backward_hook_construct(*inputs, **kwargs)
497
- elif self._shard_fn is not None:
948
+ if self._shard_fn is not None:
498
949
  output = self._shard_fn(*inputs, **kwargs)
499
- elif self._recompute_cell is not None:
500
- output = self._recompute_cell(*inputs, **kwargs)
501
- elif self.has_bprop and _pynative_executor.requires_grad():
502
- output = self._call_custom_bprop(*inputs, **kwargs)
950
+ elif _pynative_executor.requires_grad():
951
+ if self._backward_hook:
952
+ output = self._backward_hook_construct(*inputs, **kwargs)
953
+ elif self._recompute_cell is not None:
954
+ output = self._recompute_cell(*inputs, **kwargs)
955
+ elif self.has_bprop:
956
+ output = self._call_custom_bprop(*inputs, **kwargs)
957
+ else:
958
+ output = self.construct(*inputs, **kwargs)
503
959
  else:
504
960
  output = self.construct(*inputs, **kwargs)
505
961
 
@@ -590,6 +1046,89 @@ class Cell(Cell_):
590
1046
  for prim in all_prims:
591
1047
  prim.add_prim_attr("strategy_gen_mode", "data_parallel")
592
1048
 
1049
+ def offload(self, backward_prefetch="Auto"):
1050
+ """
1051
+ Set the cell offload. All primitive ops in the cell will be set offload. For the intermediate
1052
+ activations calculated by these primitive ops, we will not save them in the forward pass, but
1053
+ offload them and onload them in the backward pass.
1054
+
1055
+ Note:
1056
+ - If Cell.offload is called, the mode should be set to "GRAPH_MODE".
1057
+ - If Cell.offload is called, lazyinline should be enabled.
1058
+
1059
+ Args:
1060
+ backward_prefetch(Union[str, int], optional): The timing for prefetching activations in advance in backward
1061
+ pass. Default: ``"Auto"``. If set it to ``"Auto"``, framework
1062
+ will start to prefetch activations one operator in advance.
1063
+ If set it to a positive int value, framework will start to
1064
+ prefetch activations ``backward_prefetch`` operators in
1065
+ advance, such as 1, 20, 100.
1066
+ Examples:
1067
+ >>> import mindspore.nn as nn
1068
+ >>> from mindspore import ops
1069
+ >>> from mindspore.common import Tensor, Parameter
1070
+ >>> from mindspore.common.lazy_inline import lazy_inline
1071
+ >>>
1072
+ >>> class Block(nn.Cell):
1073
+ ... def __init__(self):
1074
+ ... super(Block, self).__init__()
1075
+ ... self.transpose1 = ops.Transpose()
1076
+ ... self.transpose2 = ops.Transpose()
1077
+ ... self.transpose3 = ops.Transpose()
1078
+ ... self.transpose4 = ops.Transpose()
1079
+ ... self.real_div1 = ops.RealDiv()
1080
+ ... self.real_div2 = ops.RealDiv()
1081
+ ... self.batch_matmul1 = ops.BatchMatMul()
1082
+ ... self.batch_matmul2 = ops.BatchMatMul()
1083
+ ... self.softmax = ops.Softmax(-1)
1084
+ ... self.expand_dims = ops.ExpandDims()
1085
+ ... self.sub = ops.Sub()
1086
+ ... self.y = Parameter(Tensor(np.ones((1024, 128, 128)).astype(np.float32)))
1087
+ ... def construct(self, x):
1088
+ ... transpose1 = self.transpose1(x, (0, 2, 1, 3))
1089
+ ... real_div1 = self.real_div1(transpose1, Tensor(2.37891))
1090
+ ... transpose2 = self.transpose2(x, (0, 2, 3, 1))
1091
+ ... real_div2 = self.real_div2(transpose2, Tensor(2.37891))
1092
+ ... batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
1093
+ ... expand_dims = self.expand_dims(self.y, 1)
1094
+ ... sub = self.sub(Tensor([1.0]), expand_dims)
1095
+ ... soft_max = self.softmax(sub)
1096
+ ... transpose3 = self.transpose3(x, (0, 2, 1, 3))
1097
+ ... batch_matmul2 = self.batch_matmul2(soft_max[0], transpose3)
1098
+ ... transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
1099
+ ... return transpose4
1100
+ >>>
1101
+ >>> class OuterBlock(nn.Cell):
1102
+ ... @lazy_inline
1103
+ ... def __init__(self):
1104
+ ... super(OuterBlock, self).__init__()
1105
+ ... self.block = Block()
1106
+ ... def construct(self, x):
1107
+ ... return self.block(x)
1108
+ >>>
1109
+ >>> class Nets(nn.Cell):
1110
+ ... def __init__(self):
1111
+ ... super(Nets, self).__init__()
1112
+ ... self.blocks = nn.CellList()
1113
+ ... for _ in range(3):
1114
+ ... b = OuterBlock()
1115
+ ... b.offload()
1116
+ ... self.blocks.append(b)
1117
+ ... def construct(self, x):
1118
+ ... out = x
1119
+ ... for i in range(3):
1120
+ ... out = self.blocks[i](out)
1121
+ ... return out
1122
+ """
1123
+ if context._get_mode() == context.PYNATIVE_MODE:
1124
+ raise ValueError("The Cell offload does not support PyNative mode now.")
1125
+ if isinstance(backward_prefetch, str):
1126
+ Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
1127
+ else:
1128
+ Validator.check_non_negative_int(backward_prefetch)
1129
+ for prim in self._get_prims_recursively():
1130
+ prim._offload(backward_prefetch=backward_prefetch)
1131
+
593
1132
  def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
594
1133
  """
595
1134
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
@@ -598,13 +1137,13 @@ class Cell(Cell_):
598
1137
  strategy for others will be set by sharding propagation.
599
1138
  in_strategy and out_strategy define the input and output layout respectively.
600
1139
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
601
- this input/output, which can refer to the description of `mindspore.ops.Primitive.shard`.
1140
+ this input/output, which can refer to the description of :func:`mindspore.ops.Primitive.shard`.
602
1141
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
603
1142
 
604
1143
  Note:
605
- If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
606
- "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
607
- If the input contain Parameter, its strategy should be set in `in_strategy`.
1144
+ - It is valid only in semi auto parallel or auto parallel mode.
1145
+ In other parallel modes, strategies set here will be ignored.
1146
+ - If the input contain Parameter, its strategy should be set in `in_strategy`.
608
1147
 
609
1148
  Args:
610
1149
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
@@ -618,7 +1157,7 @@ class Cell(Cell_):
618
1157
  If the parameter name is incorrect or the corresponding parameter
619
1158
  has been set, the parameter setting will be ignored.
620
1159
  Default: ``None`` .
621
- device (string): Select a certain device target. It is not in use right now.
1160
+ device (str): Select a certain device target. It is not in use right now.
622
1161
  Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
623
1162
  level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
624
1163
  over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
@@ -650,10 +1189,8 @@ class Cell(Cell_):
650
1189
  ... x = self.block2_shard(x)
651
1190
  ... return x
652
1191
  """
653
- if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
654
- raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
655
- f"Please check the parallel mode in parallel context.")
656
-
1192
+ if ms.communication.management.get_group_size() == 1:
1193
+ return self
657
1194
  shard_fn = Shard()
658
1195
  fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
659
1196
  self._shard_fn = fn
@@ -756,7 +1293,8 @@ class Cell(Cell_):
756
1293
  """
757
1294
  Process cell info before call construct
758
1295
  """
759
- if self.requires_grad:
1296
+ if self.requires_grad and (not _pynative_executor.grad_flag() or _pynative_executor.high_order()):
1297
+ self.is_top_cell = True
760
1298
  _pynative_executor.set_grad_flag(True)
761
1299
  _pynative_executor.new_graph(self, *args, **kwargs)
762
1300
  elif self._dynamic_shape_inputs is not None:
@@ -770,8 +1308,9 @@ class Cell(Cell_):
770
1308
  """
771
1309
  Process cell info after call construct
772
1310
  """
773
- if self.requires_grad:
1311
+ if self.requires_grad and self.is_top_cell:
774
1312
  _pynative_executor.end_graph(self, output, *args, **kwargs)
1313
+ self.is_top_cell = False
775
1314
  elif self._dynamic_shape_inputs is not None:
776
1315
  _pynative_executor.set_cell_use_dynamic_shape_process(False)
777
1316
 
@@ -816,52 +1355,41 @@ class Cell(Cell_):
816
1355
  self._add_attr(key, value)
817
1356
  self._attr_synced = True
818
1357
 
819
- def _set_attr_for_parameter(self, name, value):
820
- """Set attr for parameter."""
821
- cells = self.__dict__.get('_cells')
822
- params = self.__dict__.get('_params')
823
- if params is None:
824
- raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
825
- if name in self.__dict__:
826
- if self.__dict__[name] is not None:
827
- raise TypeError(f"For 'Cell', the {name} should not be Parameter.")
828
- del self.__dict__[name]
829
- if cells and name in cells:
830
- raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.")
831
- self.insert_param_to_cell(name, value)
832
-
833
- def _set_attr_for_parameter_tuple(self, name, value):
834
- """Set attr for parameter in ParameterTuple."""
835
- params = self.__dict__.get('_params')
836
- params_list = self.__dict__.get('_params_list')
837
- if params is None:
838
- raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
839
- exist_names = set("")
840
- exist_objs = set()
841
- for item in value:
842
- if item in exist_objs:
843
- # If there are multiple identical objects, their names only check once.
844
- continue
845
- exist_objs.add(item)
846
- if item.name == PARAMETER_NAME_DEFAULT:
847
- logger.warning("For 'Cell', the parameter definition is deprecated.\n"
848
- "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
849
- item.name = item.name + "$" + str(self._id)
850
- self._id += 1
851
- self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
852
- if item.name in exist_names:
853
- raise ValueError("The value {} , its name '{}' already exists. "
854
- "Please set a unique name for the parameter.".format(value, item.name))
855
- exist_names.add(item.name)
856
-
857
- if context._get_mode() == context.PYNATIVE_MODE:
1358
+ def _set_attr_for_param_or_param_tuple(self, name, value):
1359
+ """Set attr for param and tensor."""
1360
+ if isinstance(value, Parameter):
858
1361
  if name in self.__dict__:
859
1362
  del self.__dict__[name]
860
- if name in params:
861
- del params[name]
862
- params_list[name] = value
863
- else:
864
- object.__setattr__(self, name, value)
1363
+ self.insert_param_to_cell(name, value)
1364
+ elif isinstance(value, ParameterTuple):
1365
+ exist_names = set("")
1366
+ exist_objs = set()
1367
+ for item in value:
1368
+ if item in exist_objs:
1369
+ # If there are multiple identical objects, their names only check once.
1370
+ continue
1371
+ exist_objs.add(item)
1372
+ if item.name == PARAMETER_NAME_DEFAULT:
1373
+ logger.warning("For 'Cell', the parameter definition is deprecated.\n"
1374
+ "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
1375
+ item.name = item.name + "$" + str(self._id)
1376
+ self._id += 1
1377
+ self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1378
+ if item.name in exist_names:
1379
+ raise ValueError("The value {} , its name '{}' already exists. "
1380
+ "Please set a unique name for the parameter.".format(value, item.name))
1381
+ exist_names.add(item.name)
1382
+
1383
+ if context._get_mode() == context.PYNATIVE_MODE:
1384
+ if name in self.__dict__:
1385
+ del self.__dict__[name]
1386
+ params = self.__dict__.get('_params')
1387
+ if name in params:
1388
+ del params[name]
1389
+ params_list = self.__dict__.get('_params_list')
1390
+ params_list[name] = value
1391
+ else:
1392
+ object.__setattr__(self, name, value)
865
1393
 
866
1394
  def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
867
1395
  """Set attr for parameter in list or tuple."""
@@ -874,24 +1402,18 @@ class Cell(Cell_):
874
1402
  item.name = item.name + "$" + str(self._id)
875
1403
  self._id += 1
876
1404
  if item.name in self.exist_names:
877
- raise ValueError("The value {} , its name '{}' already exists. "
878
- "Please set a unique name for the parameter.".format(value, item.name))
1405
+ raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
1406
+ "Please set a unique name for the parameter.")
879
1407
  self.exist_names.add(item.name)
880
1408
  object.__setattr__(self, name, value)
881
1409
 
882
1410
  def _set_attr_for_cell(self, name, value):
883
1411
  """Set attr for cell."""
884
- cells = self.__dict__.get('_cells')
885
- params = self.__dict__.get('_params')
886
- if cells is None:
887
- raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.")
888
1412
  if name in self.__dict__:
889
1413
  del self.__dict__[name]
890
- if params and name in params:
891
- raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.")
892
1414
  if self._auto_prefix:
893
1415
  value.update_parameters_name(name + '.')
894
- cells[name] = value
1416
+ self.insert_child_to_cell(name, value)
895
1417
  if hasattr(self, '_cell_init_args'):
896
1418
  self.cell_init_args += str({name: value})
897
1419
 
@@ -904,30 +1426,57 @@ class Cell(Cell_):
904
1426
  else:
905
1427
  self.insert_param_to_cell(name, None)
906
1428
 
907
- def __setattr__(self, name, value):
908
- cells = self.__dict__.get('_cells')
1429
+ def _set_attr_for_object(self, name, value):
1430
+ """Set attr for py object."""
909
1431
  params = self.__dict__.get('_params')
910
- if isinstance(value, Parameter):
911
- self._set_attr_for_parameter(name, value)
912
- elif isinstance(value, ParameterTuple):
913
- self._set_attr_for_parameter_tuple(name, value)
914
- elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value):
1432
+ if params is not None and name in params:
1433
+ if value is not None:
1434
+ if isinstance(value, Tensor):
1435
+ params[name].set_data(value)
1436
+ return
1437
+ raise TypeError(
1438
+ f"Parameter '{name}' already exists in network, "
1439
+ f"can not assign this type: '{type(value)}' as a parameter.")
1440
+ params[name] = None
1441
+ return
1442
+ cells = self.__dict__.get('_cells')
1443
+ if cells is not None and name in cells:
1444
+ if value is not None:
1445
+ raise TypeError(
1446
+ f"Sub cell '{name}' already exists in network, "
1447
+ f"can not assign this type: '{type(value)}' as a cell.")
1448
+ cells[name] = None
1449
+ return
1450
+ buffers = self.__dict__.get('_buffers')
1451
+ if buffers is not None and name in buffers:
1452
+ if value is not None:
1453
+ raise TypeError(
1454
+ f"Buffer '{name}' already exists in network, "
1455
+ f"can not assign this type: '{type(value)}' as a buffer.")
1456
+ buffers[name] = None
1457
+ return
1458
+ object.__setattr__(self, name, value)
1459
+
1460
+ def __setattr__(self, name, value):
1461
+ if isinstance(value, (Parameter, ParameterTuple)):
1462
+ self._set_attr_for_param_or_param_tuple(name, value)
1463
+ elif _is_parameter_list_or_tuple(value):
915
1464
  self._set_attr_for_parameter_in_list_or_tuple(name, value)
916
1465
  elif isinstance(value, Cell):
917
1466
  self._set_attr_for_cell(name, value)
918
- elif params and name in params:
919
- self._set_attr_for_params(name, value)
920
- elif cells and name in cells:
921
- if value is not None:
922
- raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
923
- self._cells[name] = None
924
- else:
925
- if isinstance(value, Primitive):
926
- value.set_prim_instance_name(name)
927
- self._primitives[name] = value
1467
+ elif isinstance(value, _Buffer):
1468
+ if name in self.__dict__:
1469
+ del self.__dict__[name]
1470
+ self.register_buffer(name, value)
1471
+ elif isinstance(value, Primitive):
1472
+ value.set_prim_instance_name(name)
1473
+ self._primitives[name] = value
928
1474
  object.__setattr__(self, name, value)
929
- if name not in Cell.IGNORE_LIST:
930
- self._attr_synced = False
1475
+ else:
1476
+ self._set_attr_for_object(name, value)
1477
+
1478
+ def _get_name(self):
1479
+ return self.__class__.__name__
931
1480
 
932
1481
  def extend_repr(self):
933
1482
  """
@@ -941,19 +1490,28 @@ class Cell(Cell_):
941
1490
  return self.__repr__()
942
1491
 
943
1492
  def __repr__(self):
944
- extra_str = self.extend_repr()
945
- info_str = self.__class__.__name__ + '<'
946
- if self._cells:
947
- sub_str = '\n'
948
- if extra_str:
949
- sub_str += '{}\n'.format(self.extend_repr())
950
- for key, value in self._cells.items():
951
- sub_str += '({}): {}\n'.format(key, repr(value))
952
- sub_str = sub_str.replace('\n', '\n ') + '>'
953
- info_str += sub_str
954
- else:
955
- info_str += extra_str + '>'
956
- return info_str
1493
+ extra_lines = []
1494
+ extend_repr = self.extend_repr()
1495
+ # empty string will be split into list ['']
1496
+ if extend_repr:
1497
+ extra_lines = extend_repr.split("\n")
1498
+ child_lines = []
1499
+ for key, cell in self._cells.items():
1500
+ cell_str = repr(cell)
1501
+ cell_str = _addindent(cell_str, 2)
1502
+ child_lines.append("(" + key + "): " + cell_str)
1503
+ lines = extra_lines + child_lines
1504
+
1505
+ main_str = self._get_name() + "("
1506
+ if lines:
1507
+ # simple one-liner info, which most builtin Modules will use
1508
+ if len(extra_lines) == 1 and not child_lines:
1509
+ main_str += extra_lines[0]
1510
+ else:
1511
+ main_str += "\n " + "\n ".join(lines) + "\n"
1512
+
1513
+ main_str += ")"
1514
+ return main_str
957
1515
 
958
1516
  def load_parameter_slice(self, params):
959
1517
  """
@@ -1119,9 +1677,11 @@ class Cell(Cell_):
1119
1677
  args (tuple): Args of the Cell object.
1120
1678
  kwargs (dict): Kwargs of the Cell object.
1121
1679
  """
1680
+ _init_auto_parallel_context(self)
1122
1681
  self._compile_args = self._get_compile_args(args)
1123
1682
  _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1124
1683
  jit_config_dict=self._jit_config_dict, **kwargs)
1684
+ _clear_auto_parallel_context(self)
1125
1685
 
1126
1686
  def compile_and_run(self, *args, **kwargs):
1127
1687
  """
@@ -1252,9 +1812,9 @@ class Cell(Cell_):
1252
1812
  >>> net2 = nn.Dense(2, 2)
1253
1813
  >>> net1.insert_child_to_cell("child", net2)
1254
1814
  >>> print(net1)
1255
- ReLU<
1256
- (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1257
- >
1815
+ ReLU(
1816
+ (child): Dense(input_channels=2, output_channels=2, has_bias=True)
1817
+ )
1258
1818
  """
1259
1819
  if not isinstance(child_name, str):
1260
1820
  raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
@@ -1312,13 +1872,22 @@ class Cell(Cell_):
1312
1872
  new_param_tuple.append(param)
1313
1873
  cell.__dict__[key] = ParameterTuple(new_param_tuple)
1314
1874
 
1875
+ def _get_cell_parallel_mode(self):
1876
+ """Determine whether the current cell is in parallel mode."""
1877
+ is_parallel_mode = False
1878
+ for _, param in self.parameters_and_names():
1879
+ if param.param_info.is_param_init:
1880
+ is_parallel_mode = True
1881
+ break
1882
+ return is_parallel_mode
1883
+
1315
1884
  def init_parameters_data(self, auto_parallel_mode=False):
1316
1885
  """
1317
1886
  Initialize all parameters and replace the original saved parameters in cell.
1318
1887
 
1319
1888
  Note:
1320
1889
  trainable_params() and other similar interfaces may return different parameter instance after
1321
- `init_parameters_data`, do not save these results.
1890
+ `init_parameters_data`. It is not recommended to save these results.
1322
1891
 
1323
1892
  Args:
1324
1893
  auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
@@ -1350,15 +1919,24 @@ class Cell(Cell_):
1350
1919
  def _updata(param):
1351
1920
  if param in replace:
1352
1921
  return replace.get(param)
1353
- new_p = param.init_data(None, set_sliced=False)
1922
+ new_p = param.init_data(None, set_sliced=param.sliced)
1354
1923
  replace[param] = new_p
1355
1924
  return new_p
1356
1925
 
1357
1926
  # replace all original usage.
1358
1927
  cells = self.cells_and_names()
1928
+ is_parallel_mode = self._get_cell_parallel_mode()
1929
+ is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
1930
+
1359
1931
  for _, cell in cells:
1360
1932
  params = cell._params.items()
1361
1933
  for param_name, param in params:
1934
+ not_sliced = not param.sliced
1935
+ judgment = not_sliced
1936
+ if param.param_info.is_pipeline_shared_param:
1937
+ continue
1938
+ if is_graph_mode and is_parallel_mode and judgment:
1939
+ continue
1362
1940
  if not auto_parallel_mode:
1363
1941
  cell._params[param_name] = _updata(param)
1364
1942
  continue
@@ -1370,6 +1948,12 @@ class Cell(Cell_):
1370
1948
  param_tuple = cell_dict[key]
1371
1949
  new_param_tuple = []
1372
1950
  for param in param_tuple:
1951
+ not_sliced = not param.sliced
1952
+ judgment = not_sliced
1953
+ if param.param_info.is_pipeline_shared_param:
1954
+ continue
1955
+ if is_graph_mode and is_parallel_mode and judgment:
1956
+ continue
1373
1957
  if not auto_parallel_mode:
1374
1958
  new_param_tuple.append(_updata(param))
1375
1959
  continue
@@ -1677,7 +2261,7 @@ class Cell(Cell_):
1677
2261
  ... return x
1678
2262
  >>> net = Net()
1679
2263
  >>> print(net.cells())
1680
- odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
2264
+ odict_values([Dense(input_channels=2, output_channels=2, has_bias=True)])
1681
2265
  """
1682
2266
  return self.name_cells().values()
1683
2267
 
@@ -1738,7 +2322,7 @@ class Cell(Cell_):
1738
2322
  ... return x
1739
2323
  >>> net = Net()
1740
2324
  >>> print(net.name_cells())
1741
- OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
2325
+ OrderedDict([('dense', Dense(input_channels=2, output_channels=2, has_bias=True))])
1742
2326
  """
1743
2327
  value_set = set()
1744
2328
  cells = OrderedDict()
@@ -1779,10 +2363,10 @@ class Cell(Cell_):
1779
2363
  ... if isinstance(cell, nn.Dense):
1780
2364
  ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
1781
2365
  >>> net.apply(func)
1782
- SequentialCell<
1783
- (0): Dense<input_channels=2, output_channels=2, has_bias=True>
1784
- (1): Dense<input_channels=2, output_channels=2, has_bias=True>
1785
- >
2366
+ SequentialCell(
2367
+ (0): Dense(input_channels=2, output_channels=2, has_bias=True)
2368
+ (1): Dense(input_channels=2, output_channels=2, has_bias=True)
2369
+ )
1786
2370
  >>> print(net[0].weight.asnumpy())
1787
2371
  [[1. 1.]
1788
2372
  [1. 1.]]
@@ -1914,8 +2498,8 @@ class Cell(Cell_):
1914
2498
  >>>
1915
2499
  >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1916
2500
  >>> net.to_float(mstype.float16)
1917
- Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1918
- padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
2501
+ Conv2d(input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
2502
+ padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW)
1919
2503
  """
1920
2504
  if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1921
2505
  raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
@@ -1955,9 +2539,8 @@ class Cell(Cell_):
1955
2539
 
1956
2540
  def set_grad(self, requires_grad=True):
1957
2541
  """
1958
- Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires
1959
- gradients. If ``true`` , the backward network needed to compute the gradients will be generated when the forward
1960
- network is executed.
2542
+ Sets the cell flag for gradient.
2543
+
1961
2544
 
1962
2545
  Args:
1963
2546
  requires_grad (bool): Specifies if the net need to grad, if it is
@@ -2121,8 +2704,7 @@ class Cell(Cell_):
2121
2704
  """
2122
2705
  if context._get_mode() == context.GRAPH_MODE:
2123
2706
  return HookHandle()
2124
- if not check_hook_fn("register_forward_pre_hook", hook_fn):
2125
- return HookHandle()
2707
+ check_hook_fn(hook_fn)
2126
2708
  handle = HookHandle(self._forward_pre_hook)
2127
2709
  self._forward_pre_hook[handle.handle_id] = hook_fn
2128
2710
  return handle
@@ -2217,10 +2799,11 @@ class Cell(Cell_):
2217
2799
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2218
2800
  value= [ 2.00000000e+00]))
2219
2801
  """
2220
- if context._get_mode() == context.GRAPH_MODE:
2802
+ if self.has_bprop:
2221
2803
  return HookHandle()
2222
- if not check_hook_fn("register_forward_hook", hook_fn):
2804
+ if context._get_mode() == context.GRAPH_MODE:
2223
2805
  return HookHandle()
2806
+ check_hook_fn(hook_fn)
2224
2807
  handle = HookHandle(self._forward_hook)
2225
2808
  self._forward_hook[handle.handle_id] = hook_fn
2226
2809
  return handle
@@ -2310,8 +2893,7 @@ class Cell(Cell_):
2310
2893
  """
2311
2894
  if context._get_mode() == context.GRAPH_MODE:
2312
2895
  return HookHandle()
2313
- if not check_hook_fn("register_backward_pre_hook", hook_fn):
2314
- return HookHandle()
2896
+ check_hook_fn(hook_fn)
2315
2897
  handle = HookHandle(self._backward_pre_hook)
2316
2898
  self._backward_pre_hook[handle.handle_id] = hook_fn
2317
2899
  if self._cell_backward_pre_hook is None:
@@ -2334,9 +2916,12 @@ class Cell(Cell_):
2334
2916
  Supported Platforms:
2335
2917
  ``Ascend`` ``GPU`` ``CPU``
2336
2918
  """
2337
- ret = self._cell_backward_pre_hook(outputs)
2338
2919
  if isinstance(outputs, tuple):
2339
- if not isinstance(ret, tuple):
2920
+ ret = self._cell_backward_pre_hook(*outputs)
2921
+ else:
2922
+ ret = self._cell_backward_pre_hook(outputs)
2923
+ if isinstance(outputs, tuple):
2924
+ if len(outputs) == 1:
2340
2925
  ret = (ret,)
2341
2926
  if len(ret) != len(outputs):
2342
2927
  raise TypeError(
@@ -2344,6 +2929,527 @@ class Cell(Cell_):
2344
2929
  len(ret), len(outputs)))
2345
2930
  return ret
2346
2931
 
2932
+ def get_extra_state(self) -> Any:
2933
+ """Return any extra state to include in the cell's state_dict.
2934
+
2935
+ This function is called from ``state_dict``.
2936
+ Implement this and a corresponding ``set_extra_state`` for your cell
2937
+ if you need to store extra state.
2938
+
2939
+ Note that extra state should be picklable to ensure working serialization
2940
+ of the state_dict. Only provide backwards compatibility guarantees
2941
+ for serializing tensors; other objects may break backwards compatibility if
2942
+ their serialized pickled form changes.
2943
+
2944
+ Returns:
2945
+ object, any extra state to store in the cell's state_dict.
2946
+ """
2947
+ raise RuntimeError(
2948
+ "Reached a code path in Cell.get_extra_state() that should never be called."
2949
+
2950
+ )
2951
+
2952
+ def set_extra_state(self, state: Any) -> None:
2953
+ """Set extra state contained in the loaded `state_dict`.
2954
+
2955
+ This function is called from `load_state_dict` to handle any extra state
2956
+ found within the `state_dict`. Implement this function and a corresponding
2957
+ `get_extra_state` for your cell if you need to store extra state within its
2958
+ `state_dict`.
2959
+
2960
+ Args:
2961
+ state (dict): Extra state from the `state_dict`.
2962
+ """
2963
+ raise RuntimeError(
2964
+ "Reached a code path in Cell.set_extra_state() that should never be called."
2965
+ )
2966
+
2967
+ @jit_forbidden_register
2968
+ def register_state_dict_post_hook(self, hook):
2969
+ r"""Register a post-hook for the :func:`mindspore.nn.Cell.state_dict` method.
2970
+
2971
+ It should have the following signature:
2972
+
2973
+ hook(cell, state_dict, prefix, local_metadata) -> None
2974
+
2975
+ The registered hooks can modify the ``state_dict`` inplace.
2976
+
2977
+ Args:
2978
+ hook (Callable): The hook function after `state_dict` is called.
2979
+
2980
+ Returns:
2981
+ A handle that can be used to remove the added hook by calling
2982
+ `handle.remove()`.
2983
+ """
2984
+ from mindspore.utils.hooks import _RemovableHandle
2985
+ handle = _RemovableHandle(self._state_dict_hooks)
2986
+ self._state_dict_hooks[handle.id] = hook
2987
+ return handle
2988
+
2989
+ @jit_forbidden_register
2990
+ def register_state_dict_pre_hook(self, hook):
2991
+ r"""Register a pre-hook for the :func:`mindspore.nn.Cell.state_dict` method.
2992
+
2993
+ It should have the following signature:
2994
+
2995
+ hook(cell, prefix, keep_vars) -> None
2996
+
2997
+ The registered hooks can be used to perform pre-processing before the `state_dict`
2998
+ call is made.
2999
+
3000
+ Args:
3001
+ hook (Callable): The hook function before `state_dict` is called.
3002
+
3003
+ Returns:
3004
+ A handle that can be used to remove the added hook by calling
3005
+ `handle.remove()`.
3006
+
3007
+ Examples:
3008
+ >>> import mindspore
3009
+ ...
3010
+ ...
3011
+ >>> class NetA(mindspore.nn.Cell):
3012
+ ... def __init__(self):
3013
+ ... super().__init__()
3014
+ ... self.register_buffer("buffer_a", mindspore.tensor([1, 2, 3]))
3015
+ ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
3016
+ ...
3017
+ ... def construct(self, x):
3018
+ ... return x + self.buffer_a + self.param_a
3019
+ ...
3020
+ ...
3021
+ >>> def _add_extra_param(cell, prefix, keep_vars):
3022
+ ... cell._params["extra_param"] = mindspore.Parameter(mindspore.tensor([4, 5, 6]))
3023
+ ...
3024
+ ...
3025
+ >>> net = NetA()
3026
+ >>> handle = net.register_state_dict_pre_hook(_add_extra_param)
3027
+ >>> net_state_dict = net.state_dict()
3028
+ >>> handle.remove()
3029
+ >>> print("extra_param" in net_state_dict)
3030
+ True
3031
+ """
3032
+ from mindspore.utils.hooks import _RemovableHandle
3033
+ handle = _RemovableHandle(self._state_dict_pre_hooks)
3034
+ self._state_dict_pre_hooks[handle.id] = hook
3035
+ return handle
3036
+
3037
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
3038
+ r"""Save cell state to the `destination` dictionary.
3039
+
3040
+ The `destination` dictionary will contain the state
3041
+ of the cell, but not its descendants. This is called on every
3042
+ sub cell in :func:`mindspore.nn.Cell.state_dict`.
3043
+
3044
+ In rare cases, subclasses can achieve class-specific behavior by
3045
+ overriding this method with custom logic.
3046
+
3047
+ Args:
3048
+ destination (dict): a dict where state will be stored
3049
+ prefix (str): the prefix for parameters and buffers used in this
3050
+ cell
3051
+ """
3052
+ for name, param in self._params.items():
3053
+ if param is not None:
3054
+ destination[prefix + name] = param
3055
+ for name, buf in self._buffers.items():
3056
+ if buf is not None and name not in self._non_persistent_buffers_set:
3057
+ destination[prefix + name] = buf
3058
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
3059
+ if (
3060
+ getattr(self.__class__, "get_extra_state", Cell.get_extra_state)
3061
+ is not Cell.get_extra_state
3062
+ ):
3063
+ destination[extra_state_key] = self.get_extra_state()
3064
+
3065
+ # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
3066
+ # back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
3067
+ T_destination = TypeVar("T_destination", bound=Dict[str, Any])
3068
+
3069
+ @jit_forbidden_register
3070
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
3071
+ r"""Return a dictionary containing references to the whole state of the cell.
3072
+
3073
+ Both parameters and persistent buffers (e.g. running averages) are
3074
+ included. Keys are corresponding parameter and buffer names.
3075
+ Parameters and buffers set to ``None`` are not included.
3076
+
3077
+ .. note::
3078
+ The returned object is a shallow copy. It contains references
3079
+ to the cell's parameters and buffers.
3080
+
3081
+ .. warning::
3082
+ - Currently ``state_dict()`` also accepts positional arguments for
3083
+ ``destination``, ``prefix`` and ``keep_vars`` in order. However,
3084
+ this is being deprecated and keyword arguments will be enforced in
3085
+ future releases.
3086
+
3087
+ - Please avoid the use of argument ``destination`` as it is not
3088
+ designed for end-users.
3089
+
3090
+ Args:
3091
+ destination (dict, optional): If provided, the state of cell will
3092
+ be updated into the dict and the same object is returned.
3093
+ Otherwise, an ``OrderedDict`` will be created and returned.
3094
+ Default: ``None``.
3095
+ prefix (str, optional): A prefix added to parameter and buffer
3096
+ names to compose the keys in state_dict. Default: ``''``.
3097
+ keep_vars (bool, optional): Whether the state_dict returns a copy. Default: ``False`` , returns a reference.
3098
+
3099
+ Returns:
3100
+ Dict, a dictionary containing a whole state of the cell.
3101
+
3102
+ Examples:
3103
+ >>> import mindspore
3104
+ >>> class Model(mindspore.nn.Cell):
3105
+ ... def __init__(self):
3106
+ ... super().__init__()
3107
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
3108
+ ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
3109
+ ...
3110
+ ... def construct(self, x):
3111
+ ... return x + self.buffer_a + self.param_a
3112
+ ...
3113
+ ...
3114
+ >>> model = Model()
3115
+ >>> print(model.state_dict())
3116
+ OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3117
+ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3118
+ """
3119
+ # TODO: Remove `args` and the parsing logic when BC allows.
3120
+ if args:
3121
+ # DeprecationWarning is ignored by default
3122
+ warnings.warn(
3123
+ "Positional args are being deprecated, use kwargs instead. Refer to "
3124
+ "https://www.mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html"
3125
+ " for details.",
3126
+ FutureWarning,
3127
+ stacklevel=2,
3128
+ )
3129
+ if destination is None:
3130
+ destination = args[0]
3131
+ if len(args) > 1 and prefix == "":
3132
+ prefix = args[1]
3133
+ if len(args) > 2 and keep_vars is False:
3134
+ keep_vars = args[2]
3135
+ if destination is not None and not isinstance(destination, dict):
3136
+ raise TypeError(f"The type of destination must be OrderedDict, but got {type(destination)}")
3137
+ if not isinstance(prefix, str):
3138
+ raise TypeError(f"The type of prefix must be string, but got {type(prefix)}")
3139
+ if not isinstance(keep_vars, bool):
3140
+ raise TypeError(f"The type of keep_vars must be bool, but got {type(keep_vars)}")
3141
+
3142
+ if destination is None:
3143
+ destination = OrderedDict()
3144
+ destination._metadata = OrderedDict()
3145
+
3146
+ local_metadata = {}
3147
+ if hasattr(destination, "_metadata"):
3148
+ destination._metadata[prefix[:-1]] = local_metadata
3149
+
3150
+ for hook in self._state_dict_pre_hooks.values():
3151
+ hook(self, prefix, keep_vars)
3152
+ self._save_to_state_dict(destination, prefix, keep_vars)
3153
+ for name, cell in self._cells.items():
3154
+ if cell is not None:
3155
+ cell.state_dict(
3156
+ destination=destination,
3157
+ prefix=prefix + name + ".",
3158
+ keep_vars=keep_vars,
3159
+ )
3160
+ for hook in self._state_dict_hooks.values():
3161
+ hook_result = hook(self, destination, prefix, local_metadata)
3162
+ if hook_result is not None:
3163
+ raise RuntimeError("state_dict post-hook must return None")
3164
+ return destination
3165
+
3166
+ @jit_forbidden_register
3167
+ def register_load_state_dict_pre_hook(self, hook):
3168
+ r"""Register a pre-hook to be run before cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
3169
+
3170
+ It should have the following signature:
3171
+
3172
+ hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
3173
+
3174
+ Args:
3175
+ hook (Callable): The hook function before `load_state_dict` is called.
3176
+
3177
+ Returns:
3178
+ A handle that can be used to remove the added hook by calling
3179
+ `handle.remove()`.
3180
+ """
3181
+ from mindspore.utils.hooks import _RemovableHandle
3182
+ handle = _RemovableHandle(self._load_state_dict_pre_hooks)
3183
+ self._load_state_dict_pre_hooks[handle.id] = hook
3184
+ return handle
3185
+
3186
+ @jit_forbidden_register
3187
+ def register_load_state_dict_post_hook(self, hook):
3188
+ r"""Register a post-hook to be run after cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
3189
+
3190
+ It should have the following signature:
3191
+
3192
+ hook(cell, incompatible_keys) -> None
3193
+
3194
+ The ``cell`` argument is the current cell that this hook is registered
3195
+ on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
3196
+ of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
3197
+ is a ``list`` of ``str`` containing the missing keys and
3198
+ ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
3199
+
3200
+ The given incompatible_keys can be modified inplace if needed.
3201
+
3202
+ Note that the checks performed when calling :func:`load_state_dict` with
3203
+ ``strict=True`` are affected by modifications the hook makes to
3204
+ ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
3205
+ set of keys will result in an error being thrown when ``strict=True``, and
3206
+ clearing out both missing and unexpected keys will avoid an error.
3207
+
3208
+ Args:
3209
+ hook (Callable): The hook function after `load_state_dict` is called.
3210
+
3211
+ Returns:
3212
+ A handle that can be used to remove the added hook by calling
3213
+ `handle.remove()`.
3214
+ """
3215
+ from mindspore.utils.hooks import _RemovableHandle
3216
+ handle = _RemovableHandle(self._load_state_dict_post_hooks)
3217
+ self._load_state_dict_post_hooks[handle.id] = hook
3218
+ return handle
3219
+
3220
+ def _load_from_state_dict(
3221
+ self,
3222
+ state_dict,
3223
+ prefix,
3224
+ local_metadata,
3225
+ strict,
3226
+ missing_keys,
3227
+ unexpected_keys,
3228
+ error_msgs,
3229
+ ):
3230
+ r"""Copy parameters and buffers from :attr:`state_dict` into only this cell, but not its descendants.
3231
+
3232
+ This is called on every sub cell
3233
+ in :func:`mindspore.nn.Cell.load_state_dict`. Metadata saved for this
3234
+ cell in input :attr:`state_dict` is provided as :attr:`local_metadata`.
3235
+ For state dicts without metadata, :attr:`local_metadata` is empty.
3236
+ Subclasses can achieve class-specific backward compatible loading using
3237
+ the version number at `local_metadata.get("version", None)`.
3238
+
3239
+ .. note::
3240
+ :attr:`state_dict` is not the same object as the input
3241
+ :attr:`state_dict` to :func:`mindspore.nn.Cell.load_state_dict`. So
3242
+ it can be modified.
3243
+
3244
+ Args:
3245
+ state_dict (dict): a dict containing parameters and
3246
+ persistent buffers.
3247
+ prefix (str): the prefix for parameters and buffers used in this
3248
+ cell
3249
+ local_metadata (dict): a dict containing the metadata for this cell.
3250
+ See
3251
+ strict (bool): whether to strictly enforce that the keys in
3252
+ :attr:`state_dict` with :attr:`prefix` match the names of
3253
+ parameters and buffers in this cell
3254
+ missing_keys (list of str): if ``strict=True``, add missing keys to
3255
+ this list
3256
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
3257
+ keys to this list
3258
+ error_msgs (list of str): error messages should be added to this
3259
+ list, and will be reported together in
3260
+ :func:`mindspore.nn.Cell.load_state_dict`
3261
+ """
3262
+ for hook in self._load_state_dict_pre_hooks.values():
3263
+ hook(
3264
+ self,
3265
+ state_dict,
3266
+ prefix,
3267
+ local_metadata,
3268
+ strict,
3269
+ missing_keys,
3270
+ unexpected_keys,
3271
+ error_msgs,
3272
+ )
3273
+
3274
+ persistent_buffers = {
3275
+ k: v
3276
+ for k, v in self._buffers.items()
3277
+ if k not in self._non_persistent_buffers_set
3278
+ }
3279
+ local_name_params = itertools.chain(
3280
+ self._params.items(), persistent_buffers.items()
3281
+ )
3282
+ local_state = {k: v for k, v in local_name_params if v is not None}
3283
+
3284
+ for name, param in local_state.items():
3285
+ key = prefix + name
3286
+ if key in state_dict:
3287
+ input_param = state_dict[key]
3288
+ if not isinstance(input_param, Tensor):
3289
+ error_msgs.append(
3290
+ f'While copying the parameter named "{key}", '
3291
+ "expected Tensor or Tensor-like object from checkpoint but "
3292
+ f"received {type(input_param)}"
3293
+ )
3294
+ continue
3295
+
3296
+ if input_param.shape != param.shape:
3297
+ # local shape should match the one in checkpoint
3298
+ error_msgs.append(
3299
+ f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, "
3300
+ f"the shape in current model is {param.shape}."
3301
+ )
3302
+ continue
3303
+ try:
3304
+ param.assign_value(Tensor(input_param.asnumpy(), dtype=param.dtype))
3305
+ except Exception as ex: # pylint: disable=W0703
3306
+ error_msgs.append(
3307
+ f'While copy the parameter named "{key}", '
3308
+ f"whose shape in the model are {param.shape} and "
3309
+ f"whose shape in the checkpoint are {input_param.shape}, "
3310
+ f"an exception occurred : {ex.args}."
3311
+ )
3312
+ elif strict:
3313
+ missing_keys.append(key)
3314
+
3315
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
3316
+ if getattr(self.__class__, "set_extra_state", Cell.set_extra_state) is not Cell.set_extra_state:
3317
+ if extra_state_key in state_dict:
3318
+ self.set_extra_state(state_dict[extra_state_key])
3319
+ elif strict:
3320
+ missing_keys.append(extra_state_key)
3321
+ elif strict and (extra_state_key in state_dict):
3322
+ unexpected_keys.append(extra_state_key)
3323
+
3324
+ if strict:
3325
+ for key in state_dict.keys():
3326
+ if key.startswith(prefix) and key != extra_state_key:
3327
+ input_name = key[len(prefix):].split(".", 1)
3328
+ # Must be cell if it have attributes
3329
+ if len(input_name) > 1:
3330
+ if input_name[0] not in self._cells:
3331
+ unexpected_keys.append(key)
3332
+ elif input_name[0] not in local_state:
3333
+ unexpected_keys.append(key)
3334
+
3335
+ @jit_forbidden_register
3336
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
3337
+ r"""Copy parameters and buffers from :attr:`state_dict` into this cell and its descendants.
3338
+
3339
+ If :attr:`strict` is ``True``, then
3340
+ the keys of :attr:`state_dict` must exactly match the keys returned
3341
+ by this cell's :func:`mindspore.nn.Cell.state_dict` function.
3342
+
3343
+ Args:
3344
+ state_dict (dict): A dict containing parameters and
3345
+ persistent buffers.
3346
+ strict (bool, optional): Whether to strictly enforce that the keys
3347
+ in input `state_dict` match the keys returned by this cell's
3348
+ :func:`mindspore.nn.Cell.state_dict` function. Default ``True`` .
3349
+
3350
+ Returns:
3351
+ A namedtuple with ``missing_keys`` and ``unexpected_keys`` fields,
3352
+
3353
+ - `missing_keys` is a list of str containing any keys that are expected
3354
+ by this cell but missing from the provided ``state_dict``.
3355
+
3356
+ - `unexpected_keys` is a list of str containing the keys that are not
3357
+ expected by this cell but present in the provided ``state_dict``.
3358
+
3359
+ Note:
3360
+ If `strict` is ``True`` and a parameter or buffer is registered as ``None``, but its corresponding key
3361
+ exists in :attr:`state_dict`, and :func:`mindspore.nn.Cell.load_state_dict` will raise a ``RuntimeError``.
3362
+
3363
+ Examples:
3364
+ >>> import mindspore
3365
+ >>> import os
3366
+ >>> class Model(mindspore.nn.Cell):
3367
+ ... def __init__(self):
3368
+ ... super().__init__()
3369
+ ... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
3370
+ ... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
3371
+ ...
3372
+ ... def construct(self, x):
3373
+ ... return x + self.buffer_a + self.param_a
3374
+ ...
3375
+ ...
3376
+ >>> model = Model()
3377
+ >>> print(model.state_dict())
3378
+ >>> mindspore.save_checkpoint(model.state_dict(), './model_state_dict_ckpt')
3379
+ >>> new_model = Model()
3380
+ >>> new_model.load_state_dict(mindspore.load_checkpoint('./model_state_dict_ckpt'))
3381
+ >>> print(new_model.state_dict())
3382
+ >>> os.remove('./model_state_dict_ckpt')
3383
+ OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3384
+ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3385
+ OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3386
+ ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3387
+ """
3388
+ if not isinstance(state_dict, Mapping):
3389
+ raise TypeError(
3390
+ f"Expected state_dict to be dict-like, got {type(state_dict)}."
3391
+ )
3392
+
3393
+ missing_keys: List[str] = []
3394
+ unexpected_keys: List[str] = []
3395
+ error_msgs: List[str] = []
3396
+
3397
+ # copy state_dict so _load_from_state_dict can modify it
3398
+ metadata = getattr(state_dict, "_metadata", None)
3399
+ state_dict = OrderedDict(state_dict)
3400
+ if metadata is not None:
3401
+ # mypy isn't aware that "_metadata" exists in state_dict
3402
+ state_dict._metadata = metadata # type: ignore[attr-defined]
3403
+
3404
+ def load(cell, local_state_dict, prefix=""):
3405
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
3406
+ cell._load_from_state_dict(
3407
+ local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
3408
+ )
3409
+ for name, child in cell._cells.items():
3410
+ if child is not None:
3411
+ child_prefix = prefix + name + "."
3412
+ child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
3413
+ load(child, child_state_dict, child_prefix) # noqa: F821
3414
+
3415
+ # Note that the hook can modify missing_keys and unexpected_keys.
3416
+ incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
3417
+ for hook in cell._load_state_dict_post_hooks.values():
3418
+ out = hook(cell, incompatible_keys)
3419
+ if out is not None:
3420
+ raise RuntimeError(
3421
+ "Hooks registered with ``register_load_state_dict_post_hook`` are not"
3422
+ "expected to return new values, if incompatible_keys need to be modified,"
3423
+ "it should be done inplace."
3424
+ )
3425
+
3426
+ load(self, state_dict)
3427
+ del load
3428
+
3429
+ if strict:
3430
+ if unexpected_keys:
3431
+ error_msgs.insert(
3432
+ 0,
3433
+ "Unexpected key(s) in state_dict: {}. ".format(
3434
+ ", ".join(f'"{k}"' for k in unexpected_keys)
3435
+ ),
3436
+ )
3437
+ if missing_keys:
3438
+ error_msgs.insert(
3439
+ 0,
3440
+ "Missing key(s) in state_dict: {}. ".format(
3441
+ ", ".join(f'"{k}"' for k in missing_keys)
3442
+ ),
3443
+ )
3444
+
3445
+ if error_msgs:
3446
+ raise RuntimeError(
3447
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
3448
+ self.__class__.__name__, "\n\t".join(error_msgs)
3449
+ )
3450
+ )
3451
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
3452
+
2347
3453
  def register_backward_hook(self, hook_fn):
2348
3454
  """
2349
3455
  Register the backward hook function.
@@ -2403,8 +3509,7 @@ class Cell(Cell_):
2403
3509
  """
2404
3510
  if context._get_mode() == context.GRAPH_MODE:
2405
3511
  return HookHandle()
2406
- if not check_hook_fn("register_backward_hook", hook_fn):
2407
- return HookHandle()
3512
+ check_hook_fn(hook_fn)
2408
3513
  handle = HookHandle(self._backward_hook)
2409
3514
  self._backward_hook[handle.handle_id] = hook_fn
2410
3515
  if self._cell_backward_hook is None:
@@ -2452,9 +3557,14 @@ class Cell(Cell_):
2452
3557
  outputs = self.construct(*outputs, **kwargs)
2453
3558
  else:
2454
3559
  outputs = self.construct(outputs, **kwargs)
2455
-
2456
- outputs = self._cell_backward_hook(outputs)
2457
- return outputs
3560
+ if isinstance(outputs, tuple):
3561
+ new_outputs = self._cell_backward_hook(*outputs)
3562
+ else:
3563
+ new_outputs = self._cell_backward_hook(outputs)
3564
+ # if outputs is (X,) and new_outpus is X
3565
+ if isinstance(outputs, tuple) and len(outputs) == 1:
3566
+ new_outputs = (new_outputs,)
3567
+ return new_outputs
2458
3568
 
2459
3569
  def set_param_ps(self, recurse=True, init_in_server=False):
2460
3570
  """
@@ -2543,8 +3653,9 @@ class Cell(Cell_):
2543
3653
  if not self._has_config_recompute:
2544
3654
  self._has_config_recompute = True
2545
3655
  else:
2546
- raise RuntimeError("The recompute interface can be configured only once."
2547
- " When the parent cell is configured, the child cell should not be configured")
3656
+ logger.info("The recompute interface can be configured only once."
3657
+ " When the parent cell is configured, the child cell should not be configured")
3658
+ return
2548
3659
  self._set_recompute_scope(mode)
2549
3660
  if mode and not output_recompute:
2550
3661
  self.add_flags(output_no_recompute=True)
@@ -2584,18 +3695,13 @@ class Cell(Cell_):
2584
3695
  """
2585
3696
  if context.get_context("mode") == context.PYNATIVE_MODE:
2586
3697
  self._recompute_cell = recompute_registry.get()(self.construct)
2587
- self._add_recompute_flag()
2588
- return
2589
3698
  self._recompute()
2590
3699
  if 'mp_comm_recompute' in kwargs.keys():
2591
3700
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
2592
3701
  if 'parallel_optimizer_comm_recompute' in kwargs.keys():
2593
- if (kwargs.get('parallel_optimizer_comm_recompute', False) and
2594
- context.get_auto_parallel_context("pipeline_stages") > 1):
3702
+ if kwargs.get('parallel_optimizer_comm_recompute', False):
2595
3703
  logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
2596
- "are not support recomputation in pipeline parallel.")
2597
- elif context.get_auto_parallel_context("pipeline_stages") == 1:
2598
- self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
3704
+ "is replaced with zero3.")
2599
3705
  if 'recompute_slice_activation' in kwargs:
2600
3706
  self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
2601
3707
 
@@ -2687,17 +3793,91 @@ class Cell(Cell_):
2687
3793
  if hasattr(network, "_amp_level"):
2688
3794
  self._amp_level = getattr(network, "_amp_level")
2689
3795
 
2690
- def _add_recompute_flag(self):
2691
- """
2692
- Set pynative cell recomputed.
3796
+ def _register_parameters_hook(self, forward_hook=None, backward_hook=None, all=False):
2693
3797
  """
2694
- if not self._has_config_recompute:
2695
- self._has_config_recompute = True
3798
+ Register the forward hook for parameters and register the backward hook for the corresponding gradient.
3799
+
3800
+ .. warning::
3801
+ This is an experimental prototype that is subject to change and/or deletion.
3802
+
3803
+ Note:
3804
+ - The `_register_parameters_hook(forward_hook, backward_hook)` only work in graph mode
3805
+ - The `forward_hook` must be defined as the following code.
3806
+ `parameters`: the tuple of the trainble parameters of the Cell, each element in the tuple shuould be
3807
+ in the format of `(param_name, Parameter)`.
3808
+ - The `forward_hook` should have the following signature:
3809
+ forward_hook(parameters) -> None.
3810
+ - The `backward_hook` must be defined as the following code.
3811
+ `gradients`: the tuple of the gradients corresponding to the trainble parameters of the Cell, each
3812
+ element in the tuple shuould be in the format of `(param_name, gradient)`.
3813
+ - The `backward_hook` should have the following signature:
3814
+ backward_hook(parameters) -> New gradients.
3815
+
3816
+ Args:
3817
+ forward_hook (function, optional): Python function or ``None``, Forward hook function. Default: ``None``
3818
+ backward_hook (function, optional): Python function or ``None``, Backward hook function. Default ``None``
3819
+ all (bool, optional): bool, whether to set hooks for all sub cells recursively. Default: ``False``
3820
+
3821
+ Returns:
3822
+ None
3823
+
3824
+ Raises:
3825
+ RuntimeError: If the `forward_hook` or `backward_hook ` has unspoorted syntax under GRAPH MODE.
3826
+ TypeError: If the `forward_hook` or `backward_hook` is not defined as required.
3827
+
3828
+ Supported Platforms:
3829
+ ``Ascend`` ``GPU`` ``CPU``
3830
+
3831
+ Examples:
3832
+ >>> import mindspore as ms
3833
+ >>> from mindspore import Tensor, nn, ops, Parameter
3834
+ >>>
3835
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
3836
+ >>> def parameter_hook(parameters):
3837
+ ... print("--- enter parameter hook ---")
3838
+ ... for name, param in parameters:
3839
+ ... print (name, param)
3840
+ ... print("--- leave parameter hook ---")
3841
+ ...
3842
+ >>> def gradient_hook(gradients):
3843
+ ... print("--- enter gradient hook ---")
3844
+ ... outs = []
3845
+ ... for name, gradient in gradients:
3846
+ ... print(name, gradient)
3847
+ ... outs.append(gradient * 2) # double gradient
3848
+ ... print("--- leave gradient hook ---")
3849
+ ... return outs
3850
+ ...
3851
+ >>> class Net(nn.Cell):
3852
+ ... def __init__(self)
3853
+ ... super(Net, self).__init__()
3854
+ ... self.w = Parameter(Tensor(np.array([3.0], np.float32)), name='w')
3855
+ ... def construct(self, x):
3856
+ ... return self.w * x
3857
+ ...
3858
+ >>> grad = ops.GradOperation(get_by_list=True)
3859
+ >>> net = Net()
3860
+ >>> net._register_parameters_hook(forward_hook=parameter_hook, backward_hook=gradient_hook)
3861
+ >>> x = Tensor(np.array([4.0]).astype(np.float32))
3862
+ >>> output = grad(net, net.trainable_params())(x)
3863
+ --- enter parameter hook ---
3864
+ w
3865
+ Tensor(shape=[1], dtype=Float32, value=[ 3.00000000e+00])
3866
+ --- leave parameter hook ---
3867
+ --- enter gradient hook ---
3868
+ w
3869
+ Tensor(shape=[1], dtype=Float32, value=[ 4.00000000e+00])
3870
+ --- leave gradient hook ---
3871
+ >>> print("doubled grad: ", output)
3872
+ doubled grad: (Tensor(shape=[1], dtype=Float32, value=[ 8.00000000e+00]),)
3873
+ """
3874
+ if not all:
3875
+ self._parameters_forward_hook = forward_hook
3876
+ self._parameters_backward_hook = backward_hook
2696
3877
  else:
2697
- logger.info("The recompute interface can be configured only once."
2698
- " If the parent cell is configured, the child cell should not be configured")
2699
- for cell in self.cells():
2700
- cell._add_recompute_flag()
3878
+ for _, cell in self.cells_and_names():
3879
+ cell._parameters_forward_hook = forward_hook
3880
+ cell._parameters_backward_hook = backward_hook
2701
3881
 
2702
3882
 
2703
3883
  class GraphCell(Cell):
@@ -2713,12 +3893,10 @@ class GraphCell(Cell):
2713
3893
  The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
2714
3894
  If the parameter exists in the graph according to the name, update it's value.
2715
3895
  If the parameter does not exist, ignore it. Default: ``None`` .
2716
- obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
2717
- used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
2718
- a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
2719
- provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
3896
+ obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation, which is not supported now.
2720
3897
 
2721
3898
  Raises:
3899
+ NotImplementedError: Dynamic structure obfuscation is not supported now.
2722
3900
  TypeError: If the `graph` is not a FuncGraph.
2723
3901
  TypeError: If the `params_init` is not a dict.
2724
3902
  TypeError: If the key of the `params_init` is not a str.
@@ -2748,20 +3926,12 @@ class GraphCell(Cell):
2748
3926
 
2749
3927
  def __init__(self, graph, params_init=None, obf_random_seed=None):
2750
3928
  super(GraphCell, self).__init__(auto_prefix=True)
3929
+ if obf_random_seed is not None:
3930
+ raise NotImplementedError("Dynamic structure obfuscation is not supported now.")
2751
3931
  if not isinstance(graph, FuncGraph):
2752
3932
  raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
2753
3933
  f"but got type {type(graph)}.")
2754
3934
  self.graph = graph
2755
- self.obf_random_seed = obf_random_seed
2756
- if obf_random_seed is not None:
2757
- if not isinstance(obf_random_seed, int):
2758
- raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
2759
- int_64_max = 9223372036854775807
2760
- if obf_random_seed <= 0 or obf_random_seed > int_64_max:
2761
- raise ValueError(
2762
- "'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
2763
- "but got {}.".format(int_64_max, obf_random_seed))
2764
- self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
2765
3935
  params_init = {} if params_init is None else params_init
2766
3936
  if not isinstance(params_init, dict):
2767
3937
  raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
@@ -2781,19 +3951,30 @@ class GraphCell(Cell):
2781
3951
  def __call__(self, *args, **kwargs):
2782
3952
  self.phase = "graph_load_from_mindir"
2783
3953
  self._add_attr("graph_load_from_mindir", self.graph)
2784
- if not self.obf_random_seed:
2785
- return self.compile_and_run(*args, **kwargs)
2786
- append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
2787
- return self.compile_and_run(*args, append_input, **kwargs)
3954
+ return self.compile_and_run(*args, **kwargs)
2788
3955
 
2789
3956
 
2790
- def _check_param_list_tuple(value):
3957
+ def _is_parameter_list_or_tuple(value):
2791
3958
  """
2792
3959
  Check the type of input in list or tuple is Parameter.
2793
3960
  :param value: list or tuple.
2794
3961
  :return: The types of all inputs are parameter.
2795
3962
  """
2796
- for item in value:
2797
- if not isinstance(item, Parameter):
2798
- return False
2799
- return True
3963
+ if isinstance(value, (list, tuple)) and value:
3964
+ for item in value:
3965
+ if not isinstance(item, Parameter):
3966
+ return False
3967
+ return True
3968
+ return False
3969
+
3970
+
3971
+ def _addindent(s_, num_spaces):
3972
+ s = s_.split("\n")
3973
+ # don't do anything for single-line stuff
3974
+ if len(s) == 1:
3975
+ return s_
3976
+ first = s.pop(0)
3977
+ s = [(num_spaces * " ") + line for line in s]
3978
+ s = "\n".join(s)
3979
+ s = first + "\n" + s
3980
+ return s