mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (275) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/common/__init__.py +6 -4
  13. mindspore/common/_pijit_context.py +190 -0
  14. mindspore/common/_register_for_tensor.py +2 -1
  15. mindspore/common/_tensor_overload.py +139 -0
  16. mindspore/common/api.py +102 -87
  17. mindspore/common/dump.py +5 -6
  18. mindspore/common/generator.py +1 -7
  19. mindspore/common/hook_handle.py +14 -26
  20. mindspore/common/initializer.py +51 -15
  21. mindspore/common/mindir_util.py +2 -2
  22. mindspore/common/parameter.py +62 -15
  23. mindspore/common/recompute.py +39 -9
  24. mindspore/common/sparse_tensor.py +7 -3
  25. mindspore/common/tensor.py +183 -37
  26. mindspore/communication/__init__.py +1 -1
  27. mindspore/communication/_comm_helper.py +38 -3
  28. mindspore/communication/comm_func.py +315 -60
  29. mindspore/communication/management.py +14 -14
  30. mindspore/context.py +132 -22
  31. mindspore/dataset/__init__.py +1 -1
  32. mindspore/dataset/audio/__init__.py +1 -1
  33. mindspore/dataset/core/config.py +7 -0
  34. mindspore/dataset/core/validator_helpers.py +7 -0
  35. mindspore/dataset/engine/cache_client.py +1 -1
  36. mindspore/dataset/engine/datasets.py +72 -44
  37. mindspore/dataset/engine/datasets_audio.py +7 -7
  38. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  39. mindspore/dataset/engine/datasets_text.py +20 -20
  40. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  41. mindspore/dataset/engine/datasets_vision.py +33 -33
  42. mindspore/dataset/engine/iterators.py +29 -0
  43. mindspore/dataset/engine/obs/util.py +7 -0
  44. mindspore/dataset/engine/queue.py +114 -60
  45. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  46. mindspore/dataset/engine/validators.py +34 -14
  47. mindspore/dataset/text/__init__.py +1 -4
  48. mindspore/dataset/transforms/__init__.py +0 -3
  49. mindspore/dataset/utils/line_reader.py +2 -0
  50. mindspore/dataset/vision/__init__.py +1 -4
  51. mindspore/dataset/vision/utils.py +1 -1
  52. mindspore/dataset/vision/validators.py +2 -1
  53. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  54. mindspore/experimental/es/embedding_service.py +883 -0
  55. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  56. mindspore/experimental/llm_boost/__init__.py +21 -0
  57. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  58. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  59. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  60. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  61. mindspore/experimental/llm_boost/register.py +129 -0
  62. mindspore/experimental/llm_boost/utils.py +31 -0
  63. mindspore/experimental/optim/adamw.py +85 -0
  64. mindspore/experimental/optim/optimizer.py +3 -0
  65. mindspore/hal/__init__.py +3 -3
  66. mindspore/hal/contiguous_tensors_handle.py +175 -0
  67. mindspore/hal/stream.py +18 -0
  68. mindspore/include/api/model_group.h +13 -1
  69. mindspore/include/api/types.h +10 -10
  70. mindspore/include/dataset/config.h +2 -2
  71. mindspore/include/dataset/constants.h +2 -2
  72. mindspore/include/dataset/execute.h +2 -2
  73. mindspore/include/dataset/vision.h +4 -0
  74. mindspore/log.py +1 -1
  75. mindspore/mindrecord/filewriter.py +68 -51
  76. mindspore/mindspore_backend.dll +0 -0
  77. mindspore/mindspore_common.dll +0 -0
  78. mindspore/mindspore_core.dll +0 -0
  79. mindspore/mindspore_np_dtype.dll +0 -0
  80. mindspore/mindspore_ops.dll +0 -0
  81. mindspore/mint/__init__.py +983 -46
  82. mindspore/mint/distributed/__init__.py +31 -0
  83. mindspore/mint/distributed/distributed.py +254 -0
  84. mindspore/mint/nn/__init__.py +268 -23
  85. mindspore/mint/nn/functional.py +125 -19
  86. mindspore/mint/nn/layer/__init__.py +39 -0
  87. mindspore/mint/nn/layer/activation.py +133 -0
  88. mindspore/mint/nn/layer/normalization.py +477 -0
  89. mindspore/mint/nn/layer/pooling.py +110 -0
  90. mindspore/mint/optim/adamw.py +26 -13
  91. mindspore/mint/special/__init__.py +63 -0
  92. mindspore/multiprocessing/__init__.py +2 -1
  93. mindspore/nn/__init__.py +0 -1
  94. mindspore/nn/cell.py +276 -96
  95. mindspore/nn/layer/activation.py +211 -44
  96. mindspore/nn/layer/basic.py +137 -10
  97. mindspore/nn/layer/embedding.py +137 -2
  98. mindspore/nn/layer/normalization.py +101 -5
  99. mindspore/nn/layer/padding.py +34 -48
  100. mindspore/nn/layer/pooling.py +161 -7
  101. mindspore/nn/layer/transformer.py +3 -3
  102. mindspore/nn/loss/__init__.py +2 -2
  103. mindspore/nn/loss/loss.py +84 -6
  104. mindspore/nn/optim/__init__.py +2 -1
  105. mindspore/nn/optim/adadelta.py +1 -1
  106. mindspore/nn/optim/adam.py +1 -1
  107. mindspore/nn/optim/lamb.py +1 -1
  108. mindspore/nn/optim/tft_wrapper.py +124 -0
  109. mindspore/nn/wrap/cell_wrapper.py +12 -23
  110. mindspore/nn/wrap/grad_reducer.py +5 -5
  111. mindspore/nn/wrap/loss_scale.py +17 -3
  112. mindspore/numpy/__init__.py +1 -1
  113. mindspore/numpy/array_creations.py +65 -68
  114. mindspore/numpy/array_ops.py +64 -60
  115. mindspore/numpy/fft.py +610 -75
  116. mindspore/numpy/logic_ops.py +11 -10
  117. mindspore/numpy/math_ops.py +85 -84
  118. mindspore/numpy/utils_const.py +4 -4
  119. mindspore/opencv_core452.dll +0 -0
  120. mindspore/opencv_imgcodecs452.dll +0 -0
  121. mindspore/opencv_imgproc452.dll +0 -0
  122. mindspore/ops/__init__.py +6 -4
  123. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  124. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  125. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  126. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  127. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  128. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  129. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  130. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  131. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  132. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  133. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  134. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  135. mindspore/ops/composite/base.py +85 -48
  136. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  137. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  138. mindspore/ops/function/__init__.py +22 -0
  139. mindspore/ops/function/array_func.py +492 -153
  140. mindspore/ops/function/debug_func.py +113 -1
  141. mindspore/ops/function/fft_func.py +15 -2
  142. mindspore/ops/function/grad/grad_func.py +3 -2
  143. mindspore/ops/function/math_func.py +564 -207
  144. mindspore/ops/function/nn_func.py +817 -383
  145. mindspore/ops/function/other_func.py +3 -2
  146. mindspore/ops/function/random_func.py +402 -12
  147. mindspore/ops/function/reshard_func.py +13 -11
  148. mindspore/ops/function/sparse_unary_func.py +1 -1
  149. mindspore/ops/function/vmap_func.py +3 -2
  150. mindspore/ops/functional.py +24 -14
  151. mindspore/ops/op_info_register.py +3 -3
  152. mindspore/ops/operations/__init__.py +7 -2
  153. mindspore/ops/operations/_grad_ops.py +2 -76
  154. mindspore/ops/operations/_infer_ops.py +1 -1
  155. mindspore/ops/operations/_inner_ops.py +71 -94
  156. mindspore/ops/operations/array_ops.py +14 -146
  157. mindspore/ops/operations/comm_ops.py +63 -53
  158. mindspore/ops/operations/custom_ops.py +83 -19
  159. mindspore/ops/operations/debug_ops.py +42 -10
  160. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  161. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  162. mindspore/ops/operations/math_ops.py +12 -223
  163. mindspore/ops/operations/nn_ops.py +20 -114
  164. mindspore/ops/operations/other_ops.py +7 -4
  165. mindspore/ops/operations/random_ops.py +46 -1
  166. mindspore/ops/primitive.py +18 -6
  167. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  168. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  169. mindspore/ops_generate/gen_constants.py +36 -0
  170. mindspore/ops_generate/gen_ops.py +67 -52
  171. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  172. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  173. mindspore/ops_generate/op_proto.py +10 -3
  174. mindspore/ops_generate/pyboost_utils.py +14 -1
  175. mindspore/ops_generate/template.py +43 -21
  176. mindspore/parallel/__init__.py +3 -1
  177. mindspore/parallel/_auto_parallel_context.py +31 -9
  178. mindspore/parallel/_cell_wrapper.py +85 -0
  179. mindspore/parallel/_parallel_serialization.py +47 -19
  180. mindspore/parallel/_tensor.py +127 -13
  181. mindspore/parallel/_utils.py +53 -22
  182. mindspore/parallel/algo_parameter_config.py +5 -5
  183. mindspore/parallel/checkpoint_transform.py +46 -39
  184. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  185. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  186. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  187. mindspore/parallel/parameter_broadcast.py +3 -4
  188. mindspore/parallel/shard.py +162 -31
  189. mindspore/parallel/transform_safetensors.py +1146 -0
  190. mindspore/profiler/__init__.py +2 -1
  191. mindspore/profiler/common/constant.py +29 -0
  192. mindspore/profiler/common/registry.py +47 -0
  193. mindspore/profiler/common/util.py +28 -0
  194. mindspore/profiler/dynamic_profiler.py +694 -0
  195. mindspore/profiler/envprofiling.py +17 -19
  196. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  197. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  198. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  199. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  200. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  201. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  202. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  203. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  205. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  206. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  207. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  208. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  209. mindspore/profiler/parser/framework_parser.py +1 -391
  210. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  211. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  212. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  213. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  214. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  215. mindspore/profiler/parser/profiler_info.py +78 -6
  216. mindspore/profiler/profiler.py +153 -0
  217. mindspore/profiler/profiling.py +285 -413
  218. mindspore/rewrite/__init__.py +1 -2
  219. mindspore/rewrite/common/namespace.py +4 -4
  220. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  221. mindspore/run_check/_check_version.py +39 -104
  222. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  223. mindspore/train/__init__.py +4 -3
  224. mindspore/train/_utils.py +105 -19
  225. mindspore/train/amp.py +171 -53
  226. mindspore/train/callback/__init__.py +2 -2
  227. mindspore/train/callback/_callback.py +4 -4
  228. mindspore/train/callback/_checkpoint.py +97 -31
  229. mindspore/train/callback/_cluster_monitor.py +1 -1
  230. mindspore/train/callback/_flops_collector.py +1 -0
  231. mindspore/train/callback/_loss_monitor.py +3 -3
  232. mindspore/train/callback/_on_request_exit.py +145 -31
  233. mindspore/train/callback/_summary_collector.py +5 -5
  234. mindspore/train/callback/_tft_register.py +375 -0
  235. mindspore/train/dataset_helper.py +15 -3
  236. mindspore/train/metrics/metric.py +3 -3
  237. mindspore/train/metrics/roc.py +4 -4
  238. mindspore/train/mind_ir_pb2.py +44 -39
  239. mindspore/train/model.py +154 -58
  240. mindspore/train/serialization.py +342 -128
  241. mindspore/utils/__init__.py +21 -0
  242. mindspore/utils/utils.py +60 -0
  243. mindspore/version.py +1 -1
  244. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  245. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
  246. mindspore/include/c_api/ms/abstract.h +0 -67
  247. mindspore/include/c_api/ms/attribute.h +0 -197
  248. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  249. mindspore/include/c_api/ms/base/macros.h +0 -32
  250. mindspore/include/c_api/ms/base/status.h +0 -33
  251. mindspore/include/c_api/ms/base/types.h +0 -283
  252. mindspore/include/c_api/ms/context.h +0 -102
  253. mindspore/include/c_api/ms/graph.h +0 -160
  254. mindspore/include/c_api/ms/node.h +0 -606
  255. mindspore/include/c_api/ms/tensor.h +0 -161
  256. mindspore/include/c_api/ms/value.h +0 -84
  257. mindspore/mindspore_shared_lib.dll +0 -0
  258. mindspore/nn/extend/basic.py +0 -140
  259. mindspore/nn/extend/embedding.py +0 -143
  260. mindspore/nn/extend/layer/normalization.py +0 -109
  261. mindspore/nn/extend/pooling.py +0 -117
  262. mindspore/nn/layer/embedding_service.py +0 -531
  263. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  264. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  265. mindspore/ops/extend/__init__.py +0 -53
  266. mindspore/ops/extend/array_func.py +0 -218
  267. mindspore/ops/extend/math_func.py +0 -76
  268. mindspore/ops/extend/nn_func.py +0 -308
  269. mindspore/ops/silent_check.py +0 -162
  270. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  271. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  272. mindspore/train/callback/_mindio_ttp.py +0 -443
  273. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
  274. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  275. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -32,7 +32,7 @@ from mindspore import context
