mindspore 2.6.0__cp310-cp310-win_amd64.whl → 2.7.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (455) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +64 -83
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +47 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +177 -52
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +338 -208
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +2 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +84 -133
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +47 -38
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +69 -23
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +5 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +1 -0
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +4 -44
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +425 -19
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +125 -101
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +488 -620
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +86 -85
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +2 -4
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/lamb.py +1 -3
  188. mindspore/nn/optim/optimizer.py +1 -1
  189. mindspore/nn/optim/tft_wrapper.py +2 -3
  190. mindspore/nn/optim/thor.py +2 -2
  191. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  192. mindspore/nn/probability/distribution/exponential.py +2 -1
  193. mindspore/nn/probability/distribution/poisson.py +2 -1
  194. mindspore/nn/sparse/sparse.py +3 -3
  195. mindspore/nn/wrap/cell_wrapper.py +73 -42
  196. mindspore/nn/wrap/grad_reducer.py +37 -52
  197. mindspore/nn/wrap/loss_scale.py +72 -74
  198. mindspore/numpy/array_creations.py +7 -7
  199. mindspore/numpy/fft.py +1 -1
  200. mindspore/numpy/math_ops.py +1 -1
  201. mindspore/numpy/utils_const.py +1 -1
  202. mindspore/opencv_core452.dll +0 -0
  203. mindspore/opencv_imgcodecs452.dll +0 -0
  204. mindspore/opencv_imgproc452.dll +0 -0
  205. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  206. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  207. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  208. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  209. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  210. mindspore/ops/_vmap/vmap_array_ops.py +6 -13
  211. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  212. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +29 -10
  213. mindspore/ops/auto_generate/gen_extend_func.py +5 -55
  214. mindspore/ops/auto_generate/gen_ops_def.py +753 -273
  215. mindspore/ops/auto_generate/gen_ops_prim.py +1687 -958
  216. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  217. mindspore/ops/composite/__init__.py +10 -0
  218. mindspore/ops/composite/base.py +9 -5
  219. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  220. mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
  221. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  222. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  223. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  224. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  225. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  226. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  227. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  228. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  229. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  230. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  231. mindspore/ops/function/__init__.py +4 -1
  232. mindspore/ops/function/_add_attr_func.py +11 -6
  233. mindspore/ops/function/array_func.py +17 -100
  234. mindspore/ops/function/debug_func.py +8 -5
  235. mindspore/ops/function/grad/grad_func.py +5 -13
  236. mindspore/ops/function/math_func.py +65 -399
  237. mindspore/ops/function/nn_func.py +44 -61
  238. mindspore/ops/function/other_func.py +4 -1
  239. mindspore/ops/function/random_func.py +31 -4
  240. mindspore/ops/functional.py +2 -3
  241. mindspore/ops/functional_overload.py +486 -18
  242. mindspore/ops/op_info_register.py +21 -0
  243. mindspore/ops/operations/__init__.py +5 -2
  244. mindspore/ops/operations/_custom_ops_utils.py +675 -8
  245. mindspore/ops/operations/_inner_ops.py +14 -18
  246. mindspore/ops/operations/_sequence_ops.py +1 -1
  247. mindspore/ops/operations/array_ops.py +4 -50
  248. mindspore/ops/operations/comm_ops.py +186 -41
  249. mindspore/ops/operations/custom_ops.py +244 -175
  250. mindspore/ops/operations/debug_ops.py +55 -4
  251. mindspore/ops/operations/image_ops.py +13 -13
  252. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  253. mindspore/ops/operations/math_ops.py +8 -9
  254. mindspore/ops/operations/nn_ops.py +6 -7
  255. mindspore/ops/primitive.py +9 -20
  256. mindspore/ops/tensor_method.py +52 -11
  257. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  258. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  259. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  260. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  261. mindspore/ops_generate/common/base_generator.py +14 -0
  262. mindspore/ops_generate/common/gen_constants.py +7 -2
  263. mindspore/ops_generate/common/gen_utils.py +0 -19
  264. mindspore/ops_generate/common/op_proto.py +11 -4
  265. mindspore/ops_generate/common/template.py +88 -11
  266. mindspore/ops_generate/gen_ops.py +1 -1
  267. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  268. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  269. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  270. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  271. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  272. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  273. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  274. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  275. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  276. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  277. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  278. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  279. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  280. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  281. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  282. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  283. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  284. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  285. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  286. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  287. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  288. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  289. mindspore/parallel/_auto_parallel_context.py +9 -17
  290. mindspore/parallel/_cell_wrapper.py +106 -40
  291. mindspore/parallel/_parallel_serialization.py +4 -3
  292. mindspore/parallel/_ps_context.py +4 -6
  293. mindspore/parallel/_tensor.py +167 -12
  294. mindspore/parallel/_transformer/moe.py +1 -1
  295. mindspore/parallel/_transformer/transformer.py +17 -12
  296. mindspore/parallel/_utils.py +5 -11
  297. mindspore/parallel/auto_parallel.py +33 -12
  298. mindspore/parallel/checkpoint_convert.py +3 -3
  299. mindspore/parallel/checkpoint_transform.py +5 -1
  300. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  301. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  302. mindspore/parallel/cluster/run.py +48 -7
  303. mindspore/parallel/function/__init__.py +8 -1
  304. mindspore/parallel/function/reshard_func.py +7 -6
  305. mindspore/parallel/nn/__init__.py +15 -2
  306. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  307. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  308. mindspore/parallel/shard.py +9 -23
  309. mindspore/parallel/transform_safetensors.py +468 -174
  310. mindspore/pgodb140.dll +0 -0
  311. mindspore/pgort140.dll +0 -0
  312. mindspore/profiler/__init__.py +2 -1
  313. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  314. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  315. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
  316. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  317. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  318. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  319. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  321. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  322. mindspore/profiler/analysis/task_manager.py +1 -1
  323. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  324. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  325. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  326. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  327. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  328. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  329. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  330. mindspore/profiler/common/constant.py +16 -0
  331. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  332. mindspore/profiler/common/path_manager.py +9 -0
  333. mindspore/profiler/common/profiler_context.py +50 -29
  334. mindspore/profiler/common/profiler_info.py +0 -16
  335. mindspore/profiler/common/profiler_meta_data.py +1 -0
  336. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  337. mindspore/profiler/common/profiler_output_path.py +23 -8
  338. mindspore/profiler/common/profiler_parameters.py +128 -35
  339. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  340. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  341. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  342. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  343. mindspore/profiler/dynamic_profiler.py +374 -338
  344. mindspore/profiler/envprofiler.py +42 -12
  345. mindspore/profiler/experimental_config.py +112 -7
  346. mindspore/profiler/mstx.py +33 -12
  347. mindspore/profiler/platform/__init__.py +2 -3
  348. mindspore/profiler/platform/cpu_profiler.py +10 -4
  349. mindspore/profiler/platform/npu_profiler.py +30 -20
  350. mindspore/profiler/profiler.py +218 -154
  351. mindspore/profiler/profiler_action_controller.py +65 -77
  352. mindspore/profiler/profiler_interface.py +2 -2
  353. mindspore/profiler/schedule.py +10 -4
  354. mindspore/rewrite/common/config.py +1 -0
  355. mindspore/rewrite/common/namer.py +1 -0
  356. mindspore/rewrite/common/namespace.py +1 -0
  357. mindspore/rewrite/node/node.py +31 -11
  358. mindspore/rewrite/parsers/assign_parser.py +1 -1
  359. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  360. mindspore/run_check/_check_version.py +7 -10
  361. mindspore/runtime/__init__.py +8 -6
  362. mindspore/runtime/event.py +10 -4
  363. mindspore/runtime/executor.py +87 -45
  364. mindspore/runtime/memory.py +22 -30
  365. mindspore/runtime/thread_bind_core.py +299 -165
  366. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  367. mindspore/swresample-4.dll +0 -0
  368. mindspore/swscale-6.dll +0 -0
  369. mindspore/tbbmalloc.dll +0 -0
  370. mindspore/tinyxml2.dll +0 -0
  371. mindspore/train/_utils.py +9 -5
  372. mindspore/train/amp.py +43 -23
  373. mindspore/train/callback/__init__.py +5 -5
  374. mindspore/train/callback/_callback.py +2 -1
  375. mindspore/train/callback/_checkpoint.py +4 -14
  376. mindspore/train/callback/_flops_collector.py +11 -7
  377. mindspore/train/callback/_landscape.py +0 -1
  378. mindspore/train/callback/_train_fault_tolerance.py +72 -18
  379. mindspore/train/data_sink.py +15 -6
  380. mindspore/train/dataset_helper.py +14 -5
  381. mindspore/train/model.py +49 -47
  382. mindspore/train/serialization.py +168 -126
  383. mindspore/train/summary/summary_record.py +13 -2
  384. mindspore/train/train_thor/model_thor.py +2 -2
  385. mindspore/turbojpeg.dll +0 -0
  386. mindspore/utils/__init__.py +3 -2
  387. mindspore/utils/dryrun.py +0 -6
  388. mindspore/utils/runtime_execution_order_check.py +162 -78
  389. mindspore/utils/sdc_detect.py +68 -0
  390. mindspore/utils/utils.py +14 -17
  391. mindspore/vcmeta.dll +0 -0
  392. mindspore/vcruntime140.dll +0 -0
  393. mindspore/vcruntime140_1.dll +0 -0
  394. mindspore/version.py +1 -1
  395. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  396. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/RECORD +400 -439
  397. mindspore/_deprecated/jit.py +0 -198
  398. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  399. mindspore/communication/_hccl_management.py +0 -297
  400. mindspore/experimental/es/embedding_service.py +0 -891
  401. mindspore/experimental/es/embedding_service_layer.py +0 -581
  402. mindspore/profiler/common/validator/__init__.py +0 -14
  403. mindspore/profiler/common/validator/validate_path.py +0 -84
  404. mindspore/profiler/parser/__init__.py +0 -14
  405. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  406. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  407. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  408. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  409. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  410. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  411. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  412. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  413. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  414. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  415. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  416. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  417. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  418. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  419. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  420. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  421. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  422. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  423. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  424. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  425. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  426. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  427. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  428. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  429. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  430. mindspore/profiler/parser/container.py +0 -229
  431. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  432. mindspore/profiler/parser/flops_parser.py +0 -531
  433. mindspore/profiler/parser/framework_enum.py +0 -111
  434. mindspore/profiler/parser/framework_parser.py +0 -464
  435. mindspore/profiler/parser/framework_struct.py +0 -61
  436. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  437. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  438. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  439. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  440. mindspore/profiler/parser/hccl_parser.py +0 -573
  441. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  442. mindspore/profiler/parser/integrator.py +0 -526
  443. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  444. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  445. mindspore/profiler/parser/minddata_parser.py +0 -186
  446. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  447. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  448. mindspore/profiler/parser/optime_parser.py +0 -250
  449. mindspore/profiler/parser/profiler_info.py +0 -213
  450. mindspore/profiler/parser/step_trace_parser.py +0 -666
  451. mindspore/utils/hooks.py +0 -81
  452. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  453. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  454. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  455. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -15,6 +15,10 @@
