mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (285) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/mindir_util.py +2 -2
  26. mindspore/common/parameter.py +46 -13
  27. mindspore/common/recompute.py +39 -9
  28. mindspore/common/sparse_tensor.py +7 -3
  29. mindspore/common/tensor.py +209 -29
  30. mindspore/communication/__init__.py +1 -1
  31. mindspore/communication/_comm_helper.py +38 -3
  32. mindspore/communication/comm_func.py +310 -55
  33. mindspore/communication/management.py +14 -14
  34. mindspore/context.py +123 -22
  35. mindspore/dataset/__init__.py +1 -1
  36. mindspore/dataset/audio/__init__.py +1 -1
  37. mindspore/dataset/core/config.py +7 -0
  38. mindspore/dataset/core/validator_helpers.py +7 -0
  39. mindspore/dataset/engine/cache_client.py +1 -1
  40. mindspore/dataset/engine/datasets.py +72 -44
  41. mindspore/dataset/engine/datasets_audio.py +7 -7
  42. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  43. mindspore/dataset/engine/datasets_text.py +20 -20
  44. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  45. mindspore/dataset/engine/datasets_vision.py +33 -33
  46. mindspore/dataset/engine/iterators.py +29 -0
  47. mindspore/dataset/engine/obs/util.py +7 -0
  48. mindspore/dataset/engine/queue.py +114 -60
  49. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  50. mindspore/dataset/engine/validators.py +34 -14
  51. mindspore/dataset/text/__init__.py +1 -4
  52. mindspore/dataset/transforms/__init__.py +0 -3
  53. mindspore/dataset/utils/line_reader.py +2 -0
  54. mindspore/dataset/vision/__init__.py +1 -4
  55. mindspore/dataset/vision/utils.py +1 -1
  56. mindspore/dataset/vision/validators.py +2 -1
  57. mindspore/dnnl.dll +0 -0
  58. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  59. mindspore/experimental/es/embedding_service.py +883 -0
  60. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  61. mindspore/experimental/llm_boost/__init__.py +21 -0
  62. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  63. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  64. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  65. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  66. mindspore/experimental/llm_boost/register.py +129 -0
  67. mindspore/experimental/llm_boost/utils.py +31 -0
  68. mindspore/experimental/optim/adamw.py +85 -0
  69. mindspore/experimental/optim/optimizer.py +3 -0
  70. mindspore/hal/__init__.py +3 -3
  71. mindspore/hal/contiguous_tensors_handle.py +175 -0
  72. mindspore/hal/stream.py +18 -0
  73. mindspore/include/api/model_group.h +13 -1
  74. mindspore/include/api/types.h +10 -10
  75. mindspore/include/dataset/config.h +2 -2
  76. mindspore/include/dataset/constants.h +2 -2
  77. mindspore/include/dataset/execute.h +2 -2
  78. mindspore/include/dataset/vision.h +4 -0
  79. mindspore/jpeg62.dll +0 -0
  80. mindspore/log.py +1 -1
  81. mindspore/mindrecord/filewriter.py +68 -51
  82. mindspore/mindspore_backend.dll +0 -0
  83. mindspore/mindspore_common.dll +0 -0
  84. mindspore/mindspore_core.dll +0 -0
  85. mindspore/mindspore_glog.dll +0 -0
  86. mindspore/mindspore_np_dtype.dll +0 -0
  87. mindspore/mindspore_ops.dll +0 -0
  88. mindspore/mint/__init__.py +495 -46
  89. mindspore/mint/distributed/__init__.py +31 -0
  90. mindspore/mint/distributed/distributed.py +254 -0
  91. mindspore/mint/nn/__init__.py +266 -21
  92. mindspore/mint/nn/functional.py +125 -19
  93. mindspore/mint/nn/layer/__init__.py +39 -0
  94. mindspore/mint/nn/layer/activation.py +133 -0
  95. mindspore/mint/nn/layer/normalization.py +477 -0
  96. mindspore/mint/nn/layer/pooling.py +110 -0
  97. mindspore/mint/optim/adamw.py +28 -7
  98. mindspore/mint/special/__init__.py +63 -0
  99. mindspore/multiprocessing/__init__.py +2 -1
  100. mindspore/nn/__init__.py +0 -1
  101. mindspore/nn/cell.py +275 -93
  102. mindspore/nn/layer/activation.py +211 -44
  103. mindspore/nn/layer/basic.py +113 -3
  104. mindspore/nn/layer/embedding.py +120 -2
  105. mindspore/nn/layer/normalization.py +101 -5
  106. mindspore/nn/layer/padding.py +34 -48
  107. mindspore/nn/layer/pooling.py +161 -7
  108. mindspore/nn/layer/transformer.py +3 -3
  109. mindspore/nn/loss/__init__.py +2 -2
  110. mindspore/nn/loss/loss.py +84 -6
  111. mindspore/nn/optim/__init__.py +2 -1
  112. mindspore/nn/optim/adadelta.py +1 -1
  113. mindspore/nn/optim/adam.py +1 -1
  114. mindspore/nn/optim/lamb.py +1 -1
  115. mindspore/nn/optim/tft_wrapper.py +127 -0
  116. mindspore/nn/wrap/cell_wrapper.py +12 -23
  117. mindspore/nn/wrap/grad_reducer.py +5 -5
  118. mindspore/nn/wrap/loss_scale.py +17 -3
  119. mindspore/numpy/__init__.py +1 -1
  120. mindspore/numpy/array_creations.py +65 -68
  121. mindspore/numpy/array_ops.py +64 -60
  122. mindspore/numpy/fft.py +610 -75
  123. mindspore/numpy/logic_ops.py +11 -10
  124. mindspore/numpy/math_ops.py +85 -84
  125. mindspore/numpy/utils_const.py +4 -4
  126. mindspore/opencv_core452.dll +0 -0
  127. mindspore/opencv_imgcodecs452.dll +0 -0
  128. mindspore/opencv_imgproc452.dll +0 -0
  129. mindspore/ops/__init__.py +6 -4
  130. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  131. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  132. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  133. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  134. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  135. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  136. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  137. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  138. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  139. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  140. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  141. mindspore/ops/composite/base.py +85 -48
  142. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  143. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  144. mindspore/ops/function/__init__.py +22 -0
  145. mindspore/ops/function/array_func.py +490 -153
  146. mindspore/ops/function/debug_func.py +113 -1
  147. mindspore/ops/function/fft_func.py +15 -2
  148. mindspore/ops/function/grad/grad_func.py +3 -2
  149. mindspore/ops/function/math_func.py +558 -207
  150. mindspore/ops/function/nn_func.py +817 -383
  151. mindspore/ops/function/other_func.py +3 -2
  152. mindspore/ops/function/random_func.py +184 -8
  153. mindspore/ops/function/reshard_func.py +13 -11
  154. mindspore/ops/function/sparse_unary_func.py +1 -1
  155. mindspore/ops/function/vmap_func.py +3 -2
  156. mindspore/ops/functional.py +24 -14
  157. mindspore/ops/op_info_register.py +3 -3
  158. mindspore/ops/operations/__init__.py +6 -1
  159. mindspore/ops/operations/_grad_ops.py +2 -76
  160. mindspore/ops/operations/_infer_ops.py +1 -1
  161. mindspore/ops/operations/_inner_ops.py +71 -94
  162. mindspore/ops/operations/array_ops.py +12 -146
  163. mindspore/ops/operations/comm_ops.py +42 -53
  164. mindspore/ops/operations/custom_ops.py +83 -19
  165. mindspore/ops/operations/debug_ops.py +42 -10
  166. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  167. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  168. mindspore/ops/operations/math_ops.py +12 -223
  169. mindspore/ops/operations/nn_ops.py +20 -114
  170. mindspore/ops/operations/other_ops.py +7 -4
  171. mindspore/ops/operations/random_ops.py +46 -1
  172. mindspore/ops/primitive.py +18 -6
  173. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  174. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  175. mindspore/ops_generate/gen_constants.py +36 -0
  176. mindspore/ops_generate/gen_ops.py +67 -52
  177. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  178. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  179. mindspore/ops_generate/op_proto.py +10 -3
  180. mindspore/ops_generate/pyboost_utils.py +14 -1
  181. mindspore/ops_generate/template.py +43 -21
  182. mindspore/parallel/__init__.py +3 -1
  183. mindspore/parallel/_auto_parallel_context.py +28 -8
  184. mindspore/parallel/_cell_wrapper.py +83 -0
  185. mindspore/parallel/_parallel_serialization.py +47 -19
  186. mindspore/parallel/_tensor.py +81 -11
  187. mindspore/parallel/_utils.py +13 -1
  188. mindspore/parallel/algo_parameter_config.py +5 -5
  189. mindspore/parallel/checkpoint_transform.py +46 -39
  190. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  191. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  192. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  193. mindspore/parallel/parameter_broadcast.py +3 -4
  194. mindspore/parallel/shard.py +162 -31
  195. mindspore/parallel/transform_safetensors.py +993 -0
  196. mindspore/profiler/__init__.py +2 -1
  197. mindspore/profiler/common/constant.py +29 -0
  198. mindspore/profiler/common/registry.py +47 -0
  199. mindspore/profiler/common/util.py +28 -0
  200. mindspore/profiler/dynamic_profiler.py +694 -0
  201. mindspore/profiler/envprofiling.py +17 -19
  202. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  203. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  204. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  205. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  206. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  207. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  208. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  209. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  210. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  211. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  212. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  213. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  214. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  215. mindspore/profiler/parser/framework_parser.py +1 -391
  216. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  217. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  218. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  219. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  220. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  221. mindspore/profiler/parser/profiler_info.py +78 -6
  222. mindspore/profiler/profiler.py +153 -0
  223. mindspore/profiler/profiling.py +280 -412
  224. mindspore/rewrite/__init__.py +1 -2
  225. mindspore/rewrite/common/namespace.py +4 -4
  226. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  227. mindspore/run_check/_check_version.py +36 -103
  228. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  229. mindspore/swresample-4.dll +0 -0
  230. mindspore/swscale-6.dll +0 -0
  231. mindspore/tinyxml2.dll +0 -0
  232. mindspore/train/__init__.py +4 -3
  233. mindspore/train/_utils.py +28 -2
  234. mindspore/train/amp.py +171 -53
  235. mindspore/train/callback/__init__.py +2 -2
  236. mindspore/train/callback/_callback.py +4 -4
  237. mindspore/train/callback/_checkpoint.py +85 -22
  238. mindspore/train/callback/_cluster_monitor.py +1 -1
  239. mindspore/train/callback/_flops_collector.py +1 -0
  240. mindspore/train/callback/_loss_monitor.py +3 -3
  241. mindspore/train/callback/_on_request_exit.py +134 -31
  242. mindspore/train/callback/_summary_collector.py +5 -5
  243. mindspore/train/callback/_tft_register.py +352 -0
  244. mindspore/train/dataset_helper.py +7 -3
  245. mindspore/train/metrics/metric.py +3 -3
  246. mindspore/train/metrics/roc.py +4 -4
  247. mindspore/train/mind_ir_pb2.py +44 -39
  248. mindspore/train/model.py +134 -58
  249. mindspore/train/serialization.py +336 -112
  250. mindspore/turbojpeg.dll +0 -0
  251. mindspore/utils/__init__.py +21 -0
  252. mindspore/utils/utils.py +60 -0
  253. mindspore/version.py +1 -1
  254. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  255. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
  256. mindspore/include/c_api/ms/abstract.h +0 -67
  257. mindspore/include/c_api/ms/attribute.h +0 -197
  258. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  259. mindspore/include/c_api/ms/base/macros.h +0 -32
  260. mindspore/include/c_api/ms/base/status.h +0 -33
  261. mindspore/include/c_api/ms/base/types.h +0 -283
  262. mindspore/include/c_api/ms/context.h +0 -102
  263. mindspore/include/c_api/ms/graph.h +0 -160
  264. mindspore/include/c_api/ms/node.h +0 -606
  265. mindspore/include/c_api/ms/tensor.h +0 -161
  266. mindspore/include/c_api/ms/value.h +0 -84
  267. mindspore/mindspore_shared_lib.dll +0 -0
  268. mindspore/nn/extend/basic.py +0 -140
  269. mindspore/nn/extend/embedding.py +0 -143
  270. mindspore/nn/extend/layer/normalization.py +0 -109
  271. mindspore/nn/extend/pooling.py +0 -117
  272. mindspore/nn/layer/embedding_service.py +0 -531
  273. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  274. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  275. mindspore/ops/extend/__init__.py +0 -53
  276. mindspore/ops/extend/array_func.py +0 -218
  277. mindspore/ops/extend/math_func.py +0 -76
  278. mindspore/ops/extend/nn_func.py +0 -308
  279. mindspore/ops/silent_check.py +0 -162
  280. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  281. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  282. mindspore/train/callback/_mindio_ttp.py +0 -443
  283. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  284. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  285. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -32,7 +32,7 @@ from mindspore import context