32
32
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
33
33
  from mindspore import _checkparam as Validator
34
34
  from mindspore.common import dtype as mstype
35
- from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
35
+ from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, _no_grad
36
36
  from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
37
37
  from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
38
38
  from mindspore.common.parameter import Parameter, ParameterTuple
@@ -45,7 +45,6 @@ from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
45
  from mindspore.common._decorator import deprecated
46
46
  from mindspore.common._register_for_recompute import recompute_registry
47
47
 
48
-
49
48
  class Cell(Cell_):
50
49
  """
51
50
  The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
@@ -101,9 +100,9 @@ class Cell(Cell_):
101
100
  """
102
101
 
103
102
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
104
- '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
105
- '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
106
- '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
103
+ '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
104
+ '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
105
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
107
106
  '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
108
107
  total_instance_count = 0
109
108
 
@@ -135,7 +134,8 @@ class Cell(Cell_):
135
134
  self._id = 1
136
135
  self.exist_names = set("")
137
136
  self.exist_objs = set()
138
- self.recompute_cell = None
137
+ self._recompute_cell = None
138
+ self.mixed_precision_type = None
139
139
  self.sig = inspect.signature(self.construct)
140
140
  init_pipeline()
141
141
 
@@ -146,13 +146,16 @@ class Cell(Cell_):
146
146
  if flags:
147
147
  self.add_flags(**flags)
148
148
  self._bprop_debug = False
149
+
150
+ # hook
149
151
  self._forward_pre_hook = OrderedDict()
150
152
  self._forward_hook = OrderedDict()
151
- self._enable_forward_pre_hook = False
152
- self._enable_forward_hook = False
153
- self._enable_backward_hook = False
153
+ self._backward_pre_hook = OrderedDict()
154
+ self._cell_backward_pre_hook = None
155
+ self._backward_hook = OrderedDict()
154
156
  self._cell_backward_hook = None
155
157
  self._is_recursion_hook = False
158
+
156
159
  self.cell_type = None
157
160
  self.cast = Cast()
158
161
  self._has_config_recompute = False
@@ -166,6 +169,10 @@ class Cell(Cell_):
166
169
  self._is_check_and_refresh = False
167
170
  self._amp_level = ""
168
171
  self._init_flag = False
172
+ self._shard_fn = None
173
+ self.has_bprop = False
174
+ if hasattr(self, "bprop"):
175
+ self.has_bprop = True
169
176
 
170
177
  def __getstate__(self):
171
178
  base = Cell_.__getstate__(self)
@@ -223,8 +230,9 @@ class Cell(Cell_):
223
230
  Get whether cell custom bprop debug is enabled.