15
15
  """cell"""
16
16
  from __future__ import absolute_import
17
17
 
18
+ __all__ = [
19
+ "register_cell_buffer_registration_hook",
20
+ ]
21
+
18
22
  import inspect
19
23
  import os
20
24
  import time
@@ -24,7 +28,6 @@ from collections import OrderedDict, namedtuple
24
28
  from typing import (
25
29
  Dict,
26
30
  Optional,
27
- Set,
28
31
  Callable,
29
32
  List,
30
33
  Tuple,
@@ -34,36 +37,30 @@ from typing import (
34
37
  Mapping
35
38
  )
36
39
 
40
+ import weakref
37
41
  import mindspore as ms
42
+ import mindspore.ops as ops
38
43
  from mindspore._checkparam import args_type_check, check_hook_fn
39
- from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
44
+ from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
40
45
  from mindspore import log as logger
41
- from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
42
- from mindspore.common.hook_handle import HookHandle
43
- from mindspore.context import ParallelMode
46
+ from mindspore.common.hook_handle import HookHandle, _update_hook_version
44
47
  from mindspore import context
45
48
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
46
49
  from mindspore import _checkparam as Validator
47
50
  from mindspore.common import dtype as mstype
48
51
  from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
49
- _no_grad
50
- from mindspore.common.api import _convert_python_data, _get_args_for_run_predict
52
+ _no_grad, _get_mutable_flags
53
+ from mindspore.common.api import _convert_python_data
51
54
  from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
52
- from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
55
+ from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple, _is_parameter_generated
53
56
  from mindspore.common.tensor import Tensor
54
- from mindspore.ops.operations import Cast
55
57
  from mindspore.ops.primitive import Primitive
56
58
  from mindspore.ops.operations import _inner_ops as inner
57
59
  from mindspore.parallel.shard import Shard
58
60
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
59
61
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
60
- from mindspore.common._decorator import deprecated
61
62
  from mindspore.common._register_for_recompute import recompute_registry
62
-
63
-
64
- __all__ = [
65
- "register_cell_buffer_registration_hook",
66
- ]
63
+ from mindspore.common.jit_config import JitConfig
67
64
 
68
65
  _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
69
66
  _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
@@ -96,13 +93,11 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
96
93
  A handle that can be used to remove the added hook by calling
97
94
  `handle.remove()`.
98
95
  """
99
- from mindspore.utils.hooks import _RemovableHandle
100
- handle = _RemovableHandle(_global_buffer_registration_hooks)
101
- _global_buffer_registration_hooks[handle.id] = hook
96
+ handle = HookHandle(_global_buffer_registration_hooks)
97
+ _global_buffer_registration_hooks[handle.handle_id] = hook
102
98
  return handle
103
99
 
104
100
 
105
-
106
101
  class Cell(Cell_):
107
102
  """
108
103
  The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
@@ -160,51 +155,59 @@ class Cell(Cell_):
160
155
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
161
156
  '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
162
157
  '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
163
- '_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
164
- '_attr_synced', 'pynative', 'requires_grad', 'cell_type',
165
- '_parameters_forward_hook', '_parameters_backward_hook']
158
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
159
+ 'requires_grad', 'cell_type', '_in_strategy', '_out_strategy']
166
160
  total_instance_count = 0
167
161
  _buffers: Dict[str, Optional[Tensor]]
168
- _non_persistent_buffers_set: Set[str]
162
+ global_cells = weakref.WeakKeyDictionary()
163
+ _no_auto_lazy_inline = True
164
+
165
+ def __new__(class_, *args, **kwargs):
166
+ # Use class_ to avoid name conflicts with input args and kwargs.
167
+ this = Cell_.__new__(class_, *args, **kwargs)
168
+ if Cell._no_auto_lazy_inline:
169
+ return this
170
+
171
+ Cell.global_cells[this] = (class_, args, kwargs)
172
+ return this
169
173
 
170
174
  def __init__(self, auto_prefix=True, flags=None):
171
175
  Cell_.__init__(self, self._cell_tag)
172
176
  Cell.total_instance_count += 1
173
- self.instance_count = Cell.total_instance_count
174
- self._params = OrderedDict()
175
- self._cells = OrderedDict()
177
+ super().__setattr__("_params", OrderedDict())
178
+ super().__setattr__("_cells", OrderedDict())
176
179
  super().__setattr__("_buffers", {})
177
- super().__setattr__("_non_persistent_buffers_set", set())
178
- super().__setattr__("_state_dict_hooks", OrderedDict())
179
- super().__setattr__("_state_dict_pre_hooks", OrderedDict())
180
- super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
181
- super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
182
- self._params_list = OrderedDict()
183
- self._primitives = OrderedDict()
184
- self.training = False
185
- self.requires_grad = False
186
- self.is_top_cell = False
187
- self.pynative = False
188
- self._attr_synced = False
189
- self._param_prefix = ''
190
- self._auto_prefix = auto_prefix
191
- self._scope = None
192
- self._phase = 'train'
193
- self._parameter_layout_dict = {}
194
- self._parallel_parameter_name_list = ()
195
- self._parallel_parameter_merge_net_dict = {}
196
- self._create_time = int(time.time() * 1e9)
197
- self.arguments_key = ""
198
- self.compile_cache = set()
199
- self.phase_cache = dict()
180
+ super().__setattr__("_params_list", OrderedDict())
181
+ super().__setattr__("_primitives", OrderedDict())
182
+
183
+ super().__setattr__("_lazy_non_persistent_buffers_set", None)
184
+ super().__setattr__("_lazy_state_dict_hooks", None)
185
+ super().__setattr__("_lazy_state_dict_pre_hooks", None)
186
+ super().__setattr__("_lazy_load_state_dict_pre_hooks", None)
187
+ super().__setattr__("_lazy_load_state_dict_post_hooks", None)
188
+ super().__setattr__("training", False)
189
+ super().__setattr__("requires_grad", False)
190
+ super().__setattr__("is_top_cell", False)
191
+ super().__setattr__("_param_prefix", '')
192
+ super().__setattr__("_auto_prefix", auto_prefix)
193
+ super().__setattr__("_scope", None)
194
+ super().__setattr__("_phase", 'train')
195
+ super().__setattr__("_parameter_layout_dict", None)
196
+ super().__setattr__("_parallel_parameter_name_list", None)
197
+ super().__setattr__("_parallel_parameter_merge_net_dict", None)
198
+ super().__setattr__("_create_time", int(time.time() * 1e9))
199
+ super().__setattr__("arguments_key", "")
200
+ super().__setattr__("_compile_cache", None)
201
+ super().__setattr__("_phase_cache", None)
200
202
  cells_compile_cache[id(self)] = self.compile_cache
201
- self.parameter_broadcast_done = False
202
- self._id = 1
203
- self._exist_objs = None
204
- self._exist_names = None
205
- self._recompute_cell = None
206
- self.mixed_precision_type = None
207
- self.sig = inspect.signature(self.construct)
203
+ super().__setattr__("_id", 1)
204
+ super().__setattr__("_exist_objs", None)
205
+ super().__setattr__("_exist_names", None)
206
+ super().__setattr__("_recompute_cell", None)
207
+ super().__setattr__("mixed_precision_type", None)
208
+ super().__setattr__("_lazy_construct_sig", None)
209
+ super().__setattr__("_jit_graph_name", '')
210
+ super().__setattr__("_compiled", False)
208
211
  init_pipeline()
209
212
 
210
213
  # call gc to release GE session resources used by non-used cell objects
@@ -214,38 +217,35 @@ class Cell(Cell_):
214
217
 
215
218
  if flags:
216
219
  self.add_flags(**flags)
217
- self._bprop_debug = False
220
+ super().__setattr__("_bprop_debug", False)
218
221
 
219
222
  # hook
220
- self._forward_pre_hook = OrderedDict()
221
- self._forward_hook = OrderedDict()
222
- self._backward_pre_hook = OrderedDict()
223
- self._cell_backward_pre_hook = None
224
- self._backward_hook = OrderedDict()
225
- self._cell_backward_hook = None
226
- self._is_recursion_hook = False
227
-
228
- # parameters hook
229
- self._parameters_forward_hook = None
230
- self._parameters_backward_hook = None
231
-
232
- self.cell_type = None
233
- self.cast = Cast()
234
- self._has_config_recompute = False
235
- self._user_parameters = []
236
- self._dynamic_shape_inputs = None
237
- self._compile_args = None
238
- self.saved_dynamic_shape = None
239
- self._jit_config_dict = dict()
240
- self.grad_ops_label = False
241
- self.ge_sync_data = False
242
- self._is_check_and_refresh = False
243
- self._amp_level = ""
244
- self._init_flag = False
245
- self._shard_fn = None
246
- self.has_bprop = False
223
+ super().__setattr__("_lazy_forward_pre_hook", None)
224
+ super().__setattr__("_lazy_forward_hook", None)
225
+ super().__setattr__("_lazy_backward_pre_hook", None)
226
+ super().__setattr__("_lazy_backward_hook", None)
227
+ super().__setattr__("_lazy_forward_pre_hook_with_kwargs", None)
228
+ super().__setattr__("_lazy_forward_hook_with_kwargs", None)
229
+ super().__setattr__("_cell_backward_pre_hook", None)
230
+ super().__setattr__("_cell_backward_hook", None)
231
+ super().__setattr__("_is_recursion_hook", False)
232
+
233
+ super().__setattr__("cell_type", None)
234
+ super().__setattr__("_has_config_recompute", False)
235
+ super().__setattr__("_lazy_user_parameters", None)
236
+ super().__setattr__("_dynamic_shape_inputs", None)
237
+ super().__setattr__("_has_mutable_args_list", None)
238
+ super().__setattr__("_jit_config_dict", dict())
239
+ super().__setattr__("grad_ops_label", False)
240
+ super().__setattr__("_is_check_and_refresh", False)
241
+ super().__setattr__("_amp_level", "")
242
+ super().__setattr__("_init_flag", False)
243
+ super().__setattr__("_shard_fn", None)
244
+ super().__setattr__("_in_strategy", None)
245
+ super().__setattr__("_out_strategy", None)
246
+ super().__setattr__("has_bprop", False)
247
247
  if hasattr(self, "bprop"):