32
32
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
33
33
  from mindspore import _checkparam as Validator
34
34
  from mindspore.common import dtype as mstype
35
- from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
35
+ from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, _no_grad
36
36
  from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
37
37
  from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
38
38
  from mindspore.common.parameter import Parameter, ParameterTuple
@@ -45,7 +45,6 @@ from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
45
  from mindspore.common._decorator import deprecated
46
46
  from mindspore.common._register_for_recompute import recompute_registry
47
47
 
48
-
49
48
  class Cell(Cell_):
50
49
  """
51
50
  The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
@@ -101,9 +100,9 @@ class Cell(Cell_):
101
100
  """
102
101
 
103
102
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
104
- '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
105
- '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
106
- '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
103
+ '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
104
+ '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
105
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
107
106
  '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
108
107
  total_instance_count = 0
109
108
 
@@ -135,7 +134,8 @@ class Cell(Cell_):
135
134
  self._id = 1
136
135
  self.exist_names = set("")
137
136
  self.exist_objs = set()
138
- self.recompute_cell = None
137
+ self._recompute_cell = None
138
+ self.mixed_precision_type = None
139
139
  self.sig = inspect.signature(self.construct)
140
140
  init_pipeline()
141
141
 
@@ -146,13 +146,16 @@ class Cell(Cell_):
146
146
  if flags:
147
147
  self.add_flags(**flags)
148
148
  self._bprop_debug = False
149
+
150
+ # hook
149
151
  self._forward_pre_hook = OrderedDict()
150
152
  self._forward_hook = OrderedDict()
151
- self._enable_forward_pre_hook = False
152
- self._enable_forward_hook = False
153
- self._enable_backward_hook = False
153
+ self._backward_pre_hook = OrderedDict()
154
+ self._cell_backward_pre_hook = None
155
+ self._backward_hook = OrderedDict()
154
156
  self._cell_backward_hook = None
155
157
  self._is_recursion_hook = False
158
+
156
159
  self.cell_type = None
157
160
  self.cast = Cast()
158
161
  self._has_config_recompute = False
@@ -166,6 +169,10 @@ class Cell(Cell_):
166
169
  self._is_check_and_refresh = False
167
170
  self._amp_level = ""
168
171
  self._init_flag = False
172
+ self._shard_fn = None
173
+ self.has_bprop = False
174
+ if hasattr(self, "bprop"):
175
+ self.has_bprop = True
169
176
 
170
177
  def __getstate__(self):
171
178
  base = Cell_.__getstate__(self)
@@ -223,8 +230,9 @@ class Cell(Cell_):
223
230
  Get whether cell custom bprop debug is enabled.