224
231
 
225
232
  Tutorial Examples:
226
- - `Cell and Parameter - Custom Cell Reverse
227
- <https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
233
+ - `Custom Neural Network Layers - Custom Cell Reverse
234
+ <https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
235
+ #custom-cell-reverse>`_
228
236
  """
229
237
  return self._bprop_debug
230
238
 
@@ -374,6 +382,10 @@ class Cell(Cell_):
374
382
  def jit_config_dict(self):
375
383
  return self._jit_config_dict
376
384
 
385
+ @property
386
+ def enable_backward_hook(self):
387
+ return self._enable_backward_hook
388
+
377
389
  def get_func_graph_proto(self):
378
390
  """Return graph binary proto."""
379
391
  exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
@@ -401,8 +413,6 @@ class Cell(Cell_):
401
413
  cells_compile_cache.pop(id(self), None)
402
414
  if hasattr(self, "compile_cache") and self.compile_cache:
403
415
  _cell_graph_executor.del_net_res(self, self.compile_cache)
404
- if isinstance(self, GraphCell):
405
- _cell_graph_executor.dec_graph_cell_count()
406
416
  Cell.total_instance_count -= 1
407
417
 
408
418
  def __delattr__(self, name):
@@ -475,21 +485,28 @@ class Cell(Cell_):
475
485
  output = self._run_construct(cast_inputs, kwargs)
476
486
  return output
477
487
 
478
- def _run_construct(self, cast_inputs, kwargs):
488
+ def _run_construct(self, *inputs, **kwargs):
479
489
  """Run the construct function"""
480
- if self._enable_forward_pre_hook:
481
- cast_inputs = self._run_forward_pre_hook(cast_inputs)
482
- if self._enable_backward_hook:
483
- output = self._backward_hook_construct(*cast_inputs, **kwargs)
484
- elif hasattr(self, "_shard_fn"):
485
- output = self._shard_fn(*cast_inputs, **kwargs)
490
+ if self._forward_pre_hook:
491
+ inputs = self._run_forward_pre_hook(inputs)
492
+
493
+ if self._backward_hook:
494
+ output = self._backward_hook_construct(*inputs, **kwargs)
495
+ elif self._shard_fn is not None:
496
+ output = self._shard_fn(*inputs, **kwargs)
497
+ elif self._recompute_cell is not None:
498
+ output = self._recompute_cell(*inputs, **kwargs)
499
+ elif self.has_bprop and _pynative_executor.requires_grad():
500
+ output = self._call_custom_bprop(*inputs, **kwargs)
486
501
  else:
487
- if self.recompute_cell is not None:
488
- output = self.recompute_cell(*cast_inputs, **kwargs)
489
- else:
490
- output = self.construct(*cast_inputs, **kwargs)
491
- if self._enable_forward_hook:
492
- output = self._run_forward_hook(cast_inputs, output)
502
+ output = self.construct(*inputs, **kwargs)
503
+
504
+ if self._forward_hook:
505
+ output = self._run_forward_hook(inputs, output)
506
+
507
+ if self._backward_pre_hook:
508
+ output = self._run_backward_pre_hook(output)
509
+
493
510
  return output
494
511
 
495
512
  def _check_construct_args(self, *args):
@@ -527,7 +544,7 @@ class Cell(Cell_):
527
544
  '''Hook function in graph mode'''
528
545
  # Check super().__init__() in graph mode.
529
546
  try:
530
- if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
547
+ if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
531
548
  return True
532
549
  except AttributeError as e:
533
550
  raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
@@ -579,8 +596,7 @@ class Cell(Cell_):
579
596
  strategy for others will be set by sharding propagation.
580
597
  in_strategy and out_strategy define the input and output layout respectively.
581
598
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
582
- this input/output, and None represents data_parallel,
583
- which can refer to the description of `mindspore.ops.Primitive.shard`.
599
+ this input/output, which can refer to the description of `mindspore.ops.Primitive.shard`.
584
600
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
585
601
 
586
602
  Note:
@@ -589,8 +605,8 @@ class Cell(Cell_):
589
605
  If the input contain Parameter, its strategy should be set in `in_strategy`.
590
606
 
591
607
  Args:
592
- in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
593
- defines the layout of the corresponding input and None represents a data parallel strategy.
608
+ in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
609
+ defines the layout of the corresponding input.
594
610
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
595
611
  It is not in use right now. Default: ``None`` .
596
612
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
@@ -625,7 +641,7 @@ class Cell(Cell_):
625
641
  ... def __init__(self):
626
642
  ... self.block1 = Block()
627
643
  ... self.block2 = Block()
628
- ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
644
+ ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
629
645
  ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
630
646
  ... def construct(self, x):
631
647
  ... x = self.block1(x)
@@ -638,7 +654,7 @@ class Cell(Cell_):
638
654
 
639
655
  shard_fn = Shard()
640
656
  fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
641
- object.__setattr__(self, "_shard_fn", fn)
657
+ self._shard_fn = fn
642
658
  return fn
643
659
 
644
660
  def auto_cast_inputs(self, inputs):
@@ -666,6 +682,7 @@ class Cell(Cell_):
666
682
  for param in self.get_parameters(expand=False):
667
683
  if param.has_init:
668
684
  param.init_data()
685
+ self._init_flag = True
669
686
 
670
687
  def _self_check(self):
671
688
  if not self._is_check_and_refresh:
@@ -684,7 +701,7 @@ class Cell(Cell_):
684
701
 
685
702
  def __call__(self, *args, **kwargs):
686
703
  # Run in Graph mode.
687
- if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
704
+ if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
688
705
  if kwargs:
689
706
  bound_arguments = self.sig.bind(*args, **kwargs)
690
707
  bound_arguments.apply_defaults()
@@ -704,22 +721,69 @@ class Cell(Cell_):
704
721
  return out
705
722
 
706
723
  # Run in PyNative mode.
707
- self._self_check()
708
- if not self._init_flag:
724
+ if not (self._init_flag or self._is_check_and_refresh):
709
725
  self._init_check()
710
- self._init_flag = True
726
+ self._self_check()
727
+
728
+ if not (self.requires_grad or self._dynamic_shape_inputs or self.mixed_precision_type):
729
+ if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
730
+ self._shard_fn or self._recompute_cell or (self.has_bprop and _pynative_executor.requires_grad())):
731
+ return self.construct(*args, **kwargs)
732
+
733
+ return self._run_construct(*args, **kwargs)
734
+
735
+ return self._complex_call(*args, **kwargs)
711
736
 
737
+ def _complex_call(self, *args, **kwargs):
738
+ """
739
+ PyNative call with requires_grad or hooks
740
+ """
741
+ self._call_pre_process(*args, **kwargs)
742
+
743
+ if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
744
+ self._shard_fn or self._recompute_cell or self.has_bprop):
745
+ output = self.construct(*args, **kwargs)
746
+ else:
747
+ output = self._run_construct(*args, **kwargs)
748
+
749
+ self._call_post_process(output, *args, **kwargs)
750
+
751
+ return output
752
+
753
+ def _call_pre_process(self, *args, **kwargs):
754
+ """
755
+ Process cell info before call construct
756
+ """
712
757
  if self.requires_grad:
713
758
  _pynative_executor.set_grad_flag(True)
714
-
715
- try:
716
759
  _pynative_executor.new_graph(self, *args, **kwargs)
717
- output = self._run_construct(args, kwargs)
760
+ elif self._dynamic_shape_inputs is not None:
761
+ _pynative_executor.set_cell_use_dynamic_shape_process(True)
762
+
763
+ # Set mixed precision
764
+ if self.mixed_precision_type is not None:
765
+ _pynative_executor.set_mixed_precision_type(self.mixed_precision_type)
766
+
767
+ def _call_post_process(self, output, *args, **kwargs):
768
+ """
769
+ Process cell info after call construct
770
+ """
771
+ if self.requires_grad:
718
772
  _pynative_executor.end_graph(self, output, *args, **kwargs)
719
- except Exception as err:
720
- _pynative_executor.clear_res()
721
- raise err
773
+ elif self._dynamic_shape_inputs is not None:
774
+ _pynative_executor.set_cell_use_dynamic_shape_process(False)
775
+
776
+ # mixed precision reset
777
+ if self.mixed_precision_type is not None:
778
+ _pynative_executor.set_mixed_precision_type(MixedPrecisionType.NOTSET, False)
722
779
 
780
+ def _call_custom_bprop(self, *args, **kwargs):
781
+ """
782
+ Call custom bprop for cell bprop.
783
+ """
784
+ with _no_grad():
785
+ output = self.construct(*args, **kwargs)
786
+ _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
723
787
  return output
724
788
 
725
789
  def _add_attr(self, name, value):
@@ -961,9 +1025,12 @@ class Cell(Cell_):
961
1025
 
962
1026
  if not kwargs:
963
1027
  self._dynamic_shape_inputs = inputs
964
- self._check_construct_args(*inputs)
965
1028
  if context._get_mode() == context.PYNATIVE_MODE:
966
1029
  _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1030
+ else:
1031
+ self._check_construct_args(*inputs)
1032
+ # TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
1033
+ # which means that incremental mode is lacking dynamic input.
967
1034
  else:
968
1035
  self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
969
1036
 
@@ -1682,10 +1749,13 @@ class Cell(Cell_):
1682
1749
  def _add_mixed_precision_flag(self, **flags):
1683
1750
  """Add mixed precision flag to current cell"""
1684
1751
  if "fp16" in flags and flags.get("fp16", False):
1752
+ self.mixed_precision_type = MixedPrecisionType.FP16
1685
1753
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1686
1754
  if "fp32" in flags and flags.get("fp32", False):
1755
+ self.mixed_precision_type = MixedPrecisionType.FP32
1687
1756
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1688
1757
  if "bf16" in flags and flags.get("bf16", False):
1758
+ self.mixed_precision_type = MixedPrecisionType.BF16
1689
1759
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1690
1760
 
1691
1761
  def apply(self, fn):
@@ -1750,9 +1820,6 @@ class Cell(Cell_):
1750
1820
  if not hasattr(self, "_func_graph_flags"):
1751
1821
  self._func_graph_flags = {}
1752
1822
  self._func_graph_flags.update({**flags})
1753
- if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
1754
- raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
1755
- "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
1756
1823
  self.__dict__.update({**flags})
1757
1824
  self._add_mixed_precision_flag(**flags)
1758
1825
  return self
@@ -2050,15 +2117,12 @@ class Cell(Cell_):
2050
2117
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2051
2118
  value= [ 2.00000000e+00]))
2052
2119
  """