248
- self.has_bprop = True
248
+ super().__setattr__("has_bprop", True)
249
249
 
250
250
  def __getstate__(self):
251
251
  base = Cell_.__getstate__(self)
@@ -255,7 +255,6 @@ class Cell(Cell_):
255
255
  base, dict_ = state
256
256
  Cell_.__setstate__(self, base)
257
257
  self.__dict__ = dict_
258
- self._attr_synced = False
259
258
 
260
259
  def __bool__(self):
261
260
  return True
@@ -269,6 +268,112 @@ class Cell(Cell_):
269
268
  def create_time(self):
270
269
  return self._create_time
271
270
 
271
+ @property
272
+ def _non_persistent_buffers_set(self):
273
+ """_non_persistent_buffers_set"""
274
+ if self._lazy_non_persistent_buffers_set is None:
275
+ super().__setattr__("_lazy_non_persistent_buffers_set", set())
276
+ return self._lazy_non_persistent_buffers_set
277
+
278
+ @property
279
+ def _state_dict_hooks(self):
280
+ """_state_dict_hooks"""
281
+ if self._lazy_state_dict_hooks is None:
282
+ super().__setattr__("_lazy_state_dict_hooks", OrderedDict())
283
+ return self._lazy_state_dict_hooks
284
+
285
+ @property
286
+ def _state_dict_pre_hooks(self):
287
+ """_state_dict_pre_hooks"""
288
+ if self._lazy_state_dict_pre_hooks is None:
289
+ super().__setattr__("_lazy_state_dict_pre_hooks", OrderedDict())
290
+ return self._lazy_state_dict_pre_hooks
291
+
292
+ @property
293
+ def _load_state_dict_pre_hooks(self):
294
+ """_load_state_dict_pre_hooks"""
295
+ if self._lazy_load_state_dict_pre_hooks is None:
296
+ super().__setattr__("_lazy_load_state_dict_pre_hooks", OrderedDict())
297
+ return self._lazy_load_state_dict_pre_hooks
298
+
299
+ @property
300
+ def _load_state_dict_post_hooks(self):
301
+ """_load_state_dict_post_hooks"""
302
+ if self._lazy_load_state_dict_post_hooks is None:
303
+ super().__setattr__("_lazy_load_state_dict_post_hooks", OrderedDict())
304
+ return self._lazy_load_state_dict_post_hooks
305
+
306
+ @property
307
+ def compile_cache(self):
308
+ """compile_cache"""
309
+ if self._compile_cache is None:
310
+ super().__setattr__("_compile_cache", set())
311
+ return self._compile_cache
312
+
313
+ @property
314
+ def phase_cache(self):
315
+ """phase_cache"""
316
+ if self._phase_cache is None:
317
+ super().__setattr__("_phase_cache", dict())
318
+ return self._phase_cache
319
+
320
+ @property
321
+ def _forward_pre_hook(self):
322
+ """_forward_pre_hook"""
323
+ if self._lazy_forward_pre_hook is None:
324
+ super().__setattr__("_lazy_forward_pre_hook", OrderedDict())
325
+ return self._lazy_forward_pre_hook
326
+
327
+ @property
328
+ def _forward_hook(self):
329
+ """_forward_hook"""
330
+ if self._lazy_forward_hook is None:
331
+ super().__setattr__("_lazy_forward_hook", OrderedDict())
332
+ return self._lazy_forward_hook
333
+
334
+ @property
335
+ def _backward_pre_hook(self):
336
+ """_backward_pre_hook"""
337
+ if self._lazy_backward_pre_hook is None:
338
+ super().__setattr__("_lazy_backward_pre_hook", OrderedDict())
339
+ return self._lazy_backward_pre_hook
340
+
341
+ @property
342
+ def _backward_hook(self):
343
+ """_backward_hook"""
344
+ if self._lazy_backward_hook is None:
345
+ super().__setattr__("_lazy_backward_hook", OrderedDict())
346
+ return self._lazy_backward_hook
347
+
348
+ @property
349
+ def _forward_pre_hook_with_kwargs(self):
350
+ """_backward_hook"""
351
+ if self._lazy_forward_pre_hook_with_kwargs is None:
352
+ super().__setattr__("_lazy_forward_pre_hook_with_kwargs", OrderedDict())
353
+ return self._lazy_forward_pre_hook_with_kwargs
354
+
355
+ @property
356
+ def _forward_hook_with_kwargs(self):
357
+ """_backward_hook"""
358
+ if self._lazy_forward_hook_with_kwargs is None:
359
+ super().__setattr__("_lazy_forward_hook_with_kwargs", OrderedDict())
360
+ return self._lazy_forward_hook_with_kwargs
361
+
362
+ @property
363
+ def _user_parameters(self):
364
+ """_user_parameters"""
365
+ if self._lazy_user_parameters is None:
366
+ super().__setattr__("_lazy_user_parameters", [])
367
+ return self._lazy_user_parameters
368
+
369
+ @_user_parameters.setter
370
+ def _user_parameters(self, value):
371
+ """_user_parameters"""
372
+ if not isinstance(value, list):
373
+ raise TypeError(f"For 'Cell', the property '_user_parameters' must be list type, "
374
+ f"but got type {type(value)}.")
375
+ self._lazy_user_parameters = value
376
+
272
377
  @property
273
378
  def cell_init_args(self):
274
379
  return self._cell_init_args
@@ -279,15 +384,21 @@ class Cell(Cell_):
279
384
  Get exist parameter names adding by tuple or list of parameter.
280
385
  """
281
386
  if self._exist_names is None:
282
- self._exist_names = set("")
387
+ super().__setattr__("_exist_names", set(""))
283
388
  return self._exist_names
284
389
 
285
390
  @property
286
391
  def exist_objs(self):
287
392
  if self._exist_objs is None:
288
- self._exist_objs = set()
393
+ super().__setattr__("_exist_objs", set())
289
394
  return self._exist_objs
290
395
 
396
+ @property
397
+ def _construct_sig(self):
398
+ if self._lazy_construct_sig is None:
399
+ super().__setattr__("_lazy_construct_sig", inspect.signature(self.construct))
400
+ return self._lazy_construct_sig
401
+
291
402
  @property
292
403
  def param_prefix(self):
293
404
  """
@@ -319,6 +430,13 @@ class Cell(Cell_):
319
430
  """
320
431
  return self._bprop_debug
321
432
 
433
+ @property
434
+ def compiled(self):
435
+ """
436
+ Get whether `Cell` is compiled in graph mode.
437
+ """
438
+ return self._compiled
439
+
322
440
  @bprop_debug.setter
323
441
  def bprop_debug(self, value):
324
442
  """
