mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (196) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  9. mindspore/_extends/parse/__init__.py +3 -3
  10. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -0
  11. mindspore/_extends/parse/parser.py +22 -28
  12. mindspore/_extends/parse/standard_method.py +1 -15
  13. mindspore/_extends/pijit/pijit_func_white_list.py +5 -2
  14. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  15. mindspore/amp.py +18 -0
  16. mindspore/avcodec-59.dll +0 -0
  17. mindspore/avdevice-59.dll +0 -0
  18. mindspore/avfilter-8.dll +0 -0
  19. mindspore/avformat-59.dll +0 -0
  20. mindspore/avutil-57.dll +0 -0
  21. mindspore/common/__init__.py +12 -18
  22. mindspore/common/_tensor_cpp_method.py +1 -1
  23. mindspore/common/_tensor_docs.py +38 -102
  24. mindspore/common/_utils.py +1 -9
  25. mindspore/common/api.py +106 -155
  26. mindspore/common/{dynamic_shape/auto_dynamic_shape.py → auto_dynamic_shape.py} +23 -17
  27. mindspore/common/dtype.py +57 -98
  28. mindspore/common/dump.py +1 -1
  29. mindspore/common/file_system.py +9 -59
  30. mindspore/common/hook_handle.py +3 -22
  31. mindspore/common/np_dtype.py +3 -3
  32. mindspore/common/parameter.py +20 -4
  33. mindspore/common/recompute.py +4 -2
  34. mindspore/common/tensor.py +52 -38
  35. mindspore/communication/_hccl_management.py +297 -0
  36. mindspore/context.py +21 -15
  37. mindspore/dataset/__init__.py +1 -1
  38. mindspore/dataset/audio/transforms.py +1 -1
  39. mindspore/dataset/core/config.py +1 -35
  40. mindspore/dataset/engine/datasets.py +315 -330
  41. mindspore/dataset/engine/datasets_user_defined.py +22 -38
  42. mindspore/dataset/transforms/c_transforms.py +2 -2
  43. mindspore/dataset/transforms/transforms.py +3 -3
  44. mindspore/dataset/vision/__init__.py +1 -1
  45. mindspore/dataset/vision/py_transforms.py +8 -8
  46. mindspore/dataset/vision/transforms.py +5 -17
  47. mindspore/dataset/vision/utils.py +21 -632
  48. mindspore/device_context/ascend/op_tuning.py +1 -35
  49. mindspore/dnnl.dll +0 -0
  50. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -3
  51. mindspore/include/api/cell.h +4 -28
  52. mindspore/include/api/cfg.h +7 -24
  53. mindspore/include/api/context.h +0 -1
  54. mindspore/include/api/delegate.h +2 -0
  55. mindspore/include/api/dual_abi_helper.h +19 -100
  56. mindspore/include/api/graph.h +1 -14
  57. mindspore/include/api/kernel.h +3 -16
  58. mindspore/include/api/kernel_api.h +1 -9
  59. mindspore/include/api/metrics/accuracy.h +0 -9
  60. mindspore/include/api/model.h +1 -5
  61. mindspore/include/api/model_group.h +0 -4
  62. mindspore/include/api/model_parallel_runner.h +0 -2
  63. mindspore/include/api/status.h +10 -48
  64. mindspore/include/api/types.h +1 -6
  65. mindspore/include/dataset/constants.h +0 -9
  66. mindspore/jpeg62.dll +0 -0
  67. mindspore/mindrecord/tools/cifar10.py +2 -3
  68. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -5
  69. mindspore/mindspore_backend_common.dll +0 -0
  70. mindspore/mindspore_backend_manager.dll +0 -0
  71. mindspore/mindspore_common.dll +0 -0
  72. mindspore/mindspore_core.dll +0 -0
  73. mindspore/mindspore_cpu_res_manager.dll +0 -0
  74. mindspore/mindspore_dump.dll +0 -0
  75. mindspore/mindspore_frontend.dll +0 -0
  76. mindspore/mindspore_glog.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/mindspore_ops_host.dll +0 -0
  81. mindspore/mindspore_ops_kernel_common.dll +0 -0
  82. mindspore/mindspore_profiler.dll +0 -0
  83. mindspore/mindspore_pyboost.dll +0 -0
  84. mindspore/mindspore_pynative.dll +0 -0
  85. mindspore/mindspore_res_manager.dll +0 -0
  86. mindspore/mindspore_runtime_pipeline.dll +0 -0
  87. mindspore/mint/distributed/__init__.py +0 -4
  88. mindspore/mint/distributed/distributed.py +14 -217
  89. mindspore/mint/nn/layer/_functions.py +2 -1
  90. mindspore/mint/nn/layer/conv.py +6 -6
  91. mindspore/mint/nn/layer/normalization.py +3 -3
  92. mindspore/nn/cell.py +174 -216
  93. mindspore/nn/layer/activation.py +2 -4
  94. mindspore/nn/layer/basic.py +13 -7
  95. mindspore/nn/layer/image.py +1 -1
  96. mindspore/nn/optim/adam.py +3 -1
  97. mindspore/nn/optim/lamb.py +3 -1
  98. mindspore/nn/optim/tft_wrapper.py +3 -2
  99. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  100. mindspore/nn/wrap/cell_wrapper.py +5 -39
  101. mindspore/nn/wrap/grad_reducer.py +15 -0
  102. mindspore/numpy/array_creations.py +2 -2
  103. mindspore/numpy/utils_const.py +1 -1
  104. mindspore/opencv_core452.dll +0 -0
  105. mindspore/opencv_imgcodecs452.dll +0 -0
  106. mindspore/opencv_imgproc452.dll +0 -0
  107. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  108. mindspore/ops/_op_impl/cpu/__init__.py +0 -1
  109. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +2 -12
  110. mindspore/ops/auto_generate/gen_extend_func.py +4 -4
  111. mindspore/ops/auto_generate/gen_ops_def.py +16 -290
  112. mindspore/ops/auto_generate/gen_ops_prim.py +76 -563
  113. mindspore/ops/composite/base.py +1 -1
  114. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  115. mindspore/ops/function/__init__.py +0 -1
  116. mindspore/ops/function/array_func.py +6 -10
  117. mindspore/ops/function/debug_func.py +2 -4
  118. mindspore/ops/function/grad/grad_func.py +12 -4
  119. mindspore/ops/function/math_func.py +32 -44
  120. mindspore/ops/function/nn_func.py +20 -18
  121. mindspore/ops/functional.py +1 -2
  122. mindspore/ops/functional_overload.py +12 -23
  123. mindspore/ops/operations/_inner_ops.py +12 -11
  124. mindspore/ops/operations/array_ops.py +50 -4
  125. mindspore/ops/operations/comm_ops.py +15 -1
  126. mindspore/ops/operations/custom_ops.py +4 -10
  127. mindspore/ops/operations/debug_ops.py +6 -6
  128. mindspore/ops/operations/manually_defined/ops_def.py +12 -12
  129. mindspore/ops/operations/math_ops.py +5 -5
  130. mindspore/ops/operations/nn_ops.py +1 -1
  131. mindspore/ops/primitive.py +10 -3
  132. mindspore/ops/tensor_method.py +7 -16
  133. mindspore/ops_generate/pyboost/gen_pyboost_func.py +16 -0
  134. mindspore/parallel/_auto_parallel_context.py +15 -5
  135. mindspore/parallel/_parallel_serialization.py +2 -3
  136. mindspore/parallel/_ps_context.py +2 -2
  137. mindspore/parallel/_transformer/transformer.py +4 -4
  138. mindspore/parallel/_utils.py +11 -5
  139. mindspore/parallel/auto_parallel.py +9 -23
  140. mindspore/parallel/checkpoint_transform.py +0 -2
  141. mindspore/parallel/cluster/process_entity/_api.py +1 -4
  142. mindspore/parallel/cluster/run.py +3 -5
  143. mindspore/parallel/function/reshard_func.py +5 -6
  144. mindspore/parallel/nn/parallel_cell_wrapper.py +3 -40
  145. mindspore/parallel/nn/parallel_grad_reducer.py +8 -0
  146. mindspore/parallel/shard.py +21 -7
  147. mindspore/parallel/transform_safetensors.py +4 -10
  148. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +9 -10
  149. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  150. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  151. mindspore/profiler/common/path_manager.py +0 -9
  152. mindspore/profiler/common/profiler_context.py +2 -25
  153. mindspore/profiler/common/profiler_meta_data.py +0 -1
  154. mindspore/profiler/common/profiler_op_analyse.py +6 -10
  155. mindspore/{ops/_op_impl/cpu/joinedstr_op.py → profiler/common/validator/__init__.py} +1 -15
  156. mindspore/profiler/common/validator/validate_path.py +84 -0
  157. mindspore/profiler/dynamic_profiler.py +46 -91
  158. mindspore/profiler/envprofiler.py +5 -30
  159. mindspore/profiler/experimental_config.py +1 -16
  160. mindspore/profiler/platform/cpu_profiler.py +4 -10
  161. mindspore/profiler/platform/npu_profiler.py +1 -1
  162. mindspore/profiler/profiler.py +145 -193
  163. mindspore/profiler/profiler_action_controller.py +1 -1
  164. mindspore/profiler/profiler_interface.py +2 -2
  165. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  166. mindspore/runtime/__init__.py +4 -6
  167. mindspore/runtime/executor.py +0 -27
  168. mindspore/runtime/memory.py +0 -1
  169. mindspore/runtime/thread_bind_core.py +1 -1
  170. mindspore/swresample-4.dll +0 -0
  171. mindspore/swscale-6.dll +0 -0
  172. mindspore/tinyxml2.dll +0 -0
  173. mindspore/train/_utils.py +3 -3
  174. mindspore/train/amp.py +3 -0
  175. mindspore/train/callback/_callback.py +1 -2
  176. mindspore/train/callback/_checkpoint.py +8 -1
  177. mindspore/train/callback/_flops_collector.py +6 -10
  178. mindspore/train/callback/_train_fault_tolerance.py +7 -3
  179. mindspore/train/data_sink.py +4 -4
  180. mindspore/train/dataset_helper.py +5 -5
  181. mindspore/train/model.py +20 -4
  182. mindspore/train/serialization.py +15 -35
  183. mindspore/train/train_thor/model_thor.py +2 -2
  184. mindspore/turbojpeg.dll +0 -0
  185. mindspore/utils/hooks.py +81 -0
  186. mindspore/utils/utils.py +8 -8
  187. mindspore/version.py +1 -1
  188. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +1 -1
  189. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +193 -192
  190. mindspore/_extends/parallel_compile/akg_compiler/custom.py +0 -1109
  191. mindspore/common/dynamic_shape/__init__.py +0 -0
  192. mindspore/common/dynamic_shape/enable_dynamic.py +0 -197
  193. /mindspore/common/{dynamic_shape/_auto_dynamic.py → _auto_dynamic.py} +0 -0
  194. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
  195. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
  196. {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -39,11 +39,10 @@ from typing import (
39
39
 
40
40
  import weakref
41
41
  import mindspore as ms
42
- import mindspore.ops as ops
43
42
  from mindspore._checkparam import args_type_check, check_hook_fn
44
- from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
43
+ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
45
44
  from mindspore import log as logger
46
- from mindspore.common.hook_handle import HookHandle, _update_hook_version
45
+ from mindspore.common.hook_handle import HookHandle
47
46
  from mindspore import context
48
47
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
49
48
  from mindspore import _checkparam as Validator
@@ -93,8 +92,9 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
93
92
  A handle that can be used to remove the added hook by calling
94
93
  `handle.remove()`.
95
94
  """
96
- handle = HookHandle(_global_buffer_registration_hooks)
97
- _global_buffer_registration_hooks[handle.handle_id] = hook
95
+ from mindspore.utils.hooks import _RemovableHandle
96
+ handle = _RemovableHandle(_global_buffer_registration_hooks)
97
+ _global_buffer_registration_hooks[handle.id] = hook
98
98
  return handle
99
99
 
100
100
 
@@ -155,8 +155,7 @@ class Cell(Cell_):
155
155
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
156
156
  '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
157
157
  '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
158
- '_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
159
- 'requires_grad', 'cell_type', '_in_strategy', '_out_strategy']
158
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix', 'requires_grad', 'cell_type']
160
159
  total_instance_count = 0
161
160
  _buffers: Dict[str, Optional[Tensor]]
162
161
  global_cells = weakref.WeakKeyDictionary()
@@ -207,7 +206,6 @@ class Cell(Cell_):
207
206
  super().__setattr__("mixed_precision_type", None)
208
207
  super().__setattr__("_lazy_construct_sig", None)
209
208
  super().__setattr__("_jit_graph_name", '')
210
- super().__setattr__("_compiled", False)
211
209
  init_pipeline()
212
210
 
213
211
  # call gc to release GE session resources used by non-used cell objects
@@ -241,8 +239,6 @@ class Cell(Cell_):
241
239
  super().__setattr__("_amp_level", "")
242
240
  super().__setattr__("_init_flag", False)
243
241
  super().__setattr__("_shard_fn", None)
244
- super().__setattr__("_in_strategy", None)
245
- super().__setattr__("_out_strategy", None)
246
242
  super().__setattr__("has_bprop", False)
247
243
  if hasattr(self, "bprop"):
248
244
  super().__setattr__("has_bprop", True)
@@ -430,13 +426,6 @@ class Cell(Cell_):
430
426
  """
431
427
  return self._bprop_debug
432
428
 
433
- @property
434
- def compiled(self):
435
- """
436
- Get whether `Cell` is compiled in graph mode.
437
- """
438
- return self._compiled
439
-
440
429
  @bprop_debug.setter
441
430
  def bprop_debug(self, value):
442
431
  """
@@ -557,23 +546,10 @@ class Cell(Cell_):
557
546
 
558
547
  @property
559
548
  def pipeline_segment(self):
560
- """
561
- `pipeline_segment` represents the pipeline segment of current Cell.
562
- """
563
549
  return self._pipeline_segment
564
550
 
565
551
  @pipeline_segment.setter
566
552
  def pipeline_segment(self, value):
567
- """
568
- Set the `pipeline_segment` of a Cell. Only effective in zero_bubble_v scheduler.
569
-
570
- Args:
571
- value (int): The pipeline segment of a parameter.
572
-
573
- Raises:
574
- TypeError: If `value` is not int type or is a bool type.
575
- ValueError: If `value` is not a positive integer.
576
- """
577
553
  if not isinstance(value, int) or isinstance(value, bool):
578
554
  raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
579
555
  "must be int type, but got type : {}".format(type(value)))
@@ -1051,13 +1027,12 @@ class Cell(Cell_):
1051
1027
  if self._forward_pre_hook:
1052
1028
  args, kwargs = self._run_forward_pre_hook(args, kwargs)
1053
1029
 
1054
- if self._backward_hook:
1055
- args = self._cell_backward_hook(args)
1056
-
1057
1030
  if self._shard_fn is not None:
1058
1031
  output = self._shard_fn(*args, **kwargs)
1059
1032
  elif _pynative_executor.requires_grad():
1060
- if self._recompute_cell is not None:
1033
+ if self._backward_hook:
1034
+ output = self._backward_hook_construct(*args, **kwargs)
1035
+ elif self._recompute_cell is not None:
1061
1036
  output = self._recompute_cell(*args, **kwargs)
1062
1037
  elif self.has_bprop:
1063
1038
  output = self._call_custom_bprop(*args, **kwargs)
@@ -1069,11 +1044,8 @@ class Cell(Cell_):
1069
1044
  if self._forward_hook:
1070
1045
  output = self._run_forward_hook(args, kwargs, output)
1071
1046
 
1072
- if self._backward_hook:
1073
- output = self._cell_backward_hook(output)
1074
-
1075
- if self._backward_pre_hook:
1076
- output = self._cell_backward_pre_hook(output)
1047
+ if self._backward_pre_hook and _pynative_executor.requires_grad():
1048
+ output = self._run_backward_pre_hook(output)
1077
1049
 
1078
1050
  return output
1079
1051
 
@@ -1108,6 +1080,23 @@ class Cell(Cell_):
1108
1080
  f"{default_args} default argument, total {positional_args + default_args}, "
1109
1081
  f"but got {len(args)}.")
1110
1082
 
1083
+ # pylint: disable=E0203
1084
+ def _hook_fn_registered(self):
1085
+ '''Hook function in graph mode'''
1086
+ # Check super().__init__() in graph mode.
1087
+ try:
1088
+ if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
1089
+ return True
1090
+ except AttributeError as e:
1091
+ raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
1092
+ f"Please use 'super().__init__()'.") from e
1093
+ if not self._is_recursion_hook:
1094
+ self._is_recursion_hook = True
1095
+ for cell in self.cells():
1096
+ if cell._hook_fn_registered():
1097
+ return True
1098
+ return False
1099
+
1111
1100
  def _get_prims_recursively(self):
1112
1101
  all_prims = list()
1113
1102
  for _, value in self._primitives.items():
@@ -1133,6 +1122,9 @@ class Cell(Cell_):
1133
1122
  >>> net = nn.Dense(3, 4)
1134
1123
  >>> net.set_data_parallel()
1135
1124
  """
1125
+ if context._get_mode() == context.PYNATIVE_MODE:
1126
+ raise ValueError("set_data_parallel: does not support PyNative mode.")
1127
+
1136
1128
  all_prims = self._get_prims_recursively()
1137
1129
  for prim in all_prims:
1138
1130
  prim.add_prim_attr("strategy_gen_mode", "data_parallel")
@@ -1211,6 +1203,8 @@ class Cell(Cell_):
1211
1203
  ... out = self.blocks[i](out)
1212
1204
  ... return out
1213
1205
  """
1206
+ if context._get_mode() == context.PYNATIVE_MODE:
1207
+ raise ValueError("The Cell offload does not support PyNative mode now.")
1214
1208
  if isinstance(backward_prefetch, str):
1215
1209
  Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
1216
1210
  else:
@@ -1218,10 +1212,11 @@ class Cell(Cell_):
1218
1212
  for prim in self._get_prims_recursively():
1219
1213
  prim._offload(backward_prefetch=backward_prefetch)
1220
1214
 
1221
- def shard(self, in_strategy, out_strategy=None, parameter_plan=None):
1215
+ def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
1222
1216
  """
1223
1217
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
1224
- generated by sharding propagation. In Graph mode, use this method to specify distribution strategy for a Cell,
1218
+ generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
1219
+ execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
1225
1220
  strategy for others will be set by sharding propagation.
1226
1221
  in_strategy and out_strategy define the input and output layout respectively.
1227
1222
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
@@ -1233,14 +1228,11 @@ class Cell(Cell_):
1233
1228
  In other parallel modes, strategies set here will be ignored.
1234
1229
  - If the input contain Parameter, its strategy should be set in `in_strategy`.
1235
1230
 
1236
- .. warning::
1237
- The method is currently not supported in PyNative mode.
1238
-
1239
1231
  Args:
1240
1232
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
1241
1233
  defines the layout of the corresponding input.
1242
1234
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
1243
- Default: ``None`` .
1235
+ It is not in use right now. Default: ``None`` .
1244
1236
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
1245
1237
  defines the layout of the parameter like "param_name: layout".
1246
1238
  The key is a parameter name of type 'str'.
@@ -1248,6 +1240,14 @@ class Cell(Cell_):
1248
1240
  If the parameter name is incorrect or the corresponding parameter
1249
1241
  has been set, the parameter setting will be ignored.
1250
1242
  Default: ``None`` .
1243
+ device (str): Select a certain device target. It is not in use right now.
1244
+ Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
1245
+ level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
1246
+ over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
1247
+ use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
1248
+
1249
+ Returns:
1250
+ Function, return the cell construct function that will be executed under auto parallel process.
1251
1251
 
1252
1252
  Examples:
1253
1253
  >>> import mindspore.nn as nn
@@ -1265,34 +1265,19 @@ class Cell(Cell_):
1265
1265
  ... def __init__(self):
1266
1266
  ... self.block1 = Block()
1267
1267
  ... self.block2 = Block()
1268
- ... self.block2.shard(in_strategy=((2, 1),), parameter_plan={'self.block2.dense1.weight': (4, 1)})
1268
+ ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
1269
+ ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
1269
1270
  ... def construct(self, x):
1270
1271
  ... x = self.block1(x)
1271
- ... x = self.block2(x)
1272
+ ... x = self.block2_shard(x)
1272
1273
  ... return x
1273
1274
  """
1274
1275
  if ms.communication.management.get_group_size() == 1:
1275
- return
1276
-
1276
+ return self
1277
1277
  shard_fn = Shard()
1278
- self._shard_fn = shard_fn(self, in_strategy, out_strategy, parameter_plan)
1279
-
1280
- if self._in_strategy is not None: # pylint: disable=E0203
1281
- msg = (
1282
- "For '%s', 'Shard' has been configured more than once. "
1283
- "The existing in_strategy is %s and the existing out_strategy is %s. "
1284
- "The new in_strategy %s and out_strategy %s may not take effect. "
1285
- "It is recommended to configure 'Shard' only once."
1286
- ) % (
1287
- self._cell_tag,
1288
- self._in_strategy, # pylint: disable=E0203
1289
- self._out_strategy, # pylint: disable=E0203
1290
- shard_fn.in_strategy,
1291
- shard_fn.out_strategy,
1292
- )
1293
- logger.warning(msg)
1294
- self._in_strategy = shard_fn.in_strategy
1295
- self._out_strategy = shard_fn.out_strategy
1278
+ fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
1279
+ self._shard_fn = fn
1280
+ return fn
1296
1281
 
1297
1282
  def _init_check(self):
1298
1283
  for param in self.get_parameters(expand=False):
@@ -1301,13 +1286,9 @@ class Cell(Cell_):
1301
1286
  self._init_flag = True
1302
1287
 
1303
1288
  def _self_check(self):
1304
- try:
1305
- if not self._is_check_and_refresh: # pylint: disable=E0203
1306
- self.check_names_and_refresh_name()
1307
- self._is_check_and_refresh = True
1308
- except AttributeError as e:
1309
- raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
1310
- f"Please use 'super().__init__()'.") from e
1289
+ if not self._is_check_and_refresh:
1290
+ self.check_names_and_refresh_name()
1291
+ self._is_check_and_refresh = True
1311
1292
 
1312
1293
  def _predict(self, *args, **kwargs):
1313
1294
  '''Graph executor for predict'''
@@ -1328,7 +1309,6 @@ class Cell(Cell_):
1328
1309
  def __call__(self, *args, **kwargs):
1329
1310
  # Run in Graph mode.
1330
1311
  if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
1331
- self._compiled = True
1332
1312
  if kwargs:
1333
1313
  bound_arguments = self._construct_sig.bind(*args, **kwargs)
1334
1314
  bound_arguments.apply_defaults()
@@ -1339,8 +1319,11 @@ class Cell(Cell_):
1339
1319
  if predict_compiled:
1340
1320
  return res
1341
1321
  self._check_construct_args(*args)
1322
+
1323
+ if self._hook_fn_registered():
1324
+ logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
1325
+ f"function, please use context.set_context to set pynative mode.")
1342
1326
  self._self_check()
1343
- self.__compile_cell_hook__ = True
1344
1327
  out = self.compile_and_run(*args, **kwargs)
1345
1328
  return out
1346
1329
 
@@ -1438,7 +1421,16 @@ class Cell(Cell_):
1438
1421
  exist_names.add(item.name)
1439
1422
  self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1440
1423
 
1441
- object.__setattr__(self, name, value)
1424
+ if context._get_mode() == context.PYNATIVE_MODE:
1425
+ if name in self.__dict__:
1426
+ del self.__dict__[name]
1427
+ params = self.__dict__.get('_params')
1428
+ if name in params:
1429
+ del params[name]
1430
+ params_list = self.__dict__.get('_params_list')
1431
+ params_list[name] = value
1432
+ else:
1433
+ object.__setattr__(self, name, value)
1442
1434
 
1443
1435
  def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
1444
1436
  """Set attr for parameter in list or tuple."""
@@ -1617,6 +1609,8 @@ class Cell(Cell_):
1617
1609
  _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1618
1610
  else:
1619
1611
  self._check_construct_args(*inputs)
1612
+ # TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
1613
+ # which means that incremental mode is lacking dynamic input.
1620
1614
  else:
1621
1615
  self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
1622
1616
 
@@ -2604,7 +2598,6 @@ class Cell(Cell_):
2604
2598
  raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
2605
2599
  Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
2606
2600
 
2607
- @jit_forbidden_register
2608
2601
  def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
2609
2602
  """
2610
2603
  Register forward pre hook function for Cell object.
@@ -2624,6 +2617,7 @@ class Cell(Cell_):
2624
2617
  `with_kwargs` is ``True`` .
2625
2618
 
2626
2619
  Note:
2620
+ - The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
2627
2621
  - The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
2628
2622
  single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
2629
2623
  returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
@@ -2674,15 +2668,15 @@ class Cell(Cell_):
2674
2668
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2675
2669
  value= [ 2.00000000e+00]))
