mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -32,7 +32,7 @@ from mindspore import context
|
|
|
32
32
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
33
33
|
from mindspore import _checkparam as Validator
|
|
34
34
|
from mindspore.common import dtype as mstype
|
|
35
|
-
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
35
|
+
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, _no_grad
|
|
36
36
|
from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
|
|
37
37
|
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
38
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
@@ -45,7 +45,6 @@ from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
|
45
45
|
from mindspore.common._decorator import deprecated
|
|
46
46
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
47
47
|
|
|
48
|
-
|
|
49
48
|
class Cell(Cell_):
|
|
50
49
|
"""
|
|
51
50
|
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
|
|
@@ -101,9 +100,9 @@ class Cell(Cell_):
|
|
|
101
100
|
"""
|
|
102
101
|
|
|
103
102
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
104
|
-
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
|
|
105
|
-
'_forward_pre_hook', '_forward_hook', '
|
|
106
|
-
'
|
|
103
|
+
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
104
|
+
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
105
|
+
'_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
|
107
106
|
'_attr_synced', 'pynative', 'requires_grad', 'cell_type']
|
|
108
107
|
total_instance_count = 0
|
|
109
108
|
|
|
@@ -135,7 +134,8 @@ class Cell(Cell_):
|
|
|
135
134
|
self._id = 1
|
|
136
135
|
self.exist_names = set("")
|
|
137
136
|
self.exist_objs = set()
|
|
138
|
-
self.
|
|
137
|
+
self._recompute_cell = None
|
|
138
|
+
self.mixed_precision_type = None
|
|
139
139
|
self.sig = inspect.signature(self.construct)
|
|
140
140
|
init_pipeline()
|
|
141
141
|
|
|
@@ -146,13 +146,16 @@ class Cell(Cell_):
|
|
|
146
146
|
if flags:
|
|
147
147
|
self.add_flags(**flags)
|
|
148
148
|
self._bprop_debug = False
|
|
149
|
+
|
|
150
|
+
# hook
|
|
149
151
|
self._forward_pre_hook = OrderedDict()
|
|
150
152
|
self._forward_hook = OrderedDict()
|
|
151
|
-
self.
|
|
152
|
-
self.
|
|
153
|
-
self.
|
|
153
|
+
self._backward_pre_hook = OrderedDict()
|
|
154
|
+
self._cell_backward_pre_hook = None
|
|
155
|
+
self._backward_hook = OrderedDict()
|
|
154
156
|
self._cell_backward_hook = None
|
|
155
157
|
self._is_recursion_hook = False
|
|
158
|
+
|
|
156
159
|
self.cell_type = None
|
|
157
160
|
self.cast = Cast()
|
|
158
161
|
self._has_config_recompute = False
|
|
@@ -166,6 +169,10 @@ class Cell(Cell_):
|
|
|
166
169
|
self._is_check_and_refresh = False
|
|
167
170
|
self._amp_level = ""
|
|
168
171
|
self._init_flag = False
|
|
172
|
+
self._shard_fn = None
|
|
173
|
+
self.has_bprop = False
|
|
174
|
+
if hasattr(self, "bprop"):
|
|
175
|
+
self.has_bprop = True
|
|
169
176
|
|
|
170
177
|
def __getstate__(self):
|
|
171
178
|
base = Cell_.__getstate__(self)
|
|
@@ -223,8 +230,9 @@ class Cell(Cell_):
|
|
|
223
230
|
Get whether cell custom bprop debug is enabled.
|
|
224
231
|
|
|
225
232
|
Tutorial Examples:
|
|
226
|
-
- `
|
|
227
|
-
<https://mindspore.cn/
|
|
233
|
+
- `Custom Neural Network Layers - Custom Cell Reverse
|
|
234
|
+
<https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
|
|
235
|
+
#custom-cell-reverse>`_
|
|
228
236
|
"""
|
|
229
237
|
return self._bprop_debug
|
|
230
238
|
|
|
@@ -374,6 +382,10 @@ class Cell(Cell_):
|
|
|
374
382
|
def jit_config_dict(self):
|
|
375
383
|
return self._jit_config_dict
|
|
376
384
|
|
|
385
|
+
@property
|
|
386
|
+
def enable_backward_hook(self):
|
|
387
|
+
return self._enable_backward_hook
|
|
388
|
+
|
|
377
389
|
def get_func_graph_proto(self):
|
|
378
390
|
"""Return graph binary proto."""
|
|
379
391
|
exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
|
|
@@ -401,8 +413,6 @@ class Cell(Cell_):
|
|
|
401
413
|
cells_compile_cache.pop(id(self), None)
|
|
402
414
|
if hasattr(self, "compile_cache") and self.compile_cache:
|
|
403
415
|
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
404
|
-
if isinstance(self, GraphCell):
|
|
405
|
-
_cell_graph_executor.dec_graph_cell_count()
|
|
406
416
|
Cell.total_instance_count -= 1
|
|
407
417
|
|
|
408
418
|
def __delattr__(self, name):
|
|
@@ -475,21 +485,28 @@ class Cell(Cell_):
|
|
|
475
485
|
output = self._run_construct(cast_inputs, kwargs)
|
|
476
486
|
return output
|
|
477
487
|
|
|
478
|
-
def _run_construct(self,
|
|
488
|
+
def _run_construct(self, *inputs, **kwargs):
|
|
479
489
|
"""Run the construct function"""
|
|
480
|
-
if self.
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
490
|
+
if self._forward_pre_hook:
|
|
491
|
+
inputs = self._run_forward_pre_hook(inputs)
|
|
492
|
+
|
|
493
|
+
if self._backward_hook:
|
|
494
|
+
output = self._backward_hook_construct(*inputs, **kwargs)
|
|
495
|
+
elif self._shard_fn is not None:
|
|
496
|
+
output = self._shard_fn(*inputs, **kwargs)
|
|
497
|
+
elif self._recompute_cell is not None:
|
|
498
|
+
output = self._recompute_cell(*inputs, **kwargs)
|
|
499
|
+
elif self.has_bprop and _pynative_executor.requires_grad():
|
|
500
|
+
output = self._call_custom_bprop(*inputs, **kwargs)
|
|
486
501
|
else:
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
502
|
+
output = self.construct(*inputs, **kwargs)
|
|
503
|
+
|
|
504
|
+
if self._forward_hook:
|
|
505
|
+
output = self._run_forward_hook(inputs, output)
|
|
506
|
+
|
|
507
|
+
if self._backward_pre_hook:
|
|
508
|
+
output = self._run_backward_pre_hook(output)
|
|
509
|
+
|
|
493
510
|
return output
|
|
494
511
|
|
|
495
512
|
def _check_construct_args(self, *args):
|
|
@@ -527,7 +544,7 @@ class Cell(Cell_):
|
|
|
527
544
|
'''Hook function in graph mode'''
|
|
528
545
|
# Check super().__init__() in graph mode.
|
|
529
546
|
try:
|
|
530
|
-
if self.
|
|
547
|
+
if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
|
|
531
548
|
return True
|
|
532
549
|
except AttributeError as e:
|
|
533
550
|
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
@@ -579,8 +596,7 @@ class Cell(Cell_):
|
|
|
579
596
|
strategy for others will be set by sharding propagation.
|
|
580
597
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
581
598
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
582
|
-
this input/output,
|
|
583
|
-
which can refer to the description of `mindspore.ops.Primitive.shard`.
|
|
599
|
+
this input/output, which can refer to the description of `mindspore.ops.Primitive.shard`.
|
|
584
600
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
585
601
|
|
|
586
602
|
Note:
|
|
@@ -589,8 +605,8 @@ class Cell(Cell_):
|
|
|
589
605
|
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
590
606
|
|
|
591
607
|
Args:
|
|
592
|
-
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple
|
|
593
|
-
|
|
608
|
+
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
609
|
+
defines the layout of the corresponding input.
|
|
594
610
|
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
595
611
|
It is not in use right now. Default: ``None`` .
|
|
596
612
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
@@ -625,7 +641,7 @@ class Cell(Cell_):
|
|
|
625
641
|
... def __init__(self):
|
|
626
642
|
... self.block1 = Block()
|
|
627
643
|
... self.block2 = Block()
|
|
628
|
-
... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
|
|
644
|
+
... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
|
|
629
645
|
... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
|
|
630
646
|
... def construct(self, x):
|
|
631
647
|
... x = self.block1(x)
|
|
@@ -638,7 +654,7 @@ class Cell(Cell_):
|
|
|
638
654
|
|
|
639
655
|
shard_fn = Shard()
|
|
640
656
|
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
641
|
-
|
|
657
|
+
self._shard_fn = fn
|
|
642
658
|
return fn
|
|
643
659
|
|
|
644
660
|
def auto_cast_inputs(self, inputs):
|
|
@@ -666,6 +682,7 @@ class Cell(Cell_):
|
|
|
666
682
|
for param in self.get_parameters(expand=False):
|
|
667
683
|
if param.has_init:
|
|
668
684
|
param.init_data()
|
|
685
|
+
self._init_flag = True
|
|
669
686
|
|
|
670
687
|
def _self_check(self):
|
|
671
688
|
if not self._is_check_and_refresh:
|
|
@@ -684,7 +701,7 @@ class Cell(Cell_):
|
|
|
684
701
|
|
|
685
702
|
def __call__(self, *args, **kwargs):
|
|
686
703
|
# Run in Graph mode.
|
|
687
|
-
if os.getenv("MS_JIT") != '0'
|
|
704
|
+
if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
|
|
688
705
|
if kwargs:
|
|
689
706
|
bound_arguments = self.sig.bind(*args, **kwargs)
|
|
690
707
|
bound_arguments.apply_defaults()
|
|
@@ -704,22 +721,69 @@ class Cell(Cell_):
|
|
|
704
721
|
return out
|
|
705
722
|
|
|
706
723
|
# Run in PyNative mode.
|
|
707
|
-
self.
|
|
708
|
-
if not self._init_flag:
|
|
724
|
+
if not (self._init_flag or self._is_check_and_refresh):
|
|
709
725
|
self._init_check()
|
|
710
|
-
self.
|
|
726
|
+
self._self_check()
|
|
727
|
+
|
|
728
|
+
if not (self.requires_grad or self._dynamic_shape_inputs or self.mixed_precision_type):
|
|
729
|
+
if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
|
|
730
|
+
self._shard_fn or self._recompute_cell or (self.has_bprop and _pynative_executor.requires_grad())):
|
|
731
|
+
return self.construct(*args, **kwargs)
|
|
732
|
+
|
|
733
|
+
return self._run_construct(*args, **kwargs)
|
|
734
|
+
|
|
735
|
+
return self._complex_call(*args, **kwargs)
|
|
711
736
|
|
|
737
|
+
def _complex_call(self, *args, **kwargs):
|
|
738
|
+
"""
|
|
739
|
+
PyNative call with requires_grad or hooks
|
|
740
|
+
"""
|
|
741
|
+
self._call_pre_process(*args, **kwargs)
|
|
742
|
+
|
|
743
|
+
if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
|
|
744
|
+
self._shard_fn or self._recompute_cell or self.has_bprop):
|
|
745
|
+
output = self.construct(*args, **kwargs)
|
|
746
|
+
else:
|
|
747
|
+
output = self._run_construct(*args, **kwargs)
|
|
748
|
+
|
|
749
|
+
self._call_post_process(output, *args, **kwargs)
|
|
750
|
+
|
|
751
|
+
return output
|
|
752
|
+
|
|
753
|
+
def _call_pre_process(self, *args, **kwargs):
|
|
754
|
+
"""
|
|
755
|
+
Process cell info before call construct
|
|
756
|
+
"""
|
|
712
757
|
if self.requires_grad:
|
|
713
758
|
_pynative_executor.set_grad_flag(True)
|
|
714
|
-
|
|
715
|
-
try:
|
|
716
759
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
717
|
-
|
|
760
|
+
elif self._dynamic_shape_inputs is not None:
|
|
761
|
+
_pynative_executor.set_cell_use_dynamic_shape_process(True)
|
|
762
|
+
|
|
763
|
+
# Set mixed precision
|
|
764
|
+
if self.mixed_precision_type is not None:
|
|
765
|
+
_pynative_executor.set_mixed_precision_type(self.mixed_precision_type)
|
|
766
|
+
|
|
767
|
+
def _call_post_process(self, output, *args, **kwargs):
|
|
768
|
+
"""
|
|
769
|
+
Process cell info after call construct
|
|
770
|
+
"""
|
|
771
|
+
if self.requires_grad:
|
|
718
772
|
_pynative_executor.end_graph(self, output, *args, **kwargs)
|
|
719
|
-
|
|
720
|
-
_pynative_executor.
|
|
721
|
-
|
|
773
|
+
elif self._dynamic_shape_inputs is not None:
|
|
774
|
+
_pynative_executor.set_cell_use_dynamic_shape_process(False)
|
|
775
|
+
|
|
776
|
+
# mixed precision reset
|
|
777
|
+
if self.mixed_precision_type is not None:
|
|
778
|
+
_pynative_executor.set_mixed_precision_type(MixedPrecisionType.NOTSET, False)
|
|
722
779
|
|
|
780
|
+
def _call_custom_bprop(self, *args, **kwargs):
|
|
781
|
+
"""
|
|
782
|
+
Call custom bprop for cell bprop.
|
|
783
|
+
"""
|
|
784
|
+
with _no_grad():
|
|
785
|
+
output = self.construct(*args, **kwargs)
|
|
786
|
+
_pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
|
|
723
787
|
return output
|
|
724
788
|
|
|
725
789
|
def _add_attr(self, name, value):
|
|
@@ -961,9 +1025,12 @@ class Cell(Cell_):
|
|
|
961
1025
|
|
|
962
1026
|
if not kwargs:
|
|
963
1027
|
self._dynamic_shape_inputs = inputs
|
|
964
|
-
self._check_construct_args(*inputs)
|
|
965
1028
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
966
1029
|
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
1030
|
+
else:
|
|
1031
|
+
self._check_construct_args(*inputs)
|
|
1032
|
+
# TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
|
|
1033
|
+
# which means that incremental mode is lacking dynamic input.
|
|
967
1034
|
else:
|
|
968
1035
|
self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
|
|
969
1036
|
|
|
@@ -1682,10 +1749,13 @@ class Cell(Cell_):
|
|
|
1682
1749
|
def _add_mixed_precision_flag(self, **flags):
|
|
1683
1750
|
"""Add mixed precision flag to current cell"""
|
|
1684
1751
|
if "fp16" in flags and flags.get("fp16", False):
|
|
1752
|
+
self.mixed_precision_type = MixedPrecisionType.FP16
|
|
1685
1753
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
|
1686
1754
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1755
|
+
self.mixed_precision_type = MixedPrecisionType.FP32
|
|
1687
1756
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
|
1688
1757
|
if "bf16" in flags and flags.get("bf16", False):
|
|
1758
|
+
self.mixed_precision_type = MixedPrecisionType.BF16
|
|
1689
1759
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
|
|
1690
1760
|
|
|
1691
1761
|
def apply(self, fn):
|
|
@@ -1750,9 +1820,6 @@ class Cell(Cell_):
|
|
|
1750
1820
|
if not hasattr(self, "_func_graph_flags"):
|
|
1751
1821
|
self._func_graph_flags = {}
|
|
1752
1822
|
self._func_graph_flags.update({**flags})
|
|
1753
|
-
if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
|
|
1754
|
-
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
1755
|
-
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
1756
1823
|
self.__dict__.update({**flags})
|
|
1757
1824
|
self._add_mixed_precision_flag(**flags)
|
|
1758
1825
|
return self
|
|
@@ -2050,15 +2117,12 @@ class Cell(Cell_):
|
|
|
2050
2117
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2051
2118
|
value= [ 2.00000000e+00]))
|
|
2052
2119
|
"""
|
|
2120
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2121
|
+
return HookHandle()
|
|
2053
2122
|
if not check_hook_fn("register_forward_pre_hook", hook_fn):
|
|
2054
2123
|
return HookHandle()
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
if not hasattr(self, '_forward_pre_hook_key'):
|
|
2058
|
-
self._forward_pre_hook_key = -1
|
|
2059
|
-
self._forward_pre_hook_key += 1
|
|
2060
|
-
self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
|
|
2061
|
-
handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
|
|
2124
|
+
handle = HookHandle(self._forward_pre_hook)
|
|
2125
|
+
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2062
2126
|
return handle
|
|
2063
2127
|
|
|
2064
2128
|
def _run_forward_pre_hook(self, inputs):
|
|
@@ -2074,14 +2138,23 @@ class Cell(Cell_):
|
|
|
2074
2138
|
Supported Platforms:
|
|
2075
2139
|
``Ascend`` ``GPU`` ``CPU``
|
|
2076
2140
|
"""
|
|
2141
|
+
forward_pre_hook_inputs = inputs
|
|
2077
2142
|
for fn in self._forward_pre_hook.values():
|
|
2078
|
-
ret = fn(self,
|
|
2143
|
+
ret = fn(self, forward_pre_hook_inputs)
|
|
2079
2144
|
if ret is not None:
|
|
2080
2145
|
if not isinstance(ret, tuple):
|
|
2081
|
-
|
|
2146
|
+
forward_pre_hook_inputs = (ret,)
|
|
2082
2147
|
else:
|
|
2083
|
-
|
|
2084
|
-
|
|
2148
|
+
forward_pre_hook_inputs = ret
|
|
2149
|
+
|
|
2150
|
+
if isinstance(inputs, tuple):
|
|
2151
|
+
if not isinstance(forward_pre_hook_inputs, tuple):
|
|
2152
|
+
forward_pre_hook_inputs = (forward_pre_hook_inputs,)
|
|
2153
|
+
if len(forward_pre_hook_inputs) != len(inputs):
|
|
2154
|
+
raise TypeError(
|
|
2155
|
+
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2156
|
+
len(forward_pre_hook_inputs), len(inputs)))
|
|
2157
|
+
return forward_pre_hook_inputs
|
|
2085
2158
|
|
|
2086
2159
|
def register_forward_hook(self, hook_fn):
|
|
2087
2160
|
"""
|
|
@@ -2142,15 +2215,12 @@ class Cell(Cell_):
|
|
|
2142
2215
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2143
2216
|
value= [ 2.00000000e+00]))
|
|
2144
2217
|
"""
|
|
2218
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2219
|
+
return HookHandle()
|
|
2145
2220
|
if not check_hook_fn("register_forward_hook", hook_fn):
|
|
2146
2221
|
return HookHandle()
|
|
2147
|
-
|
|
2148
|
-
|
|
2149
|
-
if not hasattr(self, '_forward_hook_key'):
|
|
2150
|
-
self._forward_hook_key = -1
|
|
2151
|
-
self._forward_hook_key += 1
|
|
2152
|
-
self._forward_hook[self._forward_hook_key] = hook_fn
|
|
2153
|
-
handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
|
|
2222
|
+
handle = HookHandle(self._forward_hook)
|
|
2223
|
+
self._forward_hook[handle.handle_id] = hook_fn
|
|
2154
2224
|
return handle
|
|
2155
2225
|
|
|
2156
2226
|
def _run_forward_hook(self, inputs, output):
|
|
@@ -2167,11 +2237,110 @@ class Cell(Cell_):
|
|
|
2167
2237
|
Supported Platforms:
|
|
2168
2238
|
``Ascend`` ``GPU`` ``CPU``
|
|
2169
2239
|
"""
|
|
2240
|
+
forward_hook_output = output
|
|
2170
2241
|
for fn in self._forward_hook.values():
|
|
2171
|
-
ret = fn(self, inputs,
|
|
2242
|
+
ret = fn(self, inputs, forward_hook_output)
|
|
2172
2243
|
if ret is not None:
|
|
2173
|
-
|
|
2174
|
-
|
|
2244
|
+
forward_hook_output = ret
|
|
2245
|
+
|
|
2246
|
+
if isinstance(output, tuple):
|
|
2247
|
+
if not isinstance(forward_hook_output, tuple):
|
|
2248
|
+
forward_hook_output = (forward_hook_output,)
|
|
2249
|
+
if len(forward_hook_output) != len(output):
|
|
2250
|
+
raise TypeError(
|
|
2251
|
+
"The forward hook return value size is {} not equal to output size {}".format(
|
|
2252
|
+
len(forward_hook_output), len(output)))
|
|
2253
|
+
return forward_hook_output
|
|
2254
|
+
|
|
2255
|
+
def register_backward_pre_hook(self, hook_fn):
|
|
2256
|
+
"""
|
|
2257
|
+
Register the backward pre hook function.
|
|
2258
|
+
|
|
2259
|
+
Note:
|
|
2260
|
+
- The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2261
|
+
- The 'hook_fn' must be defined as the following code.
|
|
2262
|
+
`cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
|
|
2263
|
+
- The 'hook_fn' should have the following signature:
|
|
2264
|
+
hook_fn(cell, grad_output) -> New grad_output gradient or None.
|
|
2265
|
+
- The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
|
|
2266
|
+
graph mode, it is not recommended to write it in the `construct` function of Cell object.
|
|
2267
|
+
- In the pynative
|
|
2268
|
+
mode, if the `register_backward_pre_hook` function is called in the `construct` function of the Cell
|
|
2269
|
+
object, a hook function will be added at each run time of Cell object.
|
|
2270
|
+
|
|
2271
|
+
Args:
|
|
2272
|
+
hook_fn (function): Python function. Backward pre hook function.
|
|
2273
|
+
|
|
2274
|
+
Returns:
|
|
2275
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2276
|
+
`handle.remove()` .
|
|
2277
|
+
|
|
2278
|
+
Raises:
|
|
2279
|
+
TypeError: If the `hook_fn` is not a function of python.
|
|
2280
|
+
|
|
2281
|
+
Supported Platforms:
|
|
2282
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2283
|
+
|
|
2284
|
+
Examples:
|
|
2285
|
+
>>> import numpy as np
|
|
2286
|
+
>>> import mindspore as ms
|
|
2287
|
+
>>> from mindspore import Tensor, nn, ops
|
|
2288
|
+
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
2289
|
+
>>> def backward_pre_hook_fn(cell, grad_output):
|
|
2290
|
+
... print("backward input: ", grad_output)
|
|
2291
|
+
...
|
|
2292
|
+
>>> class Net(nn.Cell):
|
|
2293
|
+
... def __init__(self):
|
|
2294
|
+
... super(Net, self).__init__()
|
|
2295
|
+
... self.relu = nn.ReLU()
|
|
2296
|
+
... self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn)
|
|
2297
|
+
...
|
|
2298
|
+
... def construct(self, x):
|
|
2299
|
+
... x = x + x
|
|
2300
|
+
... x = self.relu(x)
|
|
2301
|
+
... return x
|
|
2302
|
+
>>> grad = ops.GradOperation(get_all=True)
|
|
2303
|
+
>>> net = Net()
|
|
2304
|
+
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
|
|
2305
|
+
backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
|
|
2306
|
+
>>> print(output)
|
|
2307
|
+
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2308
|
+
"""
|
|
2309
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2310
|
+
return HookHandle()
|
|
2311
|
+
if not check_hook_fn("register_backward_pre_hook", hook_fn):
|
|
2312
|
+
return HookHandle()
|
|
2313
|
+
handle = HookHandle(self._backward_pre_hook)
|
|
2314
|
+
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2315
|
+
if self._cell_backward_pre_hook is None:
|
|
2316
|
+
# Generate a CellBackwardHook prim, and add function for it
|
|
2317
|
+
self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2318
|
+
self, self._backward_pre_hook)
|
|
2319
|
+
self._cell_backward_pre_hook.register_backward_pre_hook()
|
|
2320
|
+
return handle
|
|
2321
|
+
|
|
2322
|
+
def _run_backward_pre_hook(self, outputs):
|
|
2323
|
+
"""
|
|
2324
|
+
Running backward pre hook function registered on Cell object.
|
|
2325
|
+
|
|
2326
|
+
Args:
|
|
2327
|
+
outputs: The output objects of cell object.
|
|
2328
|
+
|
|
2329
|
+
Returns:
|
|
2330
|
+
- **outputs** - New backward gradient or None.
|
|
2331
|
+
|
|
2332
|
+
Supported Platforms:
|
|
2333
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2334
|
+
"""
|
|
2335
|
+
ret = self._cell_backward_pre_hook(outputs)
|
|
2336
|
+
if isinstance(outputs, tuple):
|
|
2337
|
+
if not isinstance(ret, tuple):
|
|
2338
|
+
ret = (ret,)
|
|
2339
|
+
if len(ret) != len(outputs):
|
|
2340
|
+
raise TypeError(
|
|
2341
|
+
"The backward pre hook return value size is {} not equal to output size {}".format(
|
|
2342
|
+
len(ret), len(outputs)))
|
|
2343
|
+
return ret
|
|
2175
2344
|
|
|
2176
2345
|
def register_backward_hook(self, hook_fn):
|
|
2177
2346
|
"""
|
|
@@ -2180,11 +2349,11 @@ class Cell(Cell_):
|
|
|
2180
2349
|
Note:
|
|
2181
2350
|
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2182
2351
|
- The 'hook_fn' must be defined as the following code.
|
|
2183
|
-
`
|
|
2184
|
-
|
|
2185
|
-
|
|
2352
|
+
`cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
|
|
2353
|
+
the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
|
|
2354
|
+
passed to the Cell.
|
|
2186
2355
|
- The 'hook_fn' should have the following signature:
|
|
2187
|
-
hook_fn(
|
|
2356
|
+
hook_fn(cell, grad_input, grad_output) -> New grad_input gradient or none.
|
|
2188
2357
|
- The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
|
|
2189
2358
|
graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
|
|
2190
2359
|
mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
|
|
@@ -2208,9 +2377,9 @@ class Cell(Cell_):
|
|
|
2208
2377
|
>>> import mindspore as ms
|
|
2209
2378
|
>>> from mindspore import Tensor, nn, ops
|
|
2210
2379
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
2211
|
-
>>> def backward_hook_fn(
|
|
2212
|
-
... print("backward input: ",
|
|
2213
|
-
... print("backward output: ",
|
|
2380
|
+
>>> def backward_hook_fn(cell, grad_input, grad_output):
|
|
2381
|
+
... print("backward input: ", grad_output)
|
|
2382
|
+
... print("backward output: ", grad_input)
|
|
2214
2383
|
...
|
|
2215
2384
|
>>> class Net(nn.Cell):
|
|
2216
2385
|
... def __init__(self):
|
|
@@ -2230,16 +2399,17 @@ class Cell(Cell_):
|
|
|
2230
2399
|
>>> print(output)
|
|
2231
2400
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2232
2401
|
"""
|
|
2402
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2403
|
+
return HookHandle()
|
|
2233
2404
|
if not check_hook_fn("register_backward_hook", hook_fn):
|
|
2234
2405
|
return HookHandle()
|
|
2406
|
+
handle = HookHandle(self._backward_hook)
|
|
2407
|
+
self._backward_hook[handle.handle_id] = hook_fn
|
|
2235
2408
|
if self._cell_backward_hook is None:
|
|
2236
|
-
|
|
2237
|
-
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")"
|
|
2238
|
-
|
|
2239
|
-
|
|
2240
|
-
else:
|
|
2241
|
-
backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
|
|
2242
|
-
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
|
|
2409
|
+
# Generate a CellBackwardHook prim, and add function for it
|
|
2410
|
+
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2411
|
+
self, self._backward_hook)
|
|
2412
|
+
self._cell_backward_hook.register_backward_hook()
|
|
2243
2413
|
return handle
|
|
2244
2414
|
|
|
2245
2415
|
def _backward_hook_construct(self, *inputs, **kwargs):
|
|
@@ -2256,21 +2426,31 @@ class Cell(Cell_):
|
|
|
2256
2426
|
Supported Platforms:
|
|
2257
2427
|
``Ascend`` ``GPU`` ``CPU``
|
|
2258
2428
|
"""
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2429
|
+
# cell_backward_hook has CellBackwardHook op, so keep input args as they are.
|
|
2430
|
+
outputs = self._cell_backward_hook(*inputs)
|
|
2431
|
+
# If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
|
|
2432
|
+
# a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
|
|
2433
|
+
# Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
|
|
2434
|
+
is_need_unwrap = False
|
|
2435
|
+
if isinstance(outputs, tuple) and len(inputs) != 1:
|
|
2436
|
+
is_need_unwrap = True
|
|
2437
|
+
|
|
2438
|
+
if self._recompute_cell is not None:
|
|
2439
|
+
if is_need_unwrap:
|
|
2440
|
+
outputs = self._recompute_cell(*outputs, **kwargs)
|
|
2441
|
+
else:
|
|
2442
|
+
outputs = self._recompute_cell(outputs, **kwargs)
|
|
2443
|
+
elif self.has_bprop:
|
|
2444
|
+
if is_need_unwrap:
|
|
2445
|
+
outputs = self._call_custom_bprop(*outputs, **kwargs)
|
|
2267
2446
|
else:
|
|
2268
|
-
outputs = self.
|
|
2447
|
+
outputs = self._call_custom_bprop(outputs, **kwargs)
|
|
2269
2448
|
else:
|
|
2270
|
-
if
|
|
2271
|
-
outputs = self.construct(*
|
|
2449
|
+
if is_need_unwrap:
|
|
2450
|
+
outputs = self.construct(*outputs, **kwargs)
|
|
2272
2451
|
else:
|
|
2273
|
-
outputs = self.construct(
|
|
2452
|
+
outputs = self.construct(outputs, **kwargs)
|
|
2453
|
+
|
|
2274
2454
|
outputs = self._cell_backward_hook(outputs)
|
|
2275
2455
|
return outputs
|
|
2276
2456
|
|
|
@@ -2401,7 +2581,8 @@ class Cell(Cell_):
|
|
|
2401
2581
|
Default: ``False`` .
|
|
2402
2582
|
"""
|
|
2403
2583
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2404
|
-
self.
|
|
2584
|
+
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
2585
|
+
self._recompute()
|
|
2405
2586
|
return
|
|
2406
2587
|
self._recompute()
|
|
2407
2588
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
@@ -2579,7 +2760,6 @@ class GraphCell(Cell):
|
|
|
2579
2760
|
params_dict = update_func_graph_hyper_params(self.graph, params_init)
|
|
2580
2761
|
for name, param in params_dict.items():
|
|
2581
2762
|
self._params[name] = param
|
|
2582
|
-
_cell_graph_executor.inc_graph_cell_count()
|
|
2583
2763
|
|
|
2584
2764
|
def construct(self, *inputs):
|
|
2585
2765
|
return self.graph(*inputs)
|