@@ -381,6 +499,8 @@ class Cell(Cell_):
381
499
  `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
382
500
  distributed operator information.
383
501
  """
502
+ if self._parameter_layout_dict is None:
503
+ super().__setattr__("_parameter_layout_dict", {})
384
504
  return self._parameter_layout_dict
385
505
 
386
506
  @property
@@ -396,6 +516,8 @@ class Cell(Cell_):
396
516
 
397
517
  @property
398
518
  def parallel_parameter_name_list(self):
519
+ if self._parallel_parameter_name_list is None:
520
+ super().__setattr__("_parallel_parameter_name_list", ())
399
521
  return self._parallel_parameter_name_list
400
522
 
401
523
  @parallel_parameter_name_list.setter
@@ -435,10 +557,23 @@ class Cell(Cell_):
435
557
 
436
558
  @property
437
559
  def pipeline_segment(self):
560
+ """
561
+ `pipeline_segment` represents the pipeline segment of current Cell.
562
+ """
438
563
  return self._pipeline_segment
439
564
 
440
565
  @pipeline_segment.setter
441
566
  def pipeline_segment(self, value):
567
+ """
568
+ Set the `pipeline_segment` of a Cell. Only effective in zero_bubble_v scheduler.
569
+
570
+ Args:
571
+ value (int): The pipeline segment of a parameter.
572
+
573
+ Raises:
574
+ TypeError: If `value` is not int type or is a bool type.
575
+ ValueError: If `value` is not a positive integer.
576
+ """
442
577
  if not isinstance(value, int) or isinstance(value, bool):
443
578
  raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
444
579
  "must be int type, but got type : {}".format(type(value)))
@@ -450,6 +585,8 @@ class Cell(Cell_):
450
585
 
451
586
  @property
452
587
  def parallel_parameter_merge_net_dict(self):
588
+ if self._parallel_parameter_merge_net_dict is None:
589
+ super().__setattr__("_parallel_parameter_merge_net_dict", {})
453
590
  return self._parallel_parameter_merge_net_dict
454
591
 
455
592
  @parallel_parameter_merge_net_dict.setter
@@ -867,6 +1004,7 @@ class Cell(Cell_):
867
1004
  if hasattr(self, "compile_cache") and self.compile_cache:
868
1005
  _cell_graph_executor.del_net_res(self, self.compile_cache)
869
1006
  Cell.total_instance_count -= 1
1007
+ Cell.global_cells.pop(self, None)
870
1008
 
871
1009
  def __delattr__(self, name):
872
1010
  if name in self._params:
@@ -879,47 +1017,15 @@ class Cell(Cell_):
879
1017
  del self._params_list[name]
880
1018
  else:
881
1019
  object.__delattr__(self, name)
882
- self._attr_synced = False
883
-
884
- def _cast_mixed_precision_inputs(self, inputs, dst_type):
885
- """Cast input for mixed precision"""
886
- res = list()
887
- for item in inputs:
888
- if isinstance(item, tuple):
889
- res.append(self._cast_mixed_precision_inputs(item, dst_type))
890
- elif isinstance(item, float):
891
- res.append(self.cast(item, dst_type))
892
- elif hasattr(item, "dtype") and item.dtype in \
893
- {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
894
- res.append(self.cast(item, dst_type))
895
- else:
896
- res.append(item)
897
- return tuple(res)
898
1020
 
899
1021
  def cast_inputs(self, inputs, dst_type):
900
1022
  """
901
1023
  Cast inputs to specified type.
902
1024
 
903
- Args:
904
- inputs (tuple[Tensor]): The cell inputs.
905
- dst_type (mindspore.dtype): The specified data type.
906
-
907
- returns:
908
- tuple[Tensor], the result with destination data type.
1025
+ .. warning::
1026
+ This interface will be deprecated in future versions.
909
1027
  """
910
- res = list()
911
- for item in inputs:
912
- if isinstance(item, tuple):
913
- res.append(self.cast_inputs(item, dst_type))
914
- else:
915
- res.append(self.cast(item, dst_type))
916
- return tuple(res)
917
-
918
- def _do_parameter_broadcast(self):
919
- if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
920
- if not self.parameter_broadcast_done:
921
- _pynative_executor.parameter_broadcast(self, self.phase)
922
- self.parameter_broadcast_done = True
1028
+ logger.warning(f"'cast_inputs' will be deprecated in future versions.")
923
1029
 
924
1030
  def run_construct(self, cast_inputs, kwargs):
925
1031
  """
@@ -940,30 +1046,34 @@ class Cell(Cell_):
940
1046
  output = self._run_construct(cast_inputs, kwargs)
941
1047
  return output
942
1048
 
943
- def _run_construct(self, *inputs, **kwargs):
1049
+ def _run_construct(self, *args, **kwargs):
944
1050
  """Run the construct function"""
945
1051
  if self._forward_pre_hook:
946
- inputs = self._run_forward_pre_hook(inputs)
1052
+ args, kwargs = self._run_forward_pre_hook(args, kwargs)
1053
+
1054
+ if self._backward_hook:
1055
+ args = self._cell_backward_hook(args)
947
1056
 
948
1057
  if self._shard_fn is not None:
949
- output = self._shard_fn(*inputs, **kwargs)
1058
+ output = self._shard_fn(*args, **kwargs)
950
1059
  elif _pynative_executor.requires_grad():
951
- if self._backward_hook:
952
- output = self._backward_hook_construct(*inputs, **kwargs)
953
- elif self._recompute_cell is not None:
954
- output = self._recompute_cell(*inputs, **kwargs)
1060
+ if self._recompute_cell is not None:
1061
+ output = self._recompute_cell(*args, **kwargs)
955
1062
  elif self.has_bprop:
956
- output = self._call_custom_bprop(*inputs, **kwargs)
1063
+ output = self._call_custom_bprop(*args, **kwargs)
957
1064
  else:
958
- output = self.construct(*inputs, **kwargs)
1065
+ output = self.construct(*args, **kwargs)
959
1066
  else:
960
- output = self.construct(*inputs, **kwargs)
1067
+ output = self.construct(*args, **kwargs)
961
1068
 
962
1069
  if self._forward_hook:
963
- output = self._run_forward_hook(inputs, output)
1070
+ output = self._run_forward_hook(args, kwargs, output)
1071
+
1072
+ if self._backward_hook:
1073
+ output = self._cell_backward_hook(output)
964
1074
 
965
1075
  if self._backward_pre_hook:
966
- output = self._run_backward_pre_hook(output)
1076
+ output = self._cell_backward_pre_hook(output)
967
1077
 
968
1078
  return output
969
1079
 
@@ -998,22 +1108,6 @@ class Cell(Cell_):
998
1108
  f"{default_args} default argument, total {positional_args + default_args}, "
999
1109
  f"but got {len(args)}.")
1000
1110
 
1001
- def _hook_fn_registered(self):
1002
- '''Hook function in graph mode'''
1003
- # Check super().__init__() in graph mode.
1004
- try:
1005
- if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
1006
- return True
1007
- except AttributeError as e:
1008
- raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
1009
- f"Please use 'super().__init__()'.") from e
1010
- if not self._is_recursion_hook:
1011
- self._is_recursion_hook = True
1012
- for cell in self.cells():
1013
- if cell._hook_fn_registered():
1014
- return True
1015
- return False
1016
-
1017
1111
  def _get_prims_recursively(self):
1018
1112
  all_prims = list()
1019
1113
  for _, value in self._primitives.items():
@@ -1039,9 +1133,6 @@ class Cell(Cell_):
1039
1133
  >>> net = nn.Dense(3, 4)
1040
1134
  >>> net.set_data_parallel()
1041
1135
  """
1042
- if context._get_mode() == context.PYNATIVE_MODE:
1043
- raise ValueError("set_data_parallel: does not support PyNative mode.")
1044
-
1045
1136
  all_prims = self._get_prims_recursively()
1046
1137
  for prim in all_prims:
1047
1138
  prim.add_prim_attr("strategy_gen_mode", "data_parallel")
@@ -1120,8 +1211,6 @@ class Cell(Cell_):
1120
1211
  ... out = self.blocks[i](out)
1121
1212
  ... return out
1122
1213
  """
1123
- if context._get_mode() == context.PYNATIVE_MODE:
1124
- raise ValueError("The Cell offload does not support PyNative mode now.")
1125
1214
  if isinstance(backward_prefetch, str):
1126
1215
  Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
1127
1216
  else:
@@ -1129,11 +1218,10 @@ class Cell(Cell_):
1129
1218
  for prim in self._get_prims_recursively():
1130
1219
  prim._offload(backward_prefetch=backward_prefetch)
1131
1220
 
1132
- def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
1221
+ def shard(self, in_strategy, out_strategy=None, parameter_plan=None):
1133
1222
  """
1134
1223
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
1135
- generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
1136
- execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
1224
+ generated by sharding propagation. In Graph mode, use this method to specify distribution strategy for a Cell,
1137
1225
  strategy for others will be set by sharding propagation.
1138
1226
  in_strategy and out_strategy define the input and output layout respectively.
1139
1227
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
@@ -1145,11 +1233,14 @@ class Cell(Cell_):
1145
1233
  In other parallel modes, strategies set here will be ignored.
1146
1234
  - If the input contain Parameter, its strategy should be set in `in_strategy`.
1147
1235
 
1236
+ .. warning::
1237
+ The method is currently not supported in PyNative mode.
1238
+
1148
1239
  Args:
1149
1240
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
1150
1241
  defines the layout of the corresponding input.
1151
1242
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
1152
- It is not in use right now. Default: ``None`` .
1243
+ Default: ``None`` .
1153
1244
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
1154
1245
  defines the layout of the parameter like "param_name: layout".
1155
1246
  The key is a parameter name of type 'str'.
@@ -1157,14 +1248,6 @@ class Cell(Cell_):
1157
1248
  If the parameter name is incorrect or the corresponding parameter
1158
1249
  has been set, the parameter setting will be ignored.
1159
1250
  Default: ``None`` .
1160
- device (str): Select a certain device target. It is not in use right now.
1161
- Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
1162
- level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
1163
- over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
1164
- use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
1165
-
1166
- Returns:
1167
- Function, return the cell construct function that will be executed under auto parallel process.
1168
1251
 
1169
1252
  Examples:
1170
1253
  >>> import mindspore.nn as nn
@@ -1182,40 +1265,34 @@ class Cell(Cell_):
1182
1265
  ... def __init__(self):
1183
1266
  ... self.block1 = Block()
1184
1267
  ... self.block2 = Block()
1185
- ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
1186
- ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
1268
+ ... self.block2.shard(in_strategy=((2, 1),), parameter_plan={'self.block2.dense1.weight': (4, 1)})
1187
1269
  ... def construct(self, x):
1188
1270
  ... x = self.block1(x)
1189
- ... x = self.block2_shard(x)
1271
+ ... x = self.block2(x)
1190
1272
  ... return x
1191
1273
  """
1192
1274
  if ms.communication.management.get_group_size() == 1:
1193
- return self
1194
- shard_fn = Shard()
1195
- fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
1196
- self._shard_fn = fn
1197
- return fn
1198
-
1199
- def auto_cast_inputs(self, inputs):
1200
- """
1201
- Auto cast inputs in mixed precision scenarios.
1202
-
1203
- Args:
1204
- inputs (tuple): the inputs of construct.
1205
-
1206
- Returns:
1207
- Tuple, the inputs after data type cast.
1208
- """
1209
- msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
1210
- logger.warning(msg)
1211
- cast_inputs = inputs
1212
- mixed_type = self.get_mixed_precision_type()
1213
- if mixed_type == MixedPrecisionType.FP16:
1214
- cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
1215
- if mixed_type == MixedPrecisionType.FP32:
1216
- cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
1275
+ return
1217
1276
 
1218
- return cast_inputs
1277
+ shard_fn = Shard()
1278
+ self._shard_fn = shard_fn(self, in_strategy, out_strategy, parameter_plan)
1279
+
1280
+ if self._in_strategy is not None: # pylint: disable=E0203
1281
+ msg = (
1282
+ "For '%s', 'Shard' has been configured more than once. "
1283
+ "The existing in_strategy is %s and the existing out_strategy is %s. "
1284
+ "The new in_strategy %s and out_strategy %s may not take effect. "
1285
+ "It is recommended to configure 'Shard' only once."
1286
+ ) % (
1287
+ self._cell_tag,
1288
+ self._in_strategy, # pylint: disable=E0203
1289
+ self._out_strategy, # pylint: disable=E0203
1290
+ shard_fn.in_strategy,
1291
+ shard_fn.out_strategy,
1292
+ )
1293
+ logger.warning(msg)
1294
+ self._in_strategy = shard_fn.in_strategy
1295
+ self._out_strategy = shard_fn.out_strategy
1219
1296
 
1220
1297
  def _init_check(self):
1221
1298
  for param in self.get_parameters(expand=False):
@@ -1224,15 +1301,25 @@ class Cell(Cell_):
1224
1301
  self._init_flag = True
1225
1302
 
1226
1303
  def _self_check(self):
1227
- if not self._is_check_and_refresh:
1228
- self.check_names_and_refresh_name()
1229
- self._is_check_and_refresh = True
1304
+ try:
1305
+ if not self._is_check_and_refresh: # pylint: disable=E0203
1306
+ self.check_names_and_refresh_name()
1307
+ self._is_check_and_refresh = True
1308
+ except AttributeError as e:
1309
+ raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
1310
+ f"Please use 'super().__init__()'.") from e
1230
1311
 
1231
1312
  def _predict(self, *args, **kwargs):
1313
+ '''Graph executor for predict'''
1232
1314
  if not hasattr(self, "phase"):
1233
1315
  return False, None
1234
1316
  if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
1235
- new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
1317
+ new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, True)
1318
+ if self.jit_config_dict:
1319
+ jit_config_dict = self.jit_config_dict
1320
+ else:
1321
+ jit_config_dict = JitConfig().jit_config_dict
1322
+ _cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
1236
1323
  res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
1237
1324
  res = _convert_python_data(res)
1238
1325
  return True, res
@@ -1241,8 +1328,9 @@ class Cell(Cell_):
1241
1328
  def __call__(self, *args, **kwargs):
1242
1329
  # Run in Graph mode.
1243
1330
  if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
1331
+ self._compiled = True
1244
1332
  if kwargs:
1245
- bound_arguments = self.sig.bind(*args, **kwargs)
1333
+ bound_arguments = self._construct_sig.bind(*args, **kwargs)
1246
1334
  bound_arguments.apply_defaults()
1247
1335
  args = bound_arguments.args
1248
1336
  kwargs = bound_arguments.kwargs
@@ -1251,11 +1339,8 @@ class Cell(Cell_):
1251
1339
  if predict_compiled:
1252
1340
  return res
1253
1341
  self._check_construct_args(*args)
1254
-
1255
- if self._hook_fn_registered():
1256
- logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
1257
- f"function, please use context.set_context to set pynative mode.")
1258
1342
  self._self_check()
1343
+ self.__compile_cell_hook__ = True
1259
1344
  out = self.compile_and_run(*args, **kwargs)
1260
1345
  return out
1261
1346
 
@@ -1324,37 +1409,12 @@ class Cell(Cell_):
1324
1409
  """
1325
1410
  with _no_grad():
1326
1411
  output = self.construct(*args, **kwargs)
1327
- _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
1328
- return output
1412
+ return _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
1329
1413
 
1330
1414
  def _add_attr(self, name, value):
1331
1415
  if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
1332
1416
  super(Cell, self)._add_attr(name, value)
1333
1417
 
1334
- def _sync_attr_for_compile(self):
1335
- """Sync the attr to c++ object."""
1336
- if self._attr_synced:
1337
- return
1338
- cells = self.__dict__.get('_cells')
1339
- for key in cells:
1340
- cell = cells[key]
1341
- cell._sync_attr_for_compile()
1342
- self._add_attr(key, cell)
1343
- params = self.__dict__.get('_params')
1344
- for key in params:
1345
- if '.' in key:
1346
- continue
1347
- param = params[key]
1348
- self._add_attr(key, param)
1349
- params_list = self.__dict__.get('_params_list')
1350
- for key in params_list:
1351
- params_list_item = params_list[key]
1352
- self._add_attr(key, params_list_item)
1353
- for key in self.__dict__:
1354
- value = self.__dict__[key]
1355
- self._add_attr(key, value)
1356
- self._attr_synced = True
1357
-
1358
1418
  def _set_attr_for_param_or_param_tuple(self, name, value):
1359
1419
  """Set attr for param and tensor."""
1360
1420
  if isinstance(value, Parameter):
@@ -1369,27 +1429,16 @@ class Cell(Cell_):
1369
1429
  # If there are multiple identical objects, their names only check once.
1370
1430
  continue
1371
1431
  exist_objs.add(item)
1372
- if item.name == PARAMETER_NAME_DEFAULT:
1373
- logger.warning("For 'Cell', the parameter definition is deprecated.\n"
1374
- "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
1375
- item.name = item.name + "$" + str(self._id)
1432
+ if _is_parameter_generated(item.name):
1433
+ item.name = "Parameter$" + str(self._id)
1376
1434
  self._id += 1
1377
- self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1378
1435
  if item.name in exist_names:
1379
1436
  raise ValueError("The value {} , its name '{}' already exists. "
1380
1437
  "Please set a unique name for the parameter.".format(value, item.name))
1381
1438
  exist_names.add(item.name)
1439
+ self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1382
1440
 
1383
- if context._get_mode() == context.PYNATIVE_MODE:
1384
- if name in self.__dict__:
1385
- del self.__dict__[name]
1386
- params = self.__dict__.get('_params')
1387
- if name in params:
1388
- del params[name]
1389
- params_list = self.__dict__.get('_params_list')
1390
- params_list[name] = value
1391
- else:
1392
- object.__setattr__(self, name, value)
1441
+ object.__setattr__(self, name, value)
1393
1442
 
1394
1443
  def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
1395
1444
  """Set attr for parameter in list or tuple."""
@@ -1398,9 +1447,6 @@ class Cell(Cell_):
1398
1447
  # If there are multiple identical objects, their names only check once.
1399
1448
  continue
1400
1449
  self.exist_objs.add(item)
1401
- if item.name == PARAMETER_NAME_DEFAULT:
1402
- item.name = item.name + "$" + str(self._id)
1403
- self._id += 1
1404
1450
  if item.name in self.exist_names:
1405
1451
  raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
1406
1452
  "Please set a unique name for the parameter.")
@@ -1513,24 +1559,6 @@ class Cell(Cell_):
1513
1559
  main_str += ")"
1514
1560
  return main_str
1515
1561
 
1516
- def load_parameter_slice(self, params):
1517
- """
1518
- Replace parameters with sliced tensors by parallel strategies.
1519
-
1520
- Note:
1521
- This interface is deprecated.
1522
- """
1523
- logger.warning("'load_parameter_slice' function is deprecated.")
1524
-
1525
- def set_parallel_input_with_inputs(self, *inputs):
1526
- """
1527
- Slice inputs tensors by parallel strategies.
1528
-
1529
- Note:
1530
- This interface is deprecated.
1531
- """
1532
- logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
1533
-
1534
1562
  def set_inputs(self, *inputs, **kwargs):
1535
1563
  """
1536
1564
  Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
@@ -1589,8 +1617,6 @@ class Cell(Cell_):
1589
1617
  _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1590
1618
  else:
1591
1619
  self._check_construct_args(*inputs)
1592
- # TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
1593
- # which means that incremental mode is lacking dynamic input.
1594
1620
  else:
1595
1621
  self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
1596
1622
 
@@ -1665,7 +1691,6 @@ class Cell(Cell_):
1665
1691
  _cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
1666
1692
  self._check_parameter_consistency(compile_args, args)
1667
1693
  Validator.check_symbolic_shape(compile_args, args)
1668
- self.saved_dynamic_shape = compile_args
1669
1694
  return compile_args
1670
1695
  return args
1671
1696
 
@@ -1678,8 +1703,9 @@ class Cell(Cell_):
1678
1703
  kwargs (dict): Kwargs of the Cell object.
1679
1704
  """
1680
1705
  _init_auto_parallel_context(self)
1681
- self._compile_args = self._get_compile_args(args)
1682
- _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1706
+ compile_args = self._get_compile_args(args)
1707
+ self._has_mutable_args_list = _get_mutable_flags(compile_args)
1708
+ _cell_graph_executor.compile(self, *compile_args, phase=self.phase,
1683
1709
  jit_config_dict=self._jit_config_dict, **kwargs)
1684
1710
  _clear_auto_parallel_context(self)
1685
1711
 
@@ -1698,25 +1724,14 @@ class Cell(Cell_):
1698
1724
  Object, the result of executing.
1699
1725
  """
1700
1726
  self.compile(*args, **kwargs)
1701
- self.add_flags(ge_sync_data=False)
1702
- new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
1727
+ new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, False)
1728
+ if self.jit_config_dict:
1729
+ jit_config_dict = self.jit_config_dict
1730
+ else:
1731
+ jit_config_dict = JitConfig().jit_config_dict
1732
+ _cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
1703
1733
  return _cell_graph_executor(self, *new_args, phase=self.phase)
1704
1734
 
1705
- def auto_parallel_compile_and_run(self):
1706
- """
1707
- Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
1708
-
1709
- Note:
1710
- This interface is deprecated.
1711
- """
1712
- logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
1713
-
1714
- def exec_checkpoint_graph(self):
1715
- """Executes GE saving checkpoint graph operation."""
1716
- logger.warning("'exec_checkpoint_graph' function is deprecated.")
1717
- self.add_flags(ge_sync_data=True)
1718
- _cell_graph_executor(self, phase='save')
1719
-
1720
1735
  def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
1721
1736
  """
1722
1737
  Adds a parameter to the current cell.
@@ -1762,35 +1777,10 @@ class Cell(Cell_):
1762
1777
  if not isinstance(param, Parameter) and param is not None:
1763
1778
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1764
1779
  f"but got {type(param)}.")
1765
- if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1780
+ if isinstance(param, Parameter) and _is_parameter_generated(param.name):
1766
1781
  param.name = param_name
1767
1782
  self._params[param_name] = param
1768
1783
 
1769
- def cast_param(self, param):
1770
- """
1771
- Cast parameter according to auto mix precision level in pynative mode.
1772
-
1773
- This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
1774
-
1775
- Args:
1776
- param (Parameter): Parameters, the type of which should be cast.
1777
-
1778
- Returns:
1779
- Parameter, the input parameter with type automatically cast.
1780
- """
1781
- msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
1782
- logger.warning(msg)
1783
- mixed_type = self.get_mixed_precision_type()
1784
- if mixed_type != MixedPrecisionType.NOTSET:
1785
- if mixed_type == MixedPrecisionType.FP32:
1786
- param.set_cast_dtype(mstype.float32)
1787
- elif mixed_type == MixedPrecisionType.FP16:
1788
- param.set_cast_dtype(mstype.float16)
1789
- elif hasattr(param, "set_cast_dtype"):
1790
- # retest dtype
1791
- param.set_cast_dtype()
1792
- return param
1793
-
1794
1784
  def insert_child_to_cell(self, child_name, child_cell):
1795
1785
  """
1796
1786
  Adds a child cell to the current cell with a given name.
@@ -1850,27 +1840,10 @@ class Cell(Cell_):
1850
1840
  """
1851
1841
  Remove the redundant parameters.
1852
1842
 
1853
- This interface usually needs not to be used explicitly.
1843
+ .. warning::
1844
+ This interface will be deprecated in future versions.
1854
1845
  """
1855
- cells = self.cells_and_names()
1856
- for _, cell in cells:
1857
- params = cell._params.items()
1858
- for param_name, param in list(params):
1859
- if param.name not in self.parallel_parameter_name_list:
1860
- cell._params.pop(param_name)
1861
- logger.info("remove the redundant parameter: %s", param.name)
1862
- continue
1863
- cell_dict = cell.__dict__
1864
- for key in cell_dict:
1865
- if isinstance(cell_dict[key], ParameterTuple):
1866
- param_tuple = cell_dict[key]
1867
- new_param_tuple = []
1868
- for param in param_tuple:
1869
- if param.name not in self.parallel_parameter_name_list:
1870
- logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
1871
- continue
1872
- new_param_tuple.append(param)
1873
- cell.__dict__[key] = ParameterTuple(new_param_tuple)
1846
+ logger.warning(f"'remove_redundant_parameters' will be deprecated in future versions.")
1874
1847
 
1875
1848
  def _get_cell_parallel_mode(self):
1876
1849
  """Determine whether the current cell is in parallel mode."""
@@ -1926,16 +1899,13 @@ class Cell(Cell_):
1926
1899
  # replace all original usage.
1927
1900
  cells = self.cells_and_names()
1928
1901
  is_parallel_mode = self._get_cell_parallel_mode()
1929
- is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
1930
1902
 
1931
1903
  for _, cell in cells:
1932
1904
  params = cell._params.items()
1933
1905
  for param_name, param in params:
1934
- not_sliced = not param.sliced
1935
- judgment = not_sliced
1936
1906
  if param.param_info.is_pipeline_shared_param:
1937
1907
  continue
1938
- if is_graph_mode and is_parallel_mode and judgment:
1908
+ if is_parallel_mode and not param.sliced:
1939
1909
  continue
1940
1910
  if not auto_parallel_mode:
1941
1911
  cell._params[param_name] = _updata(param)
@@ -1948,11 +1918,9 @@ class Cell(Cell_):
1948
1918
  param_tuple = cell_dict[key]
1949
1919
  new_param_tuple = []
1950
1920
  for param in param_tuple:
1951
- not_sliced = not param.sliced
1952
- judgment = not_sliced
1953
1921
  if param.param_info.is_pipeline_shared_param:
1954
1922
  continue
1955
- if is_graph_mode and is_parallel_mode and judgment:
1923
+ if is_parallel_mode and not param.sliced:
1956
1924
  continue
1957
1925
  if not auto_parallel_mode:
1958
1926
  new_param_tuple.append(_updata(param))
@@ -2591,15 +2559,6 @@ class Cell(Cell_):
2591
2559
  self.add_flags_recursive(broadcast_flag=mode)
2592
2560
  return self
2593
2561
 
2594
- def set_auto_parallel(self):
2595
- """
2596
- Set the cell to auto parallel mode.
2597
-
2598
- Note:
2599
- This interface is deprecated.
2600
- """
2601
- logger.warning("'set_auto_parallel' function is deprecated.")
2602
-
2603
2562
  def set_jit_config(self, jit_config):
2604
2563
  """
2605
2564
  Set jit config for cell.
@@ -2645,25 +2604,38 @@ class Cell(Cell_):
2645
2604
  raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
2646
2605
  Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
2647
2606
 
2648
- def register_forward_pre_hook(self, hook_fn):
2607
+ @jit_forbidden_register
2608
+ def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
2649
2609
  """
2650
2610
  Register forward pre hook function for Cell object.
2651
2611
 
2612
+ The hook will be called before :func:`mindspore.nn.Cell.construct` is invoked.
2613
+
2614
+ The hook function should be one of the following signatures:
2615
+
2616
+ - `hook_fn(cell, args) -> None or new_args` , when `with_kwargs` is ``Flase`` .
2617
+ - `hook_fn(cell, args, kwargs) -> None or (new_args, new_kwargs)` , when `with_kwargs` is ``True`` .
2618
+
2619
+ where:
2620
+
2621
+ - `cell` (Cell): Cell object on which the hook is registered.
2622
+ - `args` (tuple): Positional arguments passed to the `construct` function.
2623
+ - `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
2624
+ `with_kwargs` is ``True`` .
2625
+
2652
2626
  Note:
2653
- - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2654
- - 'hook_fn' must be defined as the following code.
2655
- `cell` is the object of registered Cell. `inputs` is the forward
2656
- input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
2657
- forward input objects.
2658
- - It should have the following signature:
2659
- hook_fn(cell, inputs) -> new input objects or none.
2660
- - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2661
- `construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
2662
- called in the `construct` function of the Cell object, a hook function will be added at each run time of
2663
- Cell object.
2627
+ - The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
2628
+ single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
2629
+ returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
2630
+ - In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
2631
+ `construct` function of Cell object.
2632
+ - In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
2633
+ `hook_fn` will be added at each run time of Cell object.
2664
2634
 
2665
2635
  Args:
2666
2636
  hook_fn (function): Python function. Forward pre hook function.
2637
+ with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
2638
+ function. Default: ``False`` .
2667
2639
 
2668
2640
  Returns:
2669
2641
  A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
@@ -2702,16 +2674,41 @@ class Cell(Cell_):
2702
2674
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2703
2675
  value= [ 2.00000000e+00]))
2704
2676
  """
2705
- if context._get_mode() == context.GRAPH_MODE:
2706
- return HookHandle()
2707
2677
  check_hook_fn(hook_fn)
2708
- handle = HookHandle(self._forward_pre_hook)
2678
+ handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
2709
2679
  self._forward_pre_hook[handle.handle_id] = hook_fn
2680
+ if with_kwargs:
2681
+ self._forward_pre_hook_with_kwargs[handle.handle_id] = True
2682
+ _update_hook_version()
2710
2683
  return handle
2711
2684
 
2712
- def _run_forward_pre_hook(self, inputs):
2685
+ @jit_forbidden_register
2686
+ def _run_forward_pre_hook(self, args, kwargs):
2713
2687
  """
2714
2688
  Running forward pre hook function registered on Cell object.
2689
+ """
2690
+ for hook_id, hook_fn in self._forward_pre_hook.items():
2691
+ if hook_id in self._forward_pre_hook_with_kwargs:
2692
+ ret = hook_fn(self, args, kwargs)
2693
+ if ret is not None:
2694
+ if isinstance(ret, tuple) and len(ret) == 2:
2695
+ args, kwargs = ret
2696
+ else:
2697
+ raise RuntimeError(
2698
+ "forward pre hook with kwargs must return None or a tuple of (new_args, new_kwargs), "
2699
+ f"but got {ret}"
2700
+ )
2701
+ else:
2702
+ ret = hook_fn(self, args)
2703
+ if ret is not None:
2704
+ if not isinstance(ret, tuple):
2705
+ ret = (ret,)
2706
+ args = ret
2707
+ return args, kwargs
2708
+
2709
+ def _jit_forward_pre_hook(self, inputs):
2710
+ """
2711
+ Compile forward pre hook function registered on Cell object.
2715
2712
 
2716
2713
  Args:
2717
2714
  inputs: The input objects of cell object.
@@ -2731,34 +2728,43 @@ class Cell(Cell_):
2731
2728
  else:
2732
2729
  forward_pre_hook_inputs = ret
2733
2730
 
2734
- if isinstance(inputs, tuple):
2735
- if not isinstance(forward_pre_hook_inputs, tuple):
2736
- forward_pre_hook_inputs = (forward_pre_hook_inputs,)
2737
- if len(forward_pre_hook_inputs) != len(inputs):
2738
- raise TypeError(
2739
- "The forward pre hook return value size is {} not equal to input size {}".format(
2740
- len(forward_pre_hook_inputs), len(inputs)))
2731
+ if len(forward_pre_hook_inputs) != len(inputs):
2732
+ raise TypeError(
2733
+ "The forward pre hook return value size is {} not equal to input size {}".format(
2734
+ len(forward_pre_hook_inputs), len(inputs)))
2741
2735
  return forward_pre_hook_inputs
2742
2736
 
2743
- def register_forward_hook(self, hook_fn):
2737
+ @jit_forbidden_register
2738
+ def register_forward_hook(self, hook_fn, with_kwargs=False):
2744
2739
  """
2745
- Set the Cell forward hook function.
2740
+ Register forward hook function for Cell object.
2741
+
2742
+ This hook will be called after :func:`mindspore.nn.Cell.construct` has computed an output.
2743
+
2744
+ The hook function should be one of the following signatures:
2745
+
2746
+ - `hook_fn(cell, args, output) -> None or new_output` , when `with_kwargs` is ``False`` .
2747
+ - `hook_fn(cell, args, kwargs, output) -> None or new_output` , when `with_kwargs` is ``True`` .
2748
+
2749
+ where:
2750
+
2751
+ - `cell` (Cell): Cell object on which the hook is registered.
2752
+ - `args` (tuple): Positional arguments passed to the `construct` function.
2753
+ - `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
2754
+ `with_kwargs` is ``True`` .
2755
+ - `output`: Output generated by the `construct` function.
2746
2756
 