2676
2670
  """
2671
+ if context._get_mode() == context.GRAPH_MODE:
2672
+ return HookHandle()
2677
2673
  check_hook_fn(hook_fn)
2678
2674
  handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
2679
2675
  self._forward_pre_hook[handle.handle_id] = hook_fn
2680
2676
  if with_kwargs:
2681
2677
  self._forward_pre_hook_with_kwargs[handle.handle_id] = True
2682
- _update_hook_version()
2683
2678
  return handle
2684
2679
 
2685
- @jit_forbidden_register
2686
2680
  def _run_forward_pre_hook(self, args, kwargs):
2687
2681
  """
2688
2682
  Running forward pre hook function registered on Cell object.
@@ -2706,35 +2700,6 @@ class Cell(Cell_):
2706
2700
  args = ret
2707
2701
  return args, kwargs
2708
2702
 
2709
- def _jit_forward_pre_hook(self, inputs):
2710
- """
2711
- Compile forward pre hook function registered on Cell object.
2712
-
2713
- Args:
2714
- inputs: The input objects of cell object.
2715
-
2716
- Returns:
2717
- - **outputs** - New input objects or none.
2718
-
2719
- Supported Platforms:
2720
- ``Ascend`` ``GPU`` ``CPU``
2721
- """
2722
- forward_pre_hook_inputs = inputs
2723
- for fn in self._forward_pre_hook.values():
2724
- ret = fn(self, forward_pre_hook_inputs)
2725
- if ret is not None:
2726
- if not isinstance(ret, tuple):
2727
- forward_pre_hook_inputs = (ret,)
2728
- else:
2729
- forward_pre_hook_inputs = ret
2730
-
2731
- if len(forward_pre_hook_inputs) != len(inputs):
2732
- raise TypeError(
2733
- "The forward pre hook return value size is {} not equal to input size {}".format(
2734
- len(forward_pre_hook_inputs), len(inputs)))
2735
- return forward_pre_hook_inputs
2736
-
2737
- @jit_forbidden_register
2738
2703
  def register_forward_hook(self, hook_fn, with_kwargs=False):