224
231
 
225
232
  Tutorial Examples:
226
- - `Cell and Parameter - Custom Cell Reverse
227
- <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
233
+ - `Custom Neural Network Layers - Custom Cell Reverse
234
+ <https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
235
+ #custom-cell-reverse>`_
228
236
  """
229
237
  return self._bprop_debug
230
238
 
@@ -374,6 +382,10 @@ class Cell(Cell_):
374
382
  def jit_config_dict(self):
375
383
  return self._jit_config_dict
376
384
 
385
+ @property
386
+ def enable_backward_hook(self):
387
+ return self._enable_backward_hook
388
+
377
389
  def get_func_graph_proto(self):
378
390
  """Return graph binary proto."""
379
391
  exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
@@ -401,8 +413,6 @@ class Cell(Cell_):
401
413
  cells_compile_cache.pop(id(self), None)
402
414
  if hasattr(self, "compile_cache") and self.compile_cache:
403
415
  _cell_graph_executor.del_net_res(self, self.compile_cache)
404
- if isinstance(self, GraphCell):
405
- _cell_graph_executor.dec_graph_cell_count()
406
416
  Cell.total_instance_count -= 1
407
417
 
408
418
  def __delattr__(self, name):
@@ -475,21 +485,28 @@ class Cell(Cell_):
475
485
  output = self._run_construct(cast_inputs, kwargs)
476
486
  return output
477
487
 
478
- def _run_construct(self, cast_inputs, kwargs):
488
+ def _run_construct(self, *inputs, **kwargs):
479
489
  """Run the construct function"""
480
- if self._enable_forward_pre_hook:
481
- cast_inputs = self._run_forward_pre_hook(cast_inputs)
482
- if self._enable_backward_hook:
483
- output = self._backward_hook_construct(*cast_inputs, **kwargs)
484
- elif hasattr(self, "_shard_fn"):
485
- output = self._shard_fn(*cast_inputs, **kwargs)
490
+ if self._forward_pre_hook:
491
+ inputs = self._run_forward_pre_hook(inputs)
492
+
493
+ if self._backward_hook:
494
+ output = self._backward_hook_construct(*inputs, **kwargs)
495
+ elif self._shard_fn is not None:
496
+ output = self._shard_fn(*inputs, **kwargs)
497
+ elif self._recompute_cell is not None:
498
+ output = self._recompute_cell(*inputs, **kwargs)
499
+ elif self.has_bprop and _pynative_executor.requires_grad():
500
+ output = self._call_custom_bprop(*inputs, **kwargs)
486
501
  else:
487
- if self.recompute_cell is not None:
488
- output = self.recompute_cell(*cast_inputs, **kwargs)
489
- else:
490
- output = self.construct(*cast_inputs, **kwargs)
491
- if self._enable_forward_hook:
492
- output = self._run_forward_hook(cast_inputs, output)
502
+ output = self.construct(*inputs, **kwargs)
503
+
504
+ if self._forward_hook:
505
+ output = self._run_forward_hook(inputs, output)
506
+
507
+ if self._backward_pre_hook:
508
+ output = self._run_backward_pre_hook(output)
509
+
493
510
  return output
494
511
 
495
512
  def _check_construct_args(self, *args):
@@ -527,7 +544,7 @@ class Cell(Cell_):
527
544
  '''Hook function in graph mode'''
528
545
  # Check super().__init__() in graph mode.
529
546
  try:
530
- if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
547
+ if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
531
548
  return True
532
549
  except AttributeError as e:
533
550
  raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
@@ -579,8 +596,7 @@ class Cell(Cell_):
579
596
  strategy for others will be set by sharding propagation.
580
597
  in_strategy and out_strategy define the input and output layout respectively.
581
598
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
582
- this input/output, and None represents data_parallel,
583
- which can refer to the description of `mindspore.ops.Primitive.shard`.
599
+ this input/output, which can refer to the description of `mindspore.ops.Primitive.shard`.
584
600
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
585
601
 
586
602
  Note:
@@ -589,8 +605,8 @@ class Cell(Cell_):
589
605
  If the input contain Parameter, its strategy should be set in `in_strategy`.
590
606
 
591
607
  Args:
592
- in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
593
- defines the layout of the corresponding input and None represents a data parallel strategy.
608
+ in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
609
+ defines the layout of the corresponding input.
594
610
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
595
611
  It is not in use right now. Default: ``None`` .
596
612
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
@@ -625,7 +641,7 @@ class Cell(Cell_):
625
641
  ... def __init__(self):
626
642
  ... self.block1 = Block()
627
643
  ... self.block2 = Block()
628
- ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
644
+ ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
629
645
  ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
630
646
  ... def construct(self, x):
631
647
  ... x = self.block1(x)
@@ -638,7 +654,7 @@ class Cell(Cell_):
638
654
 
639
655
  shard_fn = Shard()
640
656
  fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
641
- object.__setattr__(self, "_shard_fn", fn)
657
+ self._shard_fn = fn
642
658
  return fn
643
659
 
644
660
  def auto_cast_inputs(self, inputs):
@@ -666,6 +682,7 @@ class Cell(Cell_):
666
682
  for param in self.get_parameters(expand=False):
667
683
  if param.has_init:
668
684
  param.init_data()
685
+ self._init_flag = True
669
686
 
670
687
  def _self_check(self):
671
688
  if not self._is_check_and_refresh:
@@ -684,7 +701,7 @@ class Cell(Cell_):
684
701
 
685
702
  def __call__(self, *args, **kwargs):
686
703
  # Run in Graph mode.
687
- if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
704
+ if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
688
705
  if kwargs:
689
706
  bound_arguments = self.sig.bind(*args, **kwargs)
690
707
  bound_arguments.apply_defaults()
@@ -704,22 +721,69 @@ class Cell(Cell_):
704
721
  return out
705
722
 
706
723
  # Run in PyNative mode.
707
- self._self_check()
708
- if not self._init_flag:
724
+ if not (self._init_flag or self._is_check_and_refresh):
709
725
  self._init_check()
710
- self._init_flag = True
726
+ self._self_check()
727
+
728
+ if not (self.requires_grad or self._dynamic_shape_inputs or self.mixed_precision_type):
729
+ if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
730
+ self._shard_fn or self._recompute_cell or (self.has_bprop and _pynative_executor.requires_grad())):
731
+ return self.construct(*args, **kwargs)
732
+
733
+ return self._run_construct(*args, **kwargs)
734
+
735
+ return self._complex_call(*args, **kwargs)
711
736
 
737
+ def _complex_call(self, *args, **kwargs):
738
+ """
739
+ PyNative call with requires_grad or hooks
740
+ """
741
+ self._call_pre_process(*args, **kwargs)
742
+
743
+ if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
744
+ self._shard_fn or self._recompute_cell or self.has_bprop):
745
+ output = self.construct(*args, **kwargs)
746
+ else:
747
+ output = self._run_construct(*args, **kwargs)
748
+
749
+ self._call_post_process(output, *args, **kwargs)
750
+
751
+ return output
752
+
753
+ def _call_pre_process(self, *args, **kwargs):
754
+ """
755
+ Process cell info before call construct
756
+ """
712
757
  if self.requires_grad:
713
758
  _pynative_executor.set_grad_flag(True)
714
-
715
- try:
716
759
  _pynative_executor.new_graph(self, *args, **kwargs)
717
- output = self._run_construct(args, kwargs)
760
+ elif self._dynamic_shape_inputs is not None:
761
+ _pynative_executor.set_cell_use_dynamic_shape_process(True)
762
+
763
+ # Set mixed precision
764
+ if self.mixed_precision_type is not None:
765
+ _pynative_executor.set_mixed_precision_type(self.mixed_precision_type)
766
+
767
+ def _call_post_process(self, output, *args, **kwargs):
768
+ """
769
+ Process cell info after call construct
770
+ """
771
+ if self.requires_grad:
718
772
  _pynative_executor.end_graph(self, output, *args, **kwargs)
719
- except Exception as err:
720
- _pynative_executor.clear_res()
721
- raise err
773
+ elif self._dynamic_shape_inputs is not None:
774
+ _pynative_executor.set_cell_use_dynamic_shape_process(False)
775
+
776
+ # mixed precision reset
777
+ if self.mixed_precision_type is not None:
778
+ _pynative_executor.set_mixed_precision_type(MixedPrecisionType.NOTSET, False)
722
779
 
780
+ def _call_custom_bprop(self, *args, **kwargs):
781
+ """
782
+ Call custom bprop for cell bprop.
783
+ """
784
+ with _no_grad():
785
+ output = self.construct(*args, **kwargs)
786
+ _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
723
787
  return output
724
788
 
725
789
  def _add_attr(self, name, value):
@@ -961,9 +1025,12 @@ class Cell(Cell_):
961
1025
 
962
1026
  if not kwargs:
963
1027
  self._dynamic_shape_inputs = inputs
964
- self._check_construct_args(*inputs)
965
1028
  if context._get_mode() == context.PYNATIVE_MODE:
966
1029
  _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1030
+ else:
1031
+ self._check_construct_args(*inputs)
1032
+ # TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
1033
+ # which means that incremental mode is lacking dynamic input.
967
1034
  else:
968
1035
  self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
969
1036
 
@@ -1682,10 +1749,13 @@ class Cell(Cell_):
1682
1749
  def _add_mixed_precision_flag(self, **flags):
1683
1750
  """Add mixed precision flag to current cell"""
1684
1751
  if "fp16" in flags and flags.get("fp16", False):
1752
+ self.mixed_precision_type = MixedPrecisionType.FP16
1685
1753
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1686
1754
  if "fp32" in flags and flags.get("fp32", False):
1755
+ self.mixed_precision_type = MixedPrecisionType.FP32
1687
1756
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1688
1757
  if "bf16" in flags and flags.get("bf16", False):
1758
+ self.mixed_precision_type = MixedPrecisionType.BF16
1689
1759
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1690
1760
 
1691
1761
  def apply(self, fn):
@@ -2050,15 +2120,12 @@ class Cell(Cell_):
2050
2120
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2051
2121
  value= [ 2.00000000e+00]))
2052
2122
  """