2747
2757
  Note:
2748
- - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2749
- - 'hook_fn' must be defined as the following code.
2750
- `cell` is the object of registered Cell. `inputs` is the forward
2751
- input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
2752
- modify the forward output object by returning new forward output object.
2753
- - It should have the following signature:
2754
- hook_fn(cell, inputs, output) -> new output object or none.
2755
- - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2756
- `construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
2757
- called in the `construct` function of the Cell object, a hook function will be added at each run time of
2758
- Cell object.
2758
+ - The `hook_fn` can modify the forward outputs by returning new outputs.
2759
+ - In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
2760
+ `construct` function of Cell object.
2761
+ - In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
2762
+ `hook_fn` will be added at each run time of Cell object.
2759
2763
 
2760
2764
  Args:
2761
2765
  hook_fn (function): Python function. Forward hook function.
2766
+ with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
2767
+ function. Default: ``False`` .
2762
2768
 
2763
2769
  Returns:
2764
2770
  A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
@@ -2801,16 +2807,17 @@ class Cell(Cell_):
2801
2807
  """
2802
2808
  if self.has_bprop:
2803
2809
  return HookHandle()
2804
- if context._get_mode() == context.GRAPH_MODE:
2805
- return HookHandle()
2806
2810
  check_hook_fn(hook_fn)
2807
- handle = HookHandle(self._forward_hook)
2811
+ handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
2808
2812
  self._forward_hook[handle.handle_id] = hook_fn
2813
+ if with_kwargs:
2814
+ self._forward_hook_with_kwargs[handle.handle_id] = True
2815
+ _update_hook_version()
2809
2816
  return handle
2810
2817
 
2811
- def _run_forward_hook(self, inputs, output):
2818
+ def _jit_forward_hook(self, inputs, output):
2812
2819
  """
2813
- Running forward hook function registered on Cell object.
2820
+ Compile forward hook function registered on Cell object.
2814
2821
 
2815
2822
  Args:
2816
2823
  inputs: The input objects of Cell object.
@@ -2837,12 +2844,26 @@ class Cell(Cell_):
2837
2844
  len(forward_hook_output), len(output)))
2838
2845
  return forward_hook_output
2839
2846
 
2847
+ @jit_forbidden_register
2848
+ def _run_forward_hook(self, args, kwargs, output):
2849
+ """
2850
+ Running forward hook function registered on Cell object.
2851
+ """
2852
+ for hook_id, hook_fn in self._forward_hook.items():
2853
+ if hook_id in self._forward_hook_with_kwargs:
2854
+ ret = hook_fn(self, args, kwargs, output)
2855
+ else:
2856
+ ret = hook_fn(self, args, output)
2857
+ if ret is not None:
2858
+ output = ret
2859
+ return output
2860
+
2861
+ @jit_forbidden_register
2840
2862
  def register_backward_pre_hook(self, hook_fn):
2841
2863
  """
2842
2864
  Register the backward pre hook function.
2843
2865
 
2844
2866
  Note:
2845
- - The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2846
2867
  - The 'hook_fn' must be defined as the following code.
2847
2868
  `cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