2739
2704
  """
2740
2705
  Register forward hook function for Cell object.
@@ -2755,6 +2720,7 @@ class Cell(Cell_):
2755
2720
  - `output`: Output generated by the `construct` function.
2756
2721
 
2757
2722
  Note:
2723
+ - The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
2758
2724
  - The `hook_fn` can modify the forward outputs by returning new outputs.
2759
2725
  - In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
2760
2726
  `construct` function of Cell object.
@@ -2807,44 +2773,15 @@ class Cell(Cell_):
2807
2773
  """
2808
2774
  if self.has_bprop:
2809
2775
  return HookHandle()
2776
+ if context._get_mode() == context.GRAPH_MODE:
2777
+ return HookHandle()
2810
2778
  check_hook_fn(hook_fn)
2811
2779
  handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
2812
2780
  self._forward_hook[handle.handle_id] = hook_fn
2813
2781
  if with_kwargs:
2814
2782
  self._forward_hook_with_kwargs[handle.handle_id] = True
2815
- _update_hook_version()
2816
2783
  return handle
2817
2784
 
2818
- def _jit_forward_hook(self, inputs, output):
2819
- """
2820
- Compile forward hook function registered on Cell object.
2821
-
2822
- Args:
2823
- inputs: The input objects of Cell object.
2824
- output: The output object of Cell object.
2825
-
2826
- Returns:
2827
- - **output** - New output object or none.
2828
-
2829
- Supported Platforms:
2830
- ``Ascend`` ``GPU`` ``CPU``
2831
- """
2832
- forward_hook_output = output
2833
- for fn in self._forward_hook.values():
2834
- ret = fn(self, inputs, forward_hook_output)
2835
- if ret is not None:
2836
- forward_hook_output = ret
2837
-
2838
- if isinstance(output, tuple):
2839
- if not isinstance(forward_hook_output, tuple):
2840
- forward_hook_output = (forward_hook_output,)
2841
- if len(forward_hook_output) != len(output):
2842
- raise TypeError(
2843
- "The forward hook return value size is {} not equal to output size {}".format(
2844
- len(forward_hook_output), len(output)))
2845
- return forward_hook_output
2846
-
2847
- @jit_forbidden_register
2848
2785
  def _run_forward_hook(self, args, kwargs, output):