2120
+ if context._get_mode() == context.GRAPH_MODE:
2121
+ return HookHandle()
2053
2122
  if not check_hook_fn("register_forward_pre_hook", hook_fn):
2054
2123
  return HookHandle()
2055
- self._enable_forward_pre_hook = True
2056
- _pynative_executor.set_hook_changed(self)
2057
- if not hasattr(self, '_forward_pre_hook_key'):
2058
- self._forward_pre_hook_key = -1
2059
- self._forward_pre_hook_key += 1
2060
- self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
2061
- handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
2124
+ handle = HookHandle(self._forward_pre_hook)
2125
+ self._forward_pre_hook[handle.handle_id] = hook_fn
2062
2126
  return handle
2063
2127
 
2064
2128
  def _run_forward_pre_hook(self, inputs):
@@ -2074,14 +2138,23 @@ class Cell(Cell_):
2074
2138
  Supported Platforms:
2075
2139
  ``Ascend`` ``GPU`` ``CPU``
2076
2140
  """
2141
+ forward_pre_hook_inputs = inputs
2077
2142
  for fn in self._forward_pre_hook.values():
2078
- ret = fn(self, inputs)
2143
+ ret = fn(self, forward_pre_hook_inputs)
2079
2144
  if ret is not None:
2080
2145
  if not isinstance(ret, tuple):
2081
- inputs = (ret,)
2146
+ forward_pre_hook_inputs = (ret,)
2082
2147
  else:
2083
- inputs = ret
2084
- return inputs
2148
+ forward_pre_hook_inputs = ret
2149
+
2150
+ if isinstance(inputs, tuple):
2151
+ if not isinstance(forward_pre_hook_inputs, tuple):
2152
+ forward_pre_hook_inputs = (forward_pre_hook_inputs,)
2153
+ if len(forward_pre_hook_inputs) != len(inputs):
2154
+ raise TypeError(
2155
+ "The forward pre hook return value size is {} not equal to input size {}".format(
2156
+ len(forward_pre_hook_inputs), len(inputs)))
2157
+ return forward_pre_hook_inputs
2085
2158
 
2086
2159
  def register_forward_hook(self, hook_fn):
2087
2160
  """