2848
2869
  - The 'hook_fn' should have the following signature:
@@ -2891,44 +2912,17 @@ class Cell(Cell_):
2891
2912
  >>> print(output)
2892
2913
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2893
2914
  """
2894
- if context._get_mode() == context.GRAPH_MODE:
2895
- return HookHandle()
2896
2915
  check_hook_fn(hook_fn)
2897
- handle = HookHandle(self._backward_pre_hook)
2916
+ handle = HookHandle(self._backward_pre_hook, extra_dict=None)
2898
2917
  self._backward_pre_hook[handle.handle_id] = hook_fn
2899
- if self._cell_backward_pre_hook is None:
2918
+ if self._cell_backward_pre_hook is None: # pylint: disable=E0203
2900
2919
  # Generate a CellBackwardHook prim, and add function for it
2901
2920
  self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2902
2921
  self, self._backward_pre_hook)
2903
2922
  self._cell_backward_pre_hook.register_backward_pre_hook()
2923
+ _update_hook_version()
2904
2924
  return handle
2905
2925
 
2906
- def _run_backward_pre_hook(self, outputs):
2907
- """
2908
- Running backward pre hook function registered on Cell object.
2909
-
2910
- Args:
2911
- outputs: The output objects of cell object.
2912
-
2913
- Returns:
2914
- - **outputs** - New backward gradient or None.
2915
-
2916
- Supported Platforms:
2917
- ``Ascend`` ``GPU`` ``CPU``
2918
- """
2919
- if isinstance(outputs, tuple):
2920
- ret = self._cell_backward_pre_hook(*outputs)
2921
- else:
2922
- ret = self._cell_backward_pre_hook(outputs)
2923
- if isinstance(outputs, tuple):
2924
- if len(outputs) == 1:
2925
- ret = (ret,)
2926
- if len(ret) != len(outputs):
2927
- raise TypeError(
2928
- "The backward pre hook return value size is {} not equal to output size {}".format(
2929
- len(ret), len(outputs)))
2930
- return ret
2931
-
2932
2926
  def get_extra_state(self) -> Any:
2933
2927
  """Return any extra state to include in the cell's state_dict.
2934
2928
 
@@ -2981,9 +2975,8 @@ class Cell(Cell_):
2981
2975
  A handle that can be used to remove the added hook by calling
2982
2976
  `handle.remove()`.
2983
2977
  """
2984
- from mindspore.utils.hooks import _RemovableHandle
2985
- handle = _RemovableHandle(self._state_dict_hooks)
2986
- self._state_dict_hooks[handle.id] = hook
2978
+ handle = HookHandle(self._state_dict_hooks)
2979
+ self._state_dict_hooks[handle.handle_id] = hook
2987
2980
  return handle
2988
2981
 
2989
2982
  @jit_forbidden_register
@@ -3029,9 +3022,8 @@ class Cell(Cell_):
3029
3022
  >>> print("extra_param" in net_state_dict)
3030
3023
  True
3031
3024
  """
3032
- from mindspore.utils.hooks import _RemovableHandle
3033
- handle = _RemovableHandle(self._state_dict_pre_hooks)
3034
- self._state_dict_pre_hooks[handle.id] = hook
3025
+ handle = HookHandle(self._state_dict_pre_hooks)
3026
+ self._state_dict_pre_hooks[handle.handle_id] = hook
3035
3027
  return handle
3036
3028
 
3037
3029
  def _save_to_state_dict(self, destination, prefix, keep_vars):
@@ -3116,7 +3108,6 @@ class Cell(Cell_):
3116
3108
  OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3117
3109
  ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3118
3110
  """
3119
- # TODO: Remove `args` and the parsing logic when BC allows.
3120
3111
  if args:
3121
3112
  # DeprecationWarning is ignored by default
3122
3113
  warnings.warn(
@@ -3169,7 +3160,7 @@ class Cell(Cell_):
3169
3160
 
3170
3161
  It should have the following signature:
3171
3162
 
3172
- hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
3163
+ hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None
3173
3164
 
3174
3165
  Args:
3175
3166
  hook (Callable): The hook function before `load_state_dict` is called.
@@ -3178,9 +3169,8 @@ class Cell(Cell_):
3178
3169
  A handle that can be used to remove the added hook by calling
3179
3170
  `handle.remove()`.
3180
3171
  """
3181
- from mindspore.utils.hooks import _RemovableHandle
3182
- handle = _RemovableHandle(self._load_state_dict_pre_hooks)
3183
- self._load_state_dict_pre_hooks[handle.id] = hook
3172
+ handle = HookHandle(self._load_state_dict_pre_hooks)
3173
+ self._load_state_dict_pre_hooks[handle.handle_id] = hook
3184
3174
  return handle
3185
3175
 
3186
3176
  @jit_forbidden_register
@@ -3212,9 +3202,8 @@ class Cell(Cell_):
3212
3202
  A handle that can be used to remove the added hook by calling
3213
3203
  `handle.remove()`.
3214
3204
  """
3215
- from mindspore.utils.hooks import _RemovableHandle
3216
- handle = _RemovableHandle(self._load_state_dict_post_hooks)
3217
- self._load_state_dict_post_hooks[handle.id] = hook
3205
+ handle = HookHandle(self._load_state_dict_post_hooks)
3206
+ self._load_state_dict_post_hooks[handle.handle_id] = hook
3218
3207
  return handle
3219
3208
 
3220
3209
  def _load_from_state_dict(
@@ -3450,12 +3439,12 @@ class Cell(Cell_):
3450
3439
  )
3451
3440
  return _IncompatibleKeys(missing_keys, unexpected_keys)
3452
3441
 
3442
+ @jit_forbidden_register
3453
3443
  def register_backward_hook(self, hook_fn):
3454
3444
  """
3455
3445
  Register the backward hook function.
3456
3446
 
3457
3447
  Note:
3458
- - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
3459
3448
  - The 'hook_fn' must be defined as the following code.