2849
2786
  """
2850
2787
  Running forward hook function registered on Cell object.
@@ -2858,12 +2795,12 @@ class Cell(Cell_):
2858
2795
  output = ret
2859
2796
  return output
2860
2797
 
2861
- @jit_forbidden_register
2862
2798
  def register_backward_pre_hook(self, hook_fn):
2863
2799
  """
2864
2800
  Register the backward pre hook function.
2865
2801
 
2866
2802
  Note:
2803
+ - The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2867
2804
  - The 'hook_fn' must be defined as the following code.
2868
2805
  `cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
2869
2806
  - The 'hook_fn' should have the following signature:
@@ -2912,17 +2849,44 @@ class Cell(Cell_):
2912
2849
  >>> print(output)
2913
2850
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2914
2851
  """
2852
+ if context._get_mode() == context.GRAPH_MODE:
2853
+ return HookHandle()
2915
2854
  check_hook_fn(hook_fn)
2916
- handle = HookHandle(self._backward_pre_hook, extra_dict=None)
2855
+ handle = HookHandle(self._backward_pre_hook)
2917
2856
  self._backward_pre_hook[handle.handle_id] = hook_fn
2918
- if self._cell_backward_pre_hook is None: # pylint: disable=E0203
2857
+ if self._cell_backward_pre_hook is None:
2919
2858
  # Generate a CellBackwardHook prim, and add function for it
2920
2859
  self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2921
2860
  self, self._backward_pre_hook)
2922
2861
  self._cell_backward_pre_hook.register_backward_pre_hook()
2923
- _update_hook_version()
2924
2862
  return handle
2925
2863
 
2864
+ def _run_backward_pre_hook(self, outputs):
2865
+ """
2866
+ Running backward pre hook function registered on Cell object.
2867
+
2868
+ Args:
2869
+ outputs: The output objects of cell object.
2870
+
2871
+ Returns:
2872
+ - **outputs** - New backward gradient or None.
2873
+
2874
+ Supported Platforms:
2875
+ ``Ascend`` ``GPU`` ``CPU``
2876
+ """
2877
+ if isinstance(outputs, tuple):
2878
+ ret = self._cell_backward_pre_hook(*outputs)
2879
+ else:
2880
+ ret = self._cell_backward_pre_hook(outputs)
2881
+ if isinstance(outputs, tuple):
2882
+ if len(outputs) == 1:
2883
+ ret = (ret,)
2884
+ if len(ret) != len(outputs):
2885
+ raise TypeError(
2886
+ "The backward pre hook return value size is {} not equal to output size {}".format(
2887
+ len(ret), len(outputs)))
2888
+ return ret
2889
+
2926
2890
  def get_extra_state(self) -> Any:
2927
2891
  """Return any extra state to include in the cell's state_dict.
2928
2892
 
@@ -2975,8 +2939,9 @@ class Cell(Cell_):
2975
2939
  A handle that can be used to remove the added hook by calling
2976
2940
  `handle.remove()`.
2977
2941
  """
2978
- handle = HookHandle(self._state_dict_hooks)
2979
- self._state_dict_hooks[handle.handle_id] = hook
2942
+ from mindspore.utils.hooks import _RemovableHandle
2943
+ handle = _RemovableHandle(self._state_dict_hooks)
2944
+ self._state_dict_hooks[handle.id] = hook
2980
2945
  return handle
2981
2946
 
2982
2947
  @jit_forbidden_register
@@ -3022,8 +2987,9 @@ class Cell(Cell_):
3022
2987
  >>> print("extra_param" in net_state_dict)
3023
2988
  True
3024
2989
  """
3025
- handle = HookHandle(self._state_dict_pre_hooks)
3026
- self._state_dict_pre_hooks[handle.handle_id] = hook
2990
+ from mindspore.utils.hooks import _RemovableHandle
2991
+ handle = _RemovableHandle(self._state_dict_pre_hooks)
2992
+ self._state_dict_pre_hooks[handle.id] = hook
3027
2993
  return handle
3028
2994
 
3029
2995
  def _save_to_state_dict(self, destination, prefix, keep_vars):
@@ -3169,8 +3135,9 @@ class Cell(Cell_):
3169
3135
  A handle that can be used to remove the added hook by calling
3170
3136
  `handle.remove()`.
3171
3137
  """
3172
- handle = HookHandle(self._load_state_dict_pre_hooks)
3173
- self._load_state_dict_pre_hooks[handle.handle_id] = hook
3138
+ from mindspore.utils.hooks import _RemovableHandle
3139
+ handle = _RemovableHandle(self._load_state_dict_pre_hooks)
3140
+ self._load_state_dict_pre_hooks[handle.id] = hook
3174
3141
  return handle
3175
3142
 
3176
3143
  @jit_forbidden_register
@@ -3202,8 +3169,9 @@ class Cell(Cell_):
3202
3169
  A handle that can be used to remove the added hook by calling
3203
3170
  `handle.remove()`.
3204
3171
  """
3205
- handle = HookHandle(self._load_state_dict_post_hooks)
3206
- self._load_state_dict_post_hooks[handle.handle_id] = hook
3172
+ from mindspore.utils.hooks import _RemovableHandle
3173
+ handle = _RemovableHandle(self._load_state_dict_post_hooks)
3174
+ self._load_state_dict_post_hooks[handle.id] = hook
3207
3175
  return handle
3208
3176
 
3209
3177
  def _load_from_state_dict(
@@ -3439,12 +3407,12 @@ class Cell(Cell_):
3439
3407
  )
3440
3408
  return _IncompatibleKeys(missing_keys, unexpected_keys)
3441
3409
 
3442
- @jit_forbidden_register
3443
3410
  def register_backward_hook(self, hook_fn):
3444
3411
  """
3445
3412
  Register the backward hook function.
3446
3413
 
3447
3414
  Note:
3415
+ - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
3448
3416
  - The 'hook_fn' must be defined as the following code.
3449
3417
  `cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