2123
+ if context._get_mode() == context.GRAPH_MODE:
2124
+ return HookHandle()
2053
2125
  if not check_hook_fn("register_forward_pre_hook", hook_fn):
2054
2126
  return HookHandle()
2055
- self._enable_forward_pre_hook = True
2056
- _pynative_executor.set_hook_changed(self)
2057
- if not hasattr(self, '_forward_pre_hook_key'):
2058
- self._forward_pre_hook_key = -1
2059
- self._forward_pre_hook_key += 1
2060
- self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
2061
- handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
2127
+ handle = HookHandle(self._forward_pre_hook)
2128
+ self._forward_pre_hook[handle.handle_id] = hook_fn
2062
2129
  return handle
2063
2130
 
2064
2131
  def _run_forward_pre_hook(self, inputs):
@@ -2074,14 +2141,23 @@ class Cell(Cell_):
2074
2141
  Supported Platforms:
2075
2142
  ``Ascend`` ``GPU`` ``CPU``
2076
2143
  """
2144
+ forward_pre_hook_inputs = inputs
2077
2145
  for fn in self._forward_pre_hook.values():
2078
- ret = fn(self, inputs)
2146
+ ret = fn(self, forward_pre_hook_inputs)
2079
2147
  if ret is not None:
2080
2148
  if not isinstance(ret, tuple):
2081
- inputs = (ret,)
2149
+ forward_pre_hook_inputs = (ret,)
2082
2150
  else:
2083
- inputs = ret
2084
- return inputs
2151
+ forward_pre_hook_inputs = ret
2152
+
2153
+ if isinstance(inputs, tuple):
2154
+ if not isinstance(forward_pre_hook_inputs, tuple):
2155
+ forward_pre_hook_inputs = (forward_pre_hook_inputs,)
2156
+ if len(forward_pre_hook_inputs) != len(inputs):
2157
+ raise TypeError(
2158
+ "The forward pre hook return value size is {} not equal to input size {}".format(
2159
+ len(forward_pre_hook_inputs), len(inputs)))
2160
+ return forward_pre_hook_inputs
2085
2161
 
2086
2162
  def register_forward_hook(self, hook_fn):
2087
2163
  """