@@ -2142,15 +2215,12 @@ class Cell(Cell_):
2142
2215
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2143
2216
  value= [ 2.00000000e+00]))
2144
2217
  """
2218
+ if context._get_mode() == context.GRAPH_MODE:
2219
+ return HookHandle()
2145
2220
  if not check_hook_fn("register_forward_hook", hook_fn):
2146
2221
  return HookHandle()
2147
- self._enable_forward_hook = True
2148
- _pynative_executor.set_hook_changed(self)
2149
- if not hasattr(self, '_forward_hook_key'):
2150
- self._forward_hook_key = -1
2151
- self._forward_hook_key += 1
2152
- self._forward_hook[self._forward_hook_key] = hook_fn
2153
- handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
2222
+ handle = HookHandle(self._forward_hook)
2223
+ self._forward_hook[handle.handle_id] = hook_fn
2154
2224
  return handle
2155
2225
 
2156
2226
  def _run_forward_hook(self, inputs, output):
@@ -2167,11 +2237,110 @@ class Cell(Cell_):
2167
2237
  Supported Platforms:
2168
2238
  ``Ascend`` ``GPU`` ``CPU``
2169
2239
  """
2240
+ forward_hook_output = output
2170
2241
  for fn in self._forward_hook.values():
2171
- ret = fn(self, inputs, output)
2242
+ ret = fn(self, inputs, forward_hook_output)
2172
2243
  if ret is not None:
2173
- output = ret
2174
- return output
2244
+ forward_hook_output = ret
2245
+
2246
+ if isinstance(output, tuple):
2247
+ if not isinstance(forward_hook_output, tuple):
2248
+ forward_hook_output = (forward_hook_output,)
2249
+ if len(forward_hook_output) != len(output):
2250
+ raise TypeError(
2251
+ "The forward hook return value size is {} not equal to output size {}".format(
2252
+ len(forward_hook_output), len(output)))
2253
+ return forward_hook_output
2254
+
2255
+ def register_backward_pre_hook(self, hook_fn):
2256
+ """
2257
+ Register the backward pre hook function.
2258
+
2259
+ Note:
2260
+ - The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2261
+ - The 'hook_fn' must be defined as the following code.
2262
+ `cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
2263
+ - The 'hook_fn' should have the following signature:
2264
+ hook_fn(cell, grad_output) -> New grad_output gradient or None.
2265
+ - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2266
+ graph mode, it is not recommended to write it in the `construct` function of Cell object.
2267
+ - In the pynative
2268
+ mode, if the `register_backward_pre_hook` function is called in the `construct` function of the Cell
2269
+ object, a hook function will be added at each run time of Cell object.
2270
+
2271
+ Args:
2272
+ hook_fn (function): Python function. Backward pre hook function.
2273
+
2274
+ Returns:
2275
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2276
+ `handle.remove()` .
2277
+
2278
+ Raises:
2279
+ TypeError: If the `hook_fn` is not a function of python.
2280
+
2281
+ Supported Platforms:
2282
+ ``Ascend`` ``GPU`` ``CPU``
2283
+
2284
+ Examples:
2285
+ >>> import numpy as np
2286
+ >>> import mindspore as ms
2287
+ >>> from mindspore import Tensor, nn, ops
2288
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2289
+ >>> def backward_pre_hook_fn(cell, grad_output):
2290
+ ... print("backward input: ", grad_output)
2291
+ ...
2292
+ >>> class Net(nn.Cell):
2293
+ ... def __init__(self):
2294
+ ... super(Net, self).__init__()
2295
+ ... self.relu = nn.ReLU()
2296
+ ... self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn)
2297
+ ...
2298
+ ... def construct(self, x):
2299
+ ... x = x + x
2300
+ ... x = self.relu(x)
2301
+ ... return x
2302
+ >>> grad = ops.GradOperation(get_all=True)
2303
+ >>> net = Net()
2304
+ >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
2305
+ backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2306
+ >>> print(output)
2307
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2308
+ """
2309
+ if context._get_mode() == context.GRAPH_MODE:
2310
+ return HookHandle()
2311
+ if not check_hook_fn("register_backward_pre_hook", hook_fn):
2312
+ return HookHandle()
2313
+ handle = HookHandle(self._backward_pre_hook)
2314
+ self._backward_pre_hook[handle.handle_id] = hook_fn
2315
+ if self._cell_backward_pre_hook is None:
2316
+ # Generate a CellBackwardHook prim, and add function for it
2317
+ self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2318
+ self, self._backward_pre_hook)
2319
+ self._cell_backward_pre_hook.register_backward_pre_hook()
2320
+ return handle
2321
+
2322
+ def _run_backward_pre_hook(self, outputs):
2323
+ """
2324
+ Running backward pre hook function registered on Cell object.
2325
+
2326
+ Args:
2327
+ outputs: The output objects of cell object.
2328
+
2329
+ Returns:
2330
+ - **outputs** - New backward gradient or None.
2331
+
2332
+ Supported Platforms:
2333
+ ``Ascend`` ``GPU`` ``CPU``
2334
+ """
2335
+ ret = self._cell_backward_pre_hook(outputs)
2336
+ if isinstance(outputs, tuple):
2337
+ if not isinstance(ret, tuple):
2338
+ ret = (ret,)
2339
+ if len(ret) != len(outputs):
2340
+ raise TypeError(
2341
+ "The backward pre hook return value size is {} not equal to output size {}".format(
2342
+ len(ret), len(outputs)))
2343
+ return ret
2175
2344
 
2176
2345
  def register_backward_hook(self, hook_fn):