3450
3418
  the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
@@ -3496,17 +3464,65 @@ class Cell(Cell_):
3496
3464
  >>> print(output)
3497
3465
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
3498
3466
  """
3467
+ if context._get_mode() == context.GRAPH_MODE:
3468
+ return HookHandle()
3499
3469
  check_hook_fn(hook_fn)
3500
- handle = HookHandle(self._backward_hook, extra_dict=None)
3470
+ handle = HookHandle(self._backward_hook)
3501
3471
  self._backward_hook[handle.handle_id] = hook_fn
3502
- if self._cell_backward_hook is None: # pylint: disable=E0203
3472
+ if self._cell_backward_hook is None:
3503
3473
  # Generate a CellBackwardHook prim, and add function for it
3504
3474
  self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
3505
3475
  self, self._backward_hook)
3506
3476
  self._cell_backward_hook.register_backward_hook()
3507
- _update_hook_version()
3508
3477
  return handle
3509
3478
 
3479
+ def _backward_hook_construct(self, *inputs, **kwargs):
3480
+ """
3481
+ Backward hook construct method to replace original construct method.
3482
+
3483
+ Args:
3484
+ inputs: The input objects of Cell object.
3485
+ kwargs (dict): Dictionary of variable keyword parameters.
3486
+
3487
+ Returns:
3488
+ - **outputs** - The output objects of Cell object.
3489
+
3490
+ Supported Platforms:
3491
+ ``Ascend`` ``GPU`` ``CPU``
3492
+ """
3493
+ # cell_backward_hook has CellBackwardHook op, so keep input args as they are.
3494
+ outputs = self._cell_backward_hook(*inputs)
3495
+ # If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
3496
+ # a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
3497
+ # Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
3498
+ is_need_unwrap = False
3499
+ if isinstance(outputs, tuple) and len(inputs) != 1:
3500
+ is_need_unwrap = True
3501
+
3502
+ if self._recompute_cell is not None:
3503
+ if is_need_unwrap:
3504
+ outputs = self._recompute_cell(*outputs, **kwargs)
3505
+ else:
3506
+ outputs = self._recompute_cell(outputs, **kwargs)
3507
+ elif self.has_bprop:
3508
+ if is_need_unwrap:
3509
+ outputs = self._call_custom_bprop(*outputs, **kwargs)
3510
+ else:
3511
+ outputs = self._call_custom_bprop(outputs, **kwargs)
3512
+ else:
3513
+ if is_need_unwrap:
3514
+ outputs = self.construct(*outputs, **kwargs)
3515
+ else:
3516
+ outputs = self.construct(outputs, **kwargs)
3517
+ if isinstance(outputs, tuple):
3518
+ new_outputs = self._cell_backward_hook(*outputs)
3519
+ else:
3520
+ new_outputs = self._cell_backward_hook(outputs)
3521
+ # if outputs is (X,) and new_outpus is X
3522
+ if isinstance(outputs, tuple) and len(outputs) == 1:
3523
+ new_outputs = (new_outputs,)
3524
+ return new_outputs
3525
+
3510
3526
  def set_param_ps(self, recurse=True, init_in_server=False):
3511
3527
  """