3460
3449
  `cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
3461
3450
  the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
@@ -3507,65 +3496,17 @@ class Cell(Cell_):
3507
3496
  >>> print(output)
3508
3497
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
3509
3498
  """
3510
- if context._get_mode() == context.GRAPH_MODE:
3511
- return HookHandle()
3512
3499
  check_hook_fn(hook_fn)
3513
- handle = HookHandle(self._backward_hook)
3500
+ handle = HookHandle(self._backward_hook, extra_dict=None)
3514
3501
  self._backward_hook[handle.handle_id] = hook_fn
3515
- if self._cell_backward_hook is None:
3502
+ if self._cell_backward_hook is None: # pylint: disable=E0203
3516
3503
  # Generate a CellBackwardHook prim, and add function for it
3517
3504
  self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
3518
3505
  self, self._backward_hook)
3519
3506
  self._cell_backward_hook.register_backward_hook()
3507
+ _update_hook_version()
3520
3508
  return handle
3521
3509
 
3522
- def _backward_hook_construct(self, *inputs, **kwargs):
3523
- """
3524
- Backward hook construct method to replace original construct method.
3525
-
3526
- Args:
3527
- inputs: The input objects of Cell object.
3528
- kwargs (dict): Dictionary of variable keyword parameters.
3529
-
3530
- Returns:
3531
- - **outputs** - The output objects of Cell object.
3532
-
3533
- Supported Platforms:
3534
- ``Ascend`` ``GPU`` ``CPU``
3535
- """
3536
- # cell_backward_hook has CellBackwardHook op, so keep input args as they are.
3537
- outputs = self._cell_backward_hook(*inputs)
3538
- # If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
3539
- # a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
3540
- # Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
3541
- is_need_unwrap = False
3542
- if isinstance(outputs, tuple) and len(inputs) != 1:
3543
- is_need_unwrap = True
3544
-
3545
- if self._recompute_cell is not None:
3546
- if is_need_unwrap:
3547
- outputs = self._recompute_cell(*outputs, **kwargs)
3548
- else:
3549
- outputs = self._recompute_cell(outputs, **kwargs)
3550
- elif self.has_bprop:
3551
- if is_need_unwrap:
3552
- outputs = self._call_custom_bprop(*outputs, **kwargs)
3553
- else:
3554
- outputs = self._call_custom_bprop(outputs, **kwargs)
3555
- else:
3556
- if is_need_unwrap:
3557
- outputs = self.construct(*outputs, **kwargs)
3558
- else:
3559
- outputs = self.construct(outputs, **kwargs)
3560
- if isinstance(outputs, tuple):
3561
- new_outputs = self._cell_backward_hook(*outputs)
3562
- else:
3563
- new_outputs = self._cell_backward_hook(outputs)
3564
- # if outputs is (X,) and new_outpus is X
3565
- if isinstance(outputs, tuple) and len(outputs) == 1:
3566
- new_outputs = (new_outputs,)
3567
- return new_outputs
3568
-
3569
3510
  def set_param_ps(self, recurse=True, init_in_server=False):
3570
3511
  """
3571
3512
  Set whether the trainable parameters are updated by parameter server and whether the
@@ -3584,12 +3525,6 @@ class Cell(Cell_):
3584
3525
  for param in params:
3585
3526
  param.set_param_ps(init_in_server)
3586
3527
 
3587
- @deprecated("1.8", "set_param_fl")
3588
- def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
3589
- params = self.parameters_and_names()
3590
- for param in params:
3591
- param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
3592
-
3593
3528
  def set_comm_fusion(self, fusion_type, recurse=True):
3594
3529
  """
3595
3530
  Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
@@ -3650,7 +3585,7 @@ class Cell(Cell_):
3650
3585
  """
3651
3586
  Validator.check_bool(mode)
3652
3587
  Validator.check_bool(output_recompute)
3653
- if not self._has_config_recompute:
3588
+ if not self._has_config_recompute: # pylint: disable=E0203
3654
3589
  self._has_config_recompute = True
3655
3590
  else:
3656
3591
  logger.info("The recompute interface can be configured only once."
@@ -3693,12 +3628,12 @@ class Cell(Cell_):
3693
3628
  introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
3694
3629
  Default: ``False`` .
3695
3630
  """
3696
- if context.get_context("mode") == context.PYNATIVE_MODE:
3631
+ if context._get_mode() == context.PYNATIVE_MODE:
3697
3632
  self._recompute_cell = recompute_registry.get()(self.construct)
3698
3633
  self._recompute()
3699
3634
  if 'mp_comm_recompute' in kwargs.keys():
3700
3635
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
3701
- if 'parallel_optimizer_comm_recompute' in kwargs.keys():
3636
+ if 'parallel_optimizer_comm_recompute' in kwargs:
3702
3637
  if kwargs.get('parallel_optimizer_comm_recompute', False):
3703
3638
  logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
3704
3639
  "is replaced with zero3.")
@@ -3711,38 +3646,6 @@ class Cell(Cell_):
3711
3646
  "the key kwargs must be 'mp_comm_recompute', "
3712
3647
  "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
3713
3648
 
3714
- @deprecated("2.3", "infer_param_pipeline_stage")
3715
- def infer_param_pipeline_stage(self):
3716
- """
3717
- Infer pipeline stages of all parameters in the cell.
3718
-
3719
- Note:
3720
- - The interface is deprecated from version 2.3 and will be removed in a future version.
3721
-
3722
- Returns:
3723
- The params belong to current stage in pipeline parallel.
3724
-
3725
- Raises:
3726
- RuntimeError: If there is a parameter does not belong to any stage.
3727
- """
3728
- from mindspore.parallel._utils import _get_global_rank, _get_device_num
3729
- logger.warning(f"This interface may be deleted in the future.")
3730
- stage_num = context.get_auto_parallel_context("pipeline_stages")
3731
- device_num = _get_device_num()
3732
- rank_id = _get_global_rank()
3733
- per_stage_devices = device_num // stage_num
3734
- current_stage = rank_id // per_stage_devices
3735
- params = []
3736
- for param in self.trainable_params():
3737
- if not param._pipeline_stage_list: # pylint: disable=W0212
3738
- raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, "
3739
- "please check whether the cell where the param locates has been set "
3740
- "'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' "
3741
- "to add its stage information".format(param.name))
3742
- if current_stage in param._pipeline_stage_list:
3743
- params.append(param)
3744
- return params
3745
-
3746
3649
  def place(self, role, rank_id):
3747
3650
  """
3748
3651
  Set the label for all operators in this cell.
@@ -3772,19 +3675,6 @@ class Cell(Cell_):
3772
3675
  for op in all_ops:
3773
3676
  op.place(role, rank_id)
3774
3677
 
3775
- def _mixed_precision_cast(self, inputs):
3776
- mixed_type = self.get_mixed_precision_type()
3777
- if mixed_type == MixedPrecisionType.NOTSET:
3778
- return inputs
3779
- if mixed_type == MixedPrecisionType.FP16:
3780
- cast_type = mstype.float16
3781
- elif mixed_type == MixedPrecisionType.BF16:
3782
- cast_type = mstype.bfloat16
3783
- else:
3784
- cast_type = mstype.float32
3785
- cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
3786
- return cast_inputs
3787
-
3788
3678
  def _get_attr_from_cell(self, network):
3789
3679
  if not isinstance(network, Cell):
3790
3680
  return
@@ -3793,92 +3683,70 @@ class Cell(Cell_):
3793
3683
  if hasattr(network, "_amp_level"):
3794
3684
  self._amp_level = getattr(network, "_amp_level")
3795
3685
 
3796
- def _register_parameters_hook(self, forward_hook=None, backward_hook=None, all=False):
3686
+ def _set_jit_graph_name(self, key):
3687
+ """
3688
+ Set jit graph name.
3797
3689
  """
3798
- Register the forward hook for parameters and register the backward hook for the corresponding gradient.
3690
+ self._jit_graph_name = key
3799
3691
 
3800
- .. warning::
3801
- This is an experimental prototype that is subject to change and/or deletion.
3692
+ def _jit_backward_pre_hook(self, grad_output):
3693
+ new_grad_output = grad_output
3694
+ if not isinstance(grad_output, tuple):
3695
+ new_grad_output = (grad_output,)
3802
3696
 
3803
- Note:
3804
- - The `_register_parameters_hook(forward_hook, backward_hook)` only work in graph mode
3805
- - The `forward_hook` must be defined as the following code.
3806
- `parameters`: the tuple of the trainble parameters of the Cell, each element in the tuple shuould be
3807
- in the format of `(param_name, Parameter)`.
3808
- - The `forward_hook` should have the following signature:
3809
- forward_hook(parameters) -> None.
3810
- - The `backward_hook` must be defined as the following code.
3811
- `gradients`: the tuple of the gradients corresponding to the trainble parameters of the Cell, each
3812
- element in the tuple shuould be in the format of `(param_name, gradient)`.
3813
- - The `backward_hook` should have the following signature:
3814
- backward_hook(parameters) -> New gradients.
3697
+ for fn in self._backward_pre_hook.values():
3698
+ ret = fn(self, new_grad_output)
3699
+ if ret is not None:
3700
+ if not isinstance(ret, tuple):
3701
+ output = (ret,)
3702
+ else:
3703
+ output = ret
3704
+ else:
3705
+ output = ops.Depend()(new_grad_output, ret)
3706
+ new_grad_output = output
3815
3707
 
3816
- Args:
3817
- forward_hook (function, optional): Python function or ``None``, Forward hook function. Default: ``None``
3818
- backward_hook (function, optional): Python function or ``None``, Backward hook function. Default ``None``
3819
- all (bool, optional): bool, whether to set hooks for all sub cells recursively. Default: ``False``
3708
+ if not isinstance(grad_output, tuple):
3709
+ if len(new_grad_output) == 1:
3710
+ return new_grad_output[0]
3711
+ raise TypeError(
3712
+ "The backward pre hook return value size is {} not equal to input size 1".format(
3713
+ len(new_grad_output)))
3820
3714
 
3821
- Returns:
3822
- None
3715
+ if len(new_grad_output) != len(grad_output):
3716
+ raise TypeError(
3717
+ "The backward pre hook return value size is {} not equal to input size {}".format(
3718
+ len(new_grad_output), len(grad_output)))
3823
3719
 
3824
- Raises:
3825
- RuntimeError: If the `forward_hook` or `backward_hook ` has unspoorted syntax under GRAPH MODE.
3826
- TypeError: If the `forward_hook` or `backward_hook` is not defined as required.
3720
+ return new_grad_output
3827
3721
 
3828
- Supported Platforms:
3829
- ``Ascend`` ``GPU`` ``CPU``
3722
+ def _jit_backward_hook(self, grad_input, grad_output):
3723
+ backward_hook_input = grad_input
3724
+ backward_hook_output = grad_output
3725
+ if not isinstance(grad_input, tuple):
3726
+ backward_hook_input = (grad_input,)
3727
+ if not isinstance(grad_output, tuple):
3728
+ backward_hook_output = (grad_output,)
3830
3729
 
3831
- Examples:
3832
- >>> import mindspore as ms
3833
- >>> from mindspore import Tensor, nn, ops, Parameter
3834
- >>>
3835
- >>> ms.set_context(mode=ms.GRAPH_MODE)
3836
- >>> def parameter_hook(parameters):
3837
- ... print("--- enter parameter hook ---")
3838
- ... for name, param in parameters:
3839
- ... print (name, param)
3840
- ... print("--- leave parameter hook ---")
3841
- ...
3842
- >>> def gradient_hook(gradients):
3843
- ... print("--- enter gradient hook ---")
3844
- ... outs = []
3845
- ... for name, gradient in gradients:
3846
- ... print(name, gradient)
3847
- ... outs.append(gradient * 2) # double gradient
3848
- ... print("--- leave gradient hook ---")
3849
- ... return outs
3850
- ...
3851
- >>> class Net(nn.Cell):
3852
- ... def __init__(self)
3853
- ... super(Net, self).__init__()
3854
- ... self.w = Parameter(Tensor(np.array([3.0], np.float32)), name='w')
3855
- ... def construct(self, x):
3856
- ... return self.w * x
3857
- ...
3858
- >>> grad = ops.GradOperation(get_by_list=True)
3859
- >>> net = Net()
3860
- >>> net._register_parameters_hook(forward_hook=parameter_hook, backward_hook=gradient_hook)
3861
- >>> x = Tensor(np.array([4.0]).astype(np.float32))
3862
- >>> output = grad(net, net.trainable_params())(x)
3863
- --- enter parameter hook ---
3864
- w
3865
- Tensor(shape=[1], dtype=Float32, value=[ 3.00000000e+00])
3866
- --- leave parameter hook ---
3867
- --- enter gradient hook ---
3868
- w
3869
- Tensor(shape=[1], dtype=Float32, value=[ 4.00000000e+00])
3870
- --- leave gradient hook ---
3871
- >>> print("doubled grad: ", output)
3872
- doubled grad: (Tensor(shape=[1], dtype=Float32, value=[ 8.00000000e+00]),)
3873
- """
3874
- if not all:
3875
- self._parameters_forward_hook = forward_hook
3876
- self._parameters_backward_hook = backward_hook
3877
- else:
3878
- for _, cell in self.cells_and_names():
3879
- cell._parameters_forward_hook = forward_hook
3880
- cell._parameters_backward_hook = backward_hook
3730
+ for fn in self._backward_hook.values():
3731
+ ret = fn(self, backward_hook_input, backward_hook_output)
3732
+ if ret is not None:
3733
+ if not isinstance(ret, tuple):
3734
+ output = (ret,)
3735
+ else:
3736
+ output = ret
3737
+ else:
3738
+ output = ops.Depend()(backward_hook_input, ret)
3739
+
3740
+ backward_hook_input = output
3881
3741
 
3742
+ if not isinstance(grad_input, tuple):
3743
+ return backward_hook_input[0]
3744
+
3745
+ if len(backward_hook_input) != len(grad_input):
3746
+ raise TypeError(
3747
+ "The backward hook return value size is {} not equal to input size {}".format(
3748
+ len(backward_hook_input), len(grad_input)))
3749
+ return backward_hook_input
3882
3750
 
3883
3751
  class GraphCell(Cell):
3884
3752
  """