@@ -2142,15 +2218,12 @@ class Cell(Cell_):
2142
2218
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2143
2219
  value= [ 2.00000000e+00]))
2144
2220
  """
2221
+ if context._get_mode() == context.GRAPH_MODE:
2222
+ return HookHandle()
2145
2223
  if not check_hook_fn("register_forward_hook", hook_fn):
2146
2224
  return HookHandle()
2147
- self._enable_forward_hook = True
2148
- _pynative_executor.set_hook_changed(self)
2149
- if not hasattr(self, '_forward_hook_key'):
2150
- self._forward_hook_key = -1
2151
- self._forward_hook_key += 1
2152
- self._forward_hook[self._forward_hook_key] = hook_fn
2153
- handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
2225
+ handle = HookHandle(self._forward_hook)
2226
+ self._forward_hook[handle.handle_id] = hook_fn
2154
2227
  return handle
2155
2228
 
2156
2229
  def _run_forward_hook(self, inputs, output):
@@ -2167,11 +2240,110 @@ class Cell(Cell_):
2167
2240
  Supported Platforms:
2168
2241
  ``Ascend`` ``GPU`` ``CPU``
2169
2242
  """
2243
+ forward_hook_output = output
2170
2244
  for fn in self._forward_hook.values():
2171
- ret = fn(self, inputs, output)
2245
+ ret = fn(self, inputs, forward_hook_output)
2172
2246
  if ret is not None:
2173
- output = ret
2174
- return output
2247
+ forward_hook_output = ret
2248
+
2249
+ if isinstance(output, tuple):
2250
+ if not isinstance(forward_hook_output, tuple):
2251
+ forward_hook_output = (forward_hook_output,)
2252
+ if len(forward_hook_output) != len(output):
2253
+ raise TypeError(
2254
+ "The forward hook return value size is {} not equal to output size {}".format(
2255
+ len(forward_hook_output), len(output)))
2256
+ return forward_hook_output
2257
+
2258
+ def register_backward_pre_hook(self, hook_fn):
2259
+ """
2260
+ Register the backward pre hook function.
2261
+
2262
+ Note:
2263
+ - The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2264
+ - The 'hook_fn' must be defined as the following code.
2265
+ `cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
2266
+ - The 'hook_fn' should have the following signature:
2267
+ hook_fn(cell, grad_output) -> New grad_output gradient or None.
2268
+ - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2269
+ graph mode, it is not recommended to write it in the `construct` function of Cell object.
2270
+ - In the pynative
2271
+ mode, if the `register_backward_pre_hook` function is called in the `construct` function of the Cell
2272
+ object, a hook function will be added at each run time of Cell object.
2273
+
2274
+ Args:
2275
+ hook_fn (function): Python function. Backward pre hook function.
2276
+
2277
+ Returns:
2278
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2279
+ `handle.remove()` .
2280
+
2281
+ Raises:
2282
+ TypeError: If the `hook_fn` is not a function of python.
2283
+
2284
+ Supported Platforms:
2285
+ ``Ascend`` ``GPU`` ``CPU``
2286
+
2287
+ Examples:
2288
+ >>> import numpy as np
2289
+ >>> import mindspore as ms
2290
+ >>> from mindspore import Tensor, nn, ops
2291
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2292
+ >>> def backward_pre_hook_fn(cell, grad_output):
2293
+ ... print("backward input: ", grad_output)
2294
+ ...
2295
+ >>> class Net(nn.Cell):
2296
+ ... def __init__(self):
2297
+ ... super(Net, self).__init__()
2298
+ ... self.relu = nn.ReLU()
2299
+ ... self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn)
2300
+ ...
2301
+ ... def construct(self, x):
2302
+ ... x = x + x
2303
+ ... x = self.relu(x)
2304
+ ... return x
2305
+ >>> grad = ops.GradOperation(get_all=True)
2306
+ >>> net = Net()
2307
+ >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
2308
+ backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2309
+ >>> print(output)
2310
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2311
+ """
2312
+ if context._get_mode() == context.GRAPH_MODE:
2313
+ return HookHandle()
2314
+ if not check_hook_fn("register_backward_pre_hook", hook_fn):
2315
+ return HookHandle()
2316
+ handle = HookHandle(self._backward_pre_hook)
2317
+ self._backward_pre_hook[handle.handle_id] = hook_fn
2318
+ if self._cell_backward_pre_hook is None:
2319
+ # Generate a CellBackwardHook prim, and add function for it
2320
+ self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2321
+ self, self._backward_pre_hook)
2322
+ self._cell_backward_pre_hook.register_backward_pre_hook()
2323
+ return handle
2324
+
2325
+ def _run_backward_pre_hook(self, outputs):
2326
+ """
2327
+ Running backward pre hook function registered on Cell object.
2328
+
2329
+ Args:
2330
+ outputs: The output objects of cell object.
2331
+
2332
+ Returns:
2333
+ - **outputs** - New backward gradient or None.
2334
+
2335
+ Supported Platforms:
2336
+ ``Ascend`` ``GPU`` ``CPU``
2337
+ """
2338
+ ret = self._cell_backward_pre_hook(outputs)
2339
+ if isinstance(outputs, tuple):
2340
+ if not isinstance(ret, tuple):
2341
+ ret = (ret,)
2342
+ if len(ret) != len(outputs):
2343
+ raise TypeError(
2344
+ "The backward pre hook return value size is {} not equal to output size {}".format(
2345
+ len(ret), len(outputs)))
2346
+ return ret
2175
2347
 
2176
2348
  def register_backward_hook(self, hook_fn):