3512
3528
  Set whether the trainable parameters are updated by parameter server and whether the
@@ -3585,7 +3601,7 @@ class Cell(Cell_):
3585
3601
  """
3586
3602
  Validator.check_bool(mode)
3587
3603
  Validator.check_bool(output_recompute)
3588
- if not self._has_config_recompute: # pylint: disable=E0203
3604
+ if not self._has_config_recompute:
3589
3605
  self._has_config_recompute = True
3590
3606
  else:
3591
3607
  logger.info("The recompute interface can be configured only once."
@@ -3628,7 +3644,7 @@ class Cell(Cell_):
3628
3644
  introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
3629
3645
  Default: ``False`` .
3630
3646
  """
3631
- if context._get_mode() == context.PYNATIVE_MODE:
3647
+ if context.get_context("mode") == context.PYNATIVE_MODE:
3632
3648
  self._recompute_cell = recompute_registry.get()(self.construct)
3633
3649
  self._recompute()
3634
3650
  if 'mp_comm_recompute' in kwargs.keys():
@@ -3689,64 +3705,6 @@ class Cell(Cell_):
3689
3705
  """
3690
3706
  self._jit_graph_name = key
3691
3707
 
3692
- def _jit_backward_pre_hook(self, grad_output):
3693
- new_grad_output = grad_output
3694
- if not isinstance(grad_output, tuple):
3695
- new_grad_output = (grad_output,)
3696
-
3697
- for fn in self._backward_pre_hook.values():
3698
- ret = fn(self, new_grad_output)
3699
- if ret is not None:
3700
- if not isinstance(ret, tuple):
3701
- output = (ret,)
3702
- else:
3703
- output = ret
3704
- else:
3705
- output = ops.Depend()(new_grad_output, ret)
3706
- new_grad_output = output
3707
-
3708
- if not isinstance(grad_output, tuple):
3709
- if len(new_grad_output) == 1:
3710
- return new_grad_output[0]
3711
- raise TypeError(
3712
- "The backward pre hook return value size is {} not equal to input size 1".format(
3713
- len(new_grad_output)))
3714
-
3715
- if len(new_grad_output) != len(grad_output):
3716
- raise TypeError(
3717
- "The backward pre hook return value size is {} not equal to input size {}".format(
3718
- len(new_grad_output), len(grad_output)))
3719
-
3720
- return new_grad_output
3721
-
3722
- def _jit_backward_hook(self, grad_input, grad_output):
3723
- backward_hook_input = grad_input
3724
- backward_hook_output = grad_output
3725
- if not isinstance(grad_input, tuple):
3726
- backward_hook_input = (grad_input,)
3727
- if not isinstance(grad_output, tuple):
3728
- backward_hook_output = (grad_output,)
3729
-
3730
- for fn in self._backward_hook.values():
3731
- ret = fn(self, backward_hook_input, backward_hook_output)
3732
- if ret is not None:
3733
- if not isinstance(ret, tuple):
3734
- output = (ret,)
3735
- else:
3736
- output = ret
3737
- else:
3738
- output = ops.Depend()(backward_hook_input, ret)
3739
-
3740
- backward_hook_input = output
3741
-
3742
- if not isinstance(grad_input, tuple):
3743
- return backward_hook_input[0]
3744
-
3745
- if len(backward_hook_input) != len(grad_input):
3746
- raise TypeError(
3747
- "The backward hook return value size is {} not equal to input size {}".format(
3748
- len(backward_hook_input), len(grad_input)))
3749
- return backward_hook_input
3750
3708
 
3751
3709
  class GraphCell(Cell):
3752
3710
  """