2177
2346
  """
@@ -2180,11 +2349,11 @@ class Cell(Cell_):
2180
2349
  Note:
2181
2350
  - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2182
2351
  - The 'hook_fn' must be defined as the following code.
2183
- `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
2184
- gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
2185
- primitive, which may be modified by returning a new output gradient.
2352
+ `cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
2353
+ the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
2354
+ passed to the Cell.
2186
2355
  - The 'hook_fn' should have the following signature:
2187
- hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
2356
+ hook_fn(cell, grad_input, grad_output) -> New grad_input gradient or none.
2188
2357
  - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2189
2358
  graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
2190
2359
  mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
@@ -2208,9 +2377,9 @@ class Cell(Cell_):
2208
2377
  >>> import mindspore as ms
2209
2378
  >>> from mindspore import Tensor, nn, ops
2210
2379
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2211
- >>> def backward_hook_fn(cell_id, grad_input, grad_output):
2212
- ... print("backward input: ", grad_input)
2213
- ... print("backward output: ", grad_output)
2380
+ >>> def backward_hook_fn(cell, grad_input, grad_output):
2381
+ ... print("backward input: ", grad_output)
2382
+ ... print("backward output: ", grad_input)
2214
2383
  ...
2215
2384
  >>> class Net(nn.Cell):
2216
2385
  ... def __init__(self):
@@ -2230,16 +2399,17 @@ class Cell(Cell_):
2230
2399
  >>> print(output)
2231
2400
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2232
2401
  """
2402
+ if context._get_mode() == context.GRAPH_MODE:
2403
+ return HookHandle()
2233
2404
  if not check_hook_fn("register_backward_hook", hook_fn):
2234
2405
  return HookHandle()
2406
+ handle = HookHandle(self._backward_hook)
2407
+ self._backward_hook[handle.handle_id] = hook_fn
2235
2408
  if self._cell_backward_hook is None:
2236
- self._enable_backward_hook = True
2237
- self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
2238
- backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2239
- handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2240
- else:
2241
- backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2242
- handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2409
+ # Generate a CellBackwardHook prim, and add function for it
2410
+ self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2411
+ self, self._backward_hook)
2412
+ self._cell_backward_hook.register_backward_hook()
2243
2413
  return handle
2244
2414
 
2245
2415
  def _backward_hook_construct(self, *inputs, **kwargs):
@@ -2256,21 +2426,31 @@ class Cell(Cell_):
2256
2426
  Supported Platforms:
2257
2427
  ``Ascend`` ``GPU`` ``CPU``
2258
2428
  """
2259
- if len(inputs) > 1:
2260
- inputs = self._cell_backward_hook(inputs)
2261
- else:
2262
- inputs = self._cell_backward_hook(*inputs)
2263
- inputs = (inputs,)
2264
- if self.recompute_cell is not None:
2265
- if isinstance(inputs, tuple):
2266
- outputs = self.recompute_cell(*inputs, **kwargs)
2429
+ # cell_backward_hook has CellBackwardHook op, so keep input args as they are.
2430
+ outputs = self._cell_backward_hook(*inputs)
2431
+ # If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
2432
+ # a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
2433
+ # Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
2434
+ is_need_unwrap = False
2435
+ if isinstance(outputs, tuple) and len(inputs) != 1:
2436
+ is_need_unwrap = True
2437
+
2438
+ if self._recompute_cell is not None:
2439
+ if is_need_unwrap:
2440
+ outputs = self._recompute_cell(*outputs, **kwargs)
2441
+ else:
2442
+ outputs = self._recompute_cell(outputs, **kwargs)
2443
+ elif self.has_bprop:
2444
+ if is_need_unwrap:
2445
+ outputs = self._call_custom_bprop(*outputs, **kwargs)
2267
2446
  else:
2268
- outputs = self.recompute_cell(inputs, **kwargs)
2447
+ outputs = self._call_custom_bprop(outputs, **kwargs)
2269
2448
  else:
2270
- if isinstance(inputs, tuple):
2271
- outputs = self.construct(*inputs, **kwargs)
2449
+ if is_need_unwrap:
2450
+ outputs = self.construct(*outputs, **kwargs)
2272
2451
  else:
2273
- outputs = self.construct(inputs, **kwargs)
2452
+ outputs = self.construct(outputs, **kwargs)
2453
+
2274
2454
  outputs = self._cell_backward_hook(outputs)
2275
2455
  return outputs
2276
2456
 
@@ -2401,7 +2581,8 @@ class Cell(Cell_):
2401
2581
  Default: ``False`` .
2402
2582
  """
2403
2583
  if context.get_context("mode") == context.PYNATIVE_MODE:
2404
- self.recompute_cell = recompute_registry.get()(self.construct)
2584
+ self._recompute_cell = recompute_registry.get()(self.construct)
2585
+ self._recompute()
2405
2586
  return
2406
2587
  self._recompute()
2407
2588
  if 'mp_comm_recompute' in kwargs.keys():
@@ -2579,7 +2760,6 @@ class GraphCell(Cell):
2579
2760
  params_dict = update_func_graph_hyper_params(self.graph, params_init)
2580
2761
  for name, param in params_dict.items():
2581
2762
  self._params[name] = param
2582
- _cell_graph_executor.inc_graph_cell_count()
2583
2763
 
2584
2764
  def construct(self, *inputs):
2585
2765
  return self.graph(*inputs)