2177
2349
  """
@@ -2180,11 +2352,11 @@ class Cell(Cell_):
2180
2352
  Note:
2181
2353
  - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2182
2354
  - The 'hook_fn' must be defined as the following code.
2183
- `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
2184
- gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
2185
- primitive, which may be modified by returning a new output gradient.
2355
+ `cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
2356
+ the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
2357
+ passed to the Cell.
2186
2358
  - The 'hook_fn' should have the following signature:
2187
- hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
2359
+ hook_fn(cell, grad_input, grad_output) -> New grad_input gradient or none.
2188
2360
  - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2189
2361
  graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
2190
2362
  mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
@@ -2208,9 +2380,9 @@ class Cell(Cell_):
2208
2380
  >>> import mindspore as ms
2209
2381
  >>> from mindspore import Tensor, nn, ops
2210
2382
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2211
- >>> def backward_hook_fn(cell_id, grad_input, grad_output):
2212
- ... print("backward input: ", grad_input)
2213
- ... print("backward output: ", grad_output)
2383
+ >>> def backward_hook_fn(cell, grad_input, grad_output):
2384
+ ... print("backward input: ", grad_output)
2385
+ ... print("backward output: ", grad_input)
2214
2386
  ...
2215
2387
  >>> class Net(nn.Cell):
2216
2388
  ... def __init__(self):
@@ -2230,16 +2402,17 @@ class Cell(Cell_):
2230
2402
  >>> print(output)
2231
2403
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2232
2404
  """
2405
+ if context._get_mode() == context.GRAPH_MODE:
2406
+ return HookHandle()
2233
2407
  if not check_hook_fn("register_backward_hook", hook_fn):
2234
2408
  return HookHandle()
2409
+ handle = HookHandle(self._backward_hook)
2410
+ self._backward_hook[handle.handle_id] = hook_fn
2235
2411
  if self._cell_backward_hook is None:
2236
- self._enable_backward_hook = True
2237
- self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
2238
- backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2239
- handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2240
- else:
2241
- backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2242
- handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2412
+ # Generate a CellBackwardHook prim, and add function for it
2413
+ self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2414
+ self, self._backward_hook)
2415
+ self._cell_backward_hook.register_backward_hook()
2243
2416
  return handle
2244
2417
 
2245
2418
  def _backward_hook_construct(self, *inputs, **kwargs):
@@ -2256,21 +2429,31 @@ class Cell(Cell_):
2256
2429
  Supported Platforms:
2257
2430
  ``Ascend`` ``GPU`` ``CPU``
2258
2431
  """
2259
- if len(inputs) > 1:
2260
- inputs = self._cell_backward_hook(inputs)
2261
- else:
2262
- inputs = self._cell_backward_hook(*inputs)
2263
- inputs = (inputs,)
2264
- if self.recompute_cell is not None:
2265
- if isinstance(inputs, tuple):
2266
- outputs = self.recompute_cell(*inputs, **kwargs)
2432
+ # cell_backward_hook has CellBackwardHook op, so keep input args as they are.
2433
+ outputs = self._cell_backward_hook(*inputs)
2434
+ # If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
2435
+ # a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
2436
+ # Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
2437
+ is_need_unwrap = False
2438
+ if isinstance(outputs, tuple) and len(inputs) != 1:
2439
+ is_need_unwrap = True
2440
+
2441
+ if self._recompute_cell is not None:
2442
+ if is_need_unwrap:
2443
+ outputs = self._recompute_cell(*outputs, **kwargs)
2444
+ else:
2445
+ outputs = self._recompute_cell(outputs, **kwargs)
2446
+ elif self.has_bprop:
2447
+ if is_need_unwrap:
2448
+ outputs = self._call_custom_bprop(*outputs, **kwargs)
2267
2449
  else:
2268
- outputs = self.recompute_cell(inputs, **kwargs)
2450
+ outputs = self._call_custom_bprop(outputs, **kwargs)
2269
2451
  else:
2270
- if isinstance(inputs, tuple):
2271
- outputs = self.construct(*inputs, **kwargs)
2452
+ if is_need_unwrap:
2453
+ outputs = self.construct(*outputs, **kwargs)
2272
2454
  else:
2273
- outputs = self.construct(inputs, **kwargs)
2455
+ outputs = self.construct(outputs, **kwargs)
2456
+
2274
2457
  outputs = self._cell_backward_hook(outputs)
2275
2458
  return outputs
2276
2459
 
@@ -2401,7 +2584,7 @@ class Cell(Cell_):
2401
2584
  Default: ``False`` .
2402
2585
  """
2403
2586
  if context.get_context("mode") == context.PYNATIVE_MODE:
2404
- self.recompute_cell = recompute_registry.get()(self.construct)
2587
+ self._recompute_cell = recompute_registry.get()(self.construct)
2405
2588
  return
2406
2589
  self._recompute()
2407
2590
  if 'mp_comm_recompute' in kwargs.keys():
@@ -2579,7 +2762,6 @@ class GraphCell(Cell):
2579
2762
  params_dict = update_func_graph_hyper_params(self.graph, params_init)
2580
2763
  for name, param in params_dict.items():
2581
2764
  self._params[name] = param
2582
- _cell_graph_executor.inc_graph_cell_count()
2583
2765
 
2584
2766
  def construct(self, *inputs):
2585
2767
  return self.graph(*